|
1 | 1 | import getpass |
2 | 2 | import os |
3 | | -import logging |
4 | 3 | import platform |
5 | 4 | import subprocess |
6 | 5 | import tempfile |
@@ -55,40 +54,10 @@ def __init__(self, conn_params: ConnectionParams): |
55 | 54 | self.remote = True |
56 | 55 | self.username = conn_params.username or getpass.getuser() |
57 | 56 | self.ssh_dest = f"{self.username}@{self.host}" if conn_params.username else self.host |
58 | | - self.add_known_host(self.host) |
59 | | - self.tunnel_process = None |
60 | 57 |
|
61 | 58 | def __enter__(self): |
62 | 59 | return self |
63 | 60 |
|
64 | | - def __exit__(self, exc_type, exc_val, exc_tb): |
65 | | - self.close_ssh_tunnel() |
66 | | - |
67 | | - def establish_ssh_tunnel(self, local_port, remote_port): |
68 | | - """ |
69 | | - Establish an SSH tunnel from a local port to a remote PostgreSQL port. |
70 | | - """ |
71 | | - ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"] |
72 | | - self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300) |
73 | | - |
74 | | - def close_ssh_tunnel(self): |
75 | | - if hasattr(self, 'tunnel_process'): |
76 | | - self.tunnel_process.terminate() |
77 | | - self.tunnel_process.wait() |
78 | | - del self.tunnel_process |
79 | | - else: |
80 | | - print("No active tunnel to close.") |
81 | | - |
82 | | - def add_known_host(self, host): |
83 | | - known_hosts_path = os.path.expanduser("~/.ssh/known_hosts") |
84 | | - cmd = 'ssh-keyscan -H %s >> %s' % (host, known_hosts_path) |
85 | | - |
86 | | - try: |
87 | | - subprocess.check_call(cmd, shell=True) |
88 | | - logging.info("Successfully added %s to known_hosts." % host) |
89 | | - except subprocess.CalledProcessError as e: |
90 | | - raise Exception("Failed to add %s to known_hosts. Error: %s" % (host, str(e))) |
91 | | - |
92 | 61 | def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, |
93 | 62 | encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None, |
94 | 63 | stderr=None, get_process=None, timeout=None): |
@@ -293,6 +262,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal |
293 | 262 | with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file: |
294 | 263 | # For scp the port is specified by a "-P" option |
295 | 264 | scp_args = ['-P' if x == '-p' else x for x in self.ssh_args] |
| 265 | + |
296 | 266 | if not truncate: |
297 | 267 | scp_cmd = ['scp'] + scp_args + [f"{self.ssh_dest}:{filename}", tmp_file.name] |
298 | 268 | subprocess.run(scp_cmd, check=False) # The file might not exist yet |
@@ -391,18 +361,11 @@ def get_process_children(self, pid): |
391 | 361 |
|
392 | 362 | # Database control |
393 | 363 | def db_connect(self, dbname, user, password=None, host="localhost", port=5432): |
394 | | - """ |
395 | | - Established SSH tunnel and Connects to a PostgreSQL |
396 | | - """ |
397 | | - self.establish_ssh_tunnel(local_port=port, remote_port=5432) |
398 | | - try: |
399 | | - conn = pglib.connect( |
400 | | - host=host, |
401 | | - port=port, |
402 | | - database=dbname, |
403 | | - user=user, |
404 | | - password=password, |
405 | | - ) |
406 | | - return conn |
407 | | - except Exception as e: |
408 | | - raise Exception(f"Could not connect to the database. Error: {e}") |
| 364 | + conn = pglib.connect( |
| 365 | + host=host, |
| 366 | + port=port, |
| 367 | + database=dbname, |
| 368 | + user=user, |
| 369 | + password=password, |
| 370 | + ) |
| 371 | + return conn |
0 commit comments