Skip to content

Commit c3ee280

Browse files
committed
refactor cli and ssh
1 parent 141c34d commit c3ee280

File tree

6 files changed

+84
-38
lines changed

6 files changed

+84
-38
lines changed

bale/drawer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def toggle_drawer():
6666
)
6767
self._table.tailwind.width("full")
6868
self._table.visible = False
69-
for name in ssh.get_hosts("data"):
69+
for name in ssh.get_hosts():
7070
self._add_host_to_table(name)
7171
chevron = ui.button(icon="chevron_left", color=None, on_click=toggle_drawer).props("padding=0px")
7272
chevron.classes("absolute")
@@ -87,7 +87,7 @@ async def _display_host_dialog(self, name=""):
8787
save = None
8888

8989
async def send_key():
90-
s = ssh.Ssh("data", host=host_input.value, hostname=hostname_input.value, username=username_input.value, password=password_input.value)
90+
s = ssh.Ssh(host_input.value, hostname=hostname_input.value, username=username_input.value, password=password_input.value)
9191
result = await s.send_key()
9292
if result.stdout.strip() != "":
9393
el.notify(result.stdout.strip(), multi_line=True, type="positive")
@@ -110,12 +110,12 @@ async def send_key():
110110
c.tailwind.width("full")
111111
with ui.scroll_area() as s:
112112
s.tailwind.height("[160px]")
113-
public_key = await ssh.get_public_key("data")
113+
public_key = await ssh.get_public_key()
114114
ui.label(public_key).classes("text-secondary break-all")
115115
el.DButton("SAVE", on_click=lambda: host_dialog.submit("save")).bind_enabled_from(save_em, "no_errors")
116116
host_input.value = name
117117
if name != "":
118-
s = ssh.Ssh(path="data", host=name)
118+
s = ssh.Ssh(name)
119119
hostname_input.value = s.hostname
120120
username_input.value = s.username
121121

@@ -125,11 +125,11 @@ async def send_key():
125125
default = Tab(spinner=None).common.get("default", "")
126126
if default == name:
127127
Tab(spinner=None).common["default"] = ""
128-
ssh.Ssh(path="data", host=name).remove()
128+
ssh.Ssh(name).remove()
129129
for row in self._table.rows:
130130
if name == row["name"]:
131131
self._table.remove_rows(row)
132-
ssh.Ssh(path="data", host=host_input.value, hostname=hostname_input.value, username=username_input.value)
132+
ssh.Ssh(host_input.value, hostname=hostname_input.value, username=username_input.value)
133133
self._add_host_to_table(host_input.value)
134134

135135
def _modify_host(self, mode):
@@ -162,7 +162,7 @@ async def _selected(self, e):
162162
if self._selection_mode == "remove":
163163
if len(e.selection) > 0:
164164
for row in e.selection:
165-
ssh.Ssh(path="data", host=row["name"]).remove()
165+
ssh.Ssh(row["name"]).remove()
166166
self._table.remove_rows(row)
167167
self._modify_host(None)
168168

bale/interfaces/cli.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,30 +116,49 @@ async def execute(self, command: str, max_output_lines: int = 0) -> Result:
116116
self._terminate.clear()
117117
self._busy = False
118118
return Result(
119-
command=command, return_code=process.returncode, stdout_lines=self.stdout.copy(), stderr_lines=self.stderr.copy(), terminated=terminated, truncated=self._truncated
119+
command=command,
120+
return_code=process.returncode,
121+
stdout_lines=self.stdout.copy(),
122+
stderr_lines=self.stderr.copy(),
123+
terminated=terminated,
124+
truncated=self._truncated,
120125
)
121126

122-
async def shell(self, command: str) -> Result:
127+
async def shell(self, command: str, max_output_lines: int = 0) -> Result:
123128
self._busy = True
124129
try:
125130
process = await asyncio.create_subprocess_shell(command, stdout=PIPE, stderr=PIPE)
126131
if process is not None and process.stdout is not None and process.stderr is not None:
127-
self.clear_buffers()
132+
self.stdout.clear()
133+
self.stderr.clear()
128134
self._terminate.clear()
135+
self._truncated = False
136+
terminated = False
129137
now = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
130138
self.prefix_line = f"<{now}> {command}\n"
131139
for terminal in self._stdout_terminals:
132140
terminal.call_terminal_method("write", "\n" + self.prefix_line)
133141
await asyncio.gather(
142+
self._controller(process=process, max_output_lines=max_output_lines),
134143
self._read_stdout(stream=process.stdout),
135144
self._read_stderr(stream=process.stderr),
136145
)
146+
if self._terminate.is_set():
147+
terminated = True
137148
await process.wait()
138149
except Exception as e:
139150
raise e
140151
finally:
152+
self._terminate.clear()
141153
self._busy = False
142-
return Result(command=command, return_code=process.returncode, stdout_lines=self.stdout.copy(), stderr_lines=self.stderr.copy(), terminated=False)
154+
return Result(
155+
command=command,
156+
return_code=process.returncode,
157+
stdout_lines=self.stdout.copy(),
158+
stderr_lines=self.stderr.copy(),
159+
terminated=terminated,
160+
truncated=self._truncated,
161+
)
143162

144163
def clear_buffers(self):
145164
self.prefix_line = ""

bale/interfaces/ssh.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
from typing import Dict, Union
1+
from typing import Dict, Optional, Union
22
import os
3-
import asyncio
43
from pathlib import Path
5-
from bale.result import Result
6-
from bale.interfaces.cli import Cli
4+
from bale.interfaces import cli
75

86

9-
def get_hosts(path):
7+
def get_hosts(path: str = "data"):
108
path = f"{Path(path).resolve()}/config"
119
hosts = []
1210
try:
@@ -20,32 +18,42 @@ def get_hosts(path):
2018
return []
2119

2220

23-
async def get_public_key(path: str) -> str:
21+
async def get_public_key(path: str = "data") -> str:
2422
path = Path(path).resolve()
2523
if "id_rsa.pub" not in os.listdir(path) or "id_rsa" not in os.listdir(path):
26-
await Cli().shell(f"""ssh-keygen -t rsa -N "" -f {path}/id_rsa""")
24+
await cli.Cli().shell(f"""ssh-keygen -t rsa -N "" -f {path}/id_rsa""")
2725
with open(f"{path}/id_rsa.pub", "r", encoding="utf-8") as reader:
2826
return reader.read()
2927

3028

31-
class Ssh(Cli):
32-
def __init__(self, path: str, host: str, hostname: str = "", username: str = "", password: Union[str, None] = None, seperator: bytes = b"\n") -> None:
29+
class Ssh(cli.Cli):
30+
def __init__(
31+
self,
32+
host: str,
33+
hostname: str = "",
34+
username: str = "",
35+
password: Optional[str] = None,
36+
options: Optional[Dict[str, str]] = None,
37+
path: str = "data",
38+
seperator: bytes = b"\n",
39+
) -> None:
3340
super().__init__(seperator=seperator)
3441
self._raw_path: str = path
3542
self._path: Path = Path(path).resolve()
36-
self.host: str = host
43+
self.host: str = host.replace(" ", "")
3744
self.password: Union[str, None] = password
3845
self.use_key: bool = False
3946
if password is None:
4047
self.use_key = True
48+
self.options: Optional[Dict[str, str]] = options
4149
self.key_path: str = f"{self._path}/id_rsa"
42-
self._base_cmd: str = ""
43-
self._full_cmd: str = ""
50+
self._base_command: str = ""
51+
self._full_command: str = ""
4452
self._config_path: str = f"{self._path}/config"
4553
self._config: Dict[str, Dict[str, str]] = {}
4654
self.read_config()
47-
self.hostname: str = hostname or self._config.get(host, {}).get("HostName", "")
48-
self.username: str = username or self._config.get(host, {}).get("User", "")
55+
self.hostname: str = hostname or self._config.get(host.replace(" ", ""), {}).get("HostName", "")
56+
self.username: str = username or self._config.get(host.replace(" ", ""), {}).get("User", "")
4957
self.set_config()
5058

5159
def read_config(self) -> None:
@@ -57,7 +65,7 @@ def read_config(self) -> None:
5765
if line == "" or line.startswith("#"):
5866
continue
5967
if line.startswith("Host "):
60-
current_host = line.split(" ")[1].strip()
68+
current_host = line.split(" ", 1)[1].strip().replace('"', "")
6169
self._config[current_host] = {}
6270
else:
6371
key, value = line.split(" ", 1)
@@ -76,30 +84,40 @@ def write_config(self) -> None:
7684
def set_config(self) -> None:
7785
self._config[self.host] = {
7886
"IdentityFile": self.key_path,
79-
"PasswordAuthentication": "no",
8087
"StrictHostKeychecking": "no",
8188
"IdentitiesOnly": "yes",
8289
}
90+
self._config[self.host]["PasswordAuthentication"] = "no" if self.password is None else "yes"
8391
if self.hostname != "":
8492
self._config[self.host]["HostName"] = self.hostname
8593
if self.username != "":
8694
self._config[self.host]["User"] = self.username
95+
if self.options is not None:
96+
self._config[self.host].update(self.options)
8797
self.write_config()
8898

8999
def remove(self) -> None:
90100
del self._config[self.host]
91101
self.write_config()
92102

93-
async def execute(self, command: str, max_output_lines: int = 0) -> Result:
94-
self._base_cmd = f"{'' if self.use_key else f'sshpass -p {self.password} '} ssh -F {self._config_path} {self.host}"
95-
self._full_cmd = f"{self._base_cmd} {command}"
96-
return await super().execute(self._full_cmd, max_output_lines)
103+
async def execute(self, command: str, max_output_lines: int = 0) -> cli.Result:
104+
self._full_command = f"{self.base_command} {command}"
105+
return await super().execute(self._full_command, max_output_lines)
97106

98-
async def send_key(self) -> Result:
107+
async def shell(self, command: str, max_output_lines: int = 0) -> cli.Result:
108+
self._full_command = f"{self.base_command} {command}"
109+
return await super().shell(self._full_command, max_output_lines)
110+
111+
async def send_key(self) -> cli.Result:
99112
await get_public_key(self._raw_path)
100113
cmd = f"sshpass -p {self.password} " f"ssh-copy-id -o IdentitiesOnly=yes -i {self.key_path} " f"-o StrictHostKeychecking=no {self.username}@{self.hostname}"
101-
return await super().execute(cmd)
114+
return await super().shell(cmd)
102115

103116
@property
104117
def config_path(self):
105118
return self._config_path
119+
120+
@property
121+
def base_command(self):
122+
self._base_command = f'{"" if self.use_key else f"sshpass -p {self.password} "} ssh -F {self._config_path} {self.host}'
123+
return self._base_command

bale/interfaces/zfs.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Union
1+
from typing import Any, Dict, Optional, Union
22
import re
33
from datetime import datetime
44
from dataclasses import dataclass
@@ -243,8 +243,17 @@ async def snapshots(self) -> Result:
243243

244244

245245
class Ssh(ssh.Ssh, Zfs):
246-
def __init__(self, path: str, host: str, hostname: str = "", username: str = "", password: Union[str, None] = None) -> None:
247-
super().__init__(path, host, hostname, username, password)
246+
def __init__(
247+
self,
248+
host: str,
249+
hostname: str = "",
250+
username: str = "",
251+
password: Optional[str] = None,
252+
options: Optional[Dict[str, str]] = None,
253+
path: str = "data",
254+
seperator: bytes = b"\n",
255+
) -> None:
256+
super().__init__(host, hostname, username, password, options, path, seperator)
248257
Zfs.__init__(self)
249258

250259
def notify(self, command: str):

bale/tabs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _build(self):
8383

8484
@classmethod
8585
def register_connection(cls, host: str) -> None:
86-
cls._zfs[host] = Ssh(path="data", host=host)
86+
cls._zfs[host] = Ssh(host)
8787

8888
async def _display_result(self, result: Result) -> None:
8989
with ui.dialog() as dialog, el.Card():

bale/tabs/automation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def populate_job_handler(app: str, job_id: str, host: str):
4848
tab = Tab(host=None, spinner=None)
4949
if job_id not in job_handlers:
5050
if app == "remote":
51-
job_handlers[job_id] = ssh.Ssh("data", host=host)
51+
job_handlers[job_id] = ssh.Ssh(host)
5252
else:
5353
job_handlers[job_id] = cli.Cli()
5454
return job_handlers[job_id]

0 commit comments

Comments
 (0)