@@ -50,10 +50,10 @@ def __init__(self, conn_params: ConnectionParams):
5050 self .ssh_key = conn_params .ssh_key
5151 self .port = conn_params .port
5252 self .ssh_cmd = ["-o StrictHostKeyChecking=no" ]
53- if self .ssh_key :
54- self .ssh_cmd += ["-i" , self .ssh_key ]
5553 if self .port :
5654 self .ssh_cmd += ["-p" , self .port ]
55+ if self .ssh_key :
56+ self .ssh_cmd += ["-i" , self .ssh_key ]
5757 self .remote = True
5858 self .username = conn_params .username or self .get_user ()
5959 self .tunnel_process = None
@@ -285,6 +285,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
285285 mode = "r+b" if binary else "r+"
286286
287287 with tempfile .NamedTemporaryFile (mode = mode , delete = False ) as tmp_file :
288+ # Because in scp we set up port using -P option instead -p
288289 scp_ssh_cmd = ['-P' if x == '-p' else x for x in self .ssh_cmd ]
289290
290291 if not truncate :
@@ -304,12 +305,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
304305 tmp_file .write (data )
305306
306307 tmp_file .flush ()
307- # Because in scp we set up port using -P option
308308 scp_cmd = ['scp' ] + scp_ssh_cmd + [tmp_file .name , f"{ self .username } @{ self .host } :{ filename } " ]
309309 subprocess .run (scp_cmd , check = True )
310-
311310 remote_directory = os .path .dirname (filename )
312- mkdir_cmd = ['ssh' ] + scp_ssh_cmd + [f"{ self .username } @{ self .host } " , f"mkdir -p { remote_directory } " ]
311+
312+ mkdir_cmd = ['ssh' ] + self .ssh_cmd + [f"{ self .username } @{ self .host } " , f'mkdir -p { remote_directory } ' ]
313313 subprocess .run (mkdir_cmd , check = True )
314314
315315 os .remove (tmp_file .name )
@@ -387,9 +387,10 @@ def get_process_children(self, pid):
387387 # Database control
388388 def db_connect (self , dbname , user , password = None , host = "localhost" , port = 5432 ):
389389 """
390- Established SSH tunnel and Connects to a PostgreSQL
390+ Establish SSH tunnel and connect to a PostgreSQL database.
391391 """
392- self .establish_ssh_tunnel (local_port = reserve_port (), remote_port = 5432 )
392+ self .establish_ssh_tunnel (local_port = port , remote_port = self .conn_params .port )
393+
393394 try :
394395 conn = pglib .connect (
395396 host = host ,
@@ -398,6 +399,11 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
398399 user = user ,
399400 password = password ,
400401 )
402+ print ("Database connection established successfully." )
401403 return conn
402404 except Exception as e :
403- raise Exception (f"Could not connect to the database. Error: { e } " )
405+ print (f"Error connecting to the database: { str (e )} " )
406+ if self .tunnel_process :
407+ self .tunnel_process .terminate ()
408+ print ("SSH tunnel closed due to connection failure." )
409+ raise
0 commit comments