1- import logging
21import os
2+ import socket
33import subprocess
44import tempfile
55import platform
6+ import time
67
78# we support both pg8000 and psycopg2
89try :
@@ -48,10 +49,10 @@ def __init__(self, conn_params: ConnectionParams):
4849 self .ssh_key = conn_params .ssh_key
4950 self .port = conn_params .port
5051 self .ssh_cmd = ["-o StrictHostKeyChecking=no" ]
51- if self .port :
52- self .ssh_cmd += ["-p" , self .port ]
5352 if self .ssh_key :
5453 self .ssh_cmd += ["-i" , self .ssh_key ]
54+ if self .port :
55+ self .ssh_cmd += ["-p" , self .port ]
5556 self .remote = True
5657 self .username = conn_params .username or self .get_user ()
5758 self .tunnel_process = None
@@ -62,17 +63,36 @@ def __enter__(self):
6263 def __exit__ (self , exc_type , exc_val , exc_tb ):
6364 self .close_ssh_tunnel ()
6465
66+ @staticmethod
67+ def is_port_open (host , port ):
68+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as sock :
69+ sock .settimeout (1 ) # Таймаут для попытки соединения
70+ try :
71+ sock .connect ((host , port ))
72+ return True
73+ except socket .error :
74+ return False
75+
6576 def establish_ssh_tunnel (self , local_port , remote_port ):
6677 """
6778 Establish an SSH tunnel from a local port to a remote PostgreSQL port.
6879 """
6980 ssh_cmd = ['-N' , '-L' , f"{ local_port } :localhost:{ remote_port } " ]
7081 self .tunnel_process = self .exec_command (ssh_cmd , get_process = True , timeout = 300 )
82+ timeout = 10
83+ start_time = time .time ()
84+ while time .time () - start_time < timeout :
85+ if self .is_port_open ('localhost' , local_port ):
86+ print ("SSH tunnel established." )
87+ return
88+ time .sleep (0.5 )
89+ raise Exception ("Failed to establish SSH tunnel within the timeout period." )
7190
7291 def close_ssh_tunnel (self ):
73- if hasattr ( self , ' tunnel_process' ) :
92+ if self . tunnel_process :
7493 self .tunnel_process .terminate ()
7594 self .tunnel_process .wait ()
95+ print ("SSH tunnel closed." )
7696 del self .tunnel_process
7797 else :
7898 print ("No active tunnel to close." )
@@ -238,9 +258,9 @@ def mkdtemp(self, prefix=None):
238258 - prefix (str): The prefix of the temporary directory name.
239259 """
240260 if prefix :
241- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , f"mktemp -d { prefix } XXXXX" ]
261+ command = ["ssh" + f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ f"mktemp -d { prefix } XXXXX" ]
242262 else :
243- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , "mktemp -d" ]
263+ command = ["ssh" , f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ "mktemp -d" ]
244264
245265 result = subprocess .run (command , stdout = subprocess .PIPE , stderr = subprocess .PIPE , text = True )
246266
@@ -283,7 +303,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
283303 mode = "r+b" if binary else "r+"
284304
285305 with tempfile .NamedTemporaryFile (mode = mode , delete = False ) as tmp_file :
286- # Because in scp we set up port using -P option instead -p
306+ # Because in scp we set up port using -P option
287307 scp_ssh_cmd = ['-P' if x == '-p' else x for x in self .ssh_cmd ]
288308
289309 if not truncate :
@@ -305,9 +325,9 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
305325 tmp_file .flush ()
306326 scp_cmd = ['scp' ] + scp_ssh_cmd + [tmp_file .name , f"{ self .username } @{ self .host } :{ filename } " ]
307327 subprocess .run (scp_cmd , check = True )
308- remote_directory = os .path .dirname (filename )
309328
310- mkdir_cmd = ['ssh' ] + self .ssh_cmd + [f"{ self .username } @{ self .host } " , f'mkdir -p { remote_directory } ' ]
329+ remote_directory = os .path .dirname (filename )
330+ mkdir_cmd = ['ssh' , f"{ self .username } @{ self .host } " ] + self .ssh_cmd + [f"mkdir -p { remote_directory } " ]
311331 subprocess .run (mkdir_cmd , check = True )
312332
313333 os .remove (tmp_file .name )
@@ -372,7 +392,7 @@ def get_pid(self):
372392 return int (self .exec_command ("echo $$" , encoding = get_default_encoding ()))
373393
374394 def get_process_children (self , pid ):
375- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , f"pgrep -P { pid } " ]
395+ command = ["ssh" , f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ f"pgrep -P { pid } " ]
376396
377397 result = subprocess .run (command , stdout = subprocess .PIPE , stderr = subprocess .PIPE , text = True )
378398
@@ -387,15 +407,16 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
387407 """
388408 Establish SSH tunnel and connect to a PostgreSQL database.
389409 """
390- self . establish_ssh_tunnel ( local_port = port , remote_port = self . conn_params . port )
391-
410+ local_port = reserve_port ( )
411+ self . establish_ssh_tunnel ( local_port = local_port , remote_port = port )
392412 try :
393413 conn = pglib .connect (
394414 host = host ,
395- port = port ,
415+ port = local_port ,
396416 database = dbname ,
397417 user = user ,
398418 password = password ,
419+ timeout = 10
399420 )
400421 print ("Database connection established successfully." )
401422 return conn
0 commit comments