1- import logging
21import os
2+ import socket
33import subprocess
44import tempfile
55import platform
6+ import time
67
78from ..utils import reserve_port
89
@@ -50,10 +51,10 @@ def __init__(self, conn_params: ConnectionParams):
5051 self .ssh_key = conn_params .ssh_key
5152 self .port = conn_params .port
5253 self .ssh_cmd = ["-o StrictHostKeyChecking=no" ]
53- if self .port :
54- self .ssh_cmd += ["-p" , self .port ]
5554 if self .ssh_key :
5655 self .ssh_cmd += ["-i" , self .ssh_key ]
56+ if self .port :
57+ self .ssh_cmd += ["-p" , self .port ]
5758 self .remote = True
5859 self .username = conn_params .username or self .get_user ()
5960 self .tunnel_process = None
@@ -64,17 +65,36 @@ def __enter__(self):
6465 def __exit__ (self , exc_type , exc_val , exc_tb ):
6566 self .close_ssh_tunnel ()
6667
68+ @staticmethod
69+ def is_port_open (host , port ):
70+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as sock :
71+ sock .settimeout (1 ) # Таймаут для попытки соединения
72+ try :
73+ sock .connect ((host , port ))
74+ return True
75+ except socket .error :
76+ return False
77+
6778 def establish_ssh_tunnel (self , local_port , remote_port ):
6879 """
6980 Establish an SSH tunnel from a local port to a remote PostgreSQL port.
7081 """
7182 ssh_cmd = ['-N' , '-L' , f"{ local_port } :localhost:{ remote_port } " ]
7283 self .tunnel_process = self .exec_command (ssh_cmd , get_process = True , timeout = 300 )
84+ timeout = 10
85+ start_time = time .time ()
86+ while time .time () - start_time < timeout :
87+ if self .is_port_open ('localhost' , local_port ):
88+ print ("SSH tunnel established." )
89+ return
90+ time .sleep (0.5 )
91+ raise Exception ("Failed to establish SSH tunnel within the timeout period." )
7392
7493 def close_ssh_tunnel (self ):
75- if hasattr ( self , ' tunnel_process' ) :
94+ if self . tunnel_process :
7695 self .tunnel_process .terminate ()
7796 self .tunnel_process .wait ()
97+ print ("SSH tunnel closed." )
7898 del self .tunnel_process
7999 else :
80100 print ("No active tunnel to close." )
@@ -240,9 +260,9 @@ def mkdtemp(self, prefix=None):
240260 - prefix (str): The prefix of the temporary directory name.
241261 """
242262 if prefix :
243- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , f"mktemp -d { prefix } XXXXX" ]
263+ command = ["ssh" + f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ f"mktemp -d { prefix } XXXXX" ]
244264 else :
245- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , "mktemp -d" ]
265+ command = ["ssh" , f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ "mktemp -d" ]
246266
247267 result = subprocess .run (command , stdout = subprocess .PIPE , stderr = subprocess .PIPE , text = True )
248268
@@ -285,7 +305,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
285305 mode = "r+b" if binary else "r+"
286306
287307 with tempfile .NamedTemporaryFile (mode = mode , delete = False ) as tmp_file :
288- # Because in scp we set up port using -P option instead -p
308+ # Because in scp we set up port using -P option
289309 scp_ssh_cmd = ['-P' if x == '-p' else x for x in self .ssh_cmd ]
290310
291311 if not truncate :
@@ -307,9 +327,9 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
307327 tmp_file .flush ()
308328 scp_cmd = ['scp' ] + scp_ssh_cmd + [tmp_file .name , f"{ self .username } @{ self .host } :{ filename } " ]
309329 subprocess .run (scp_cmd , check = True )
310- remote_directory = os .path .dirname (filename )
311330
312- mkdir_cmd = ['ssh' ] + self .ssh_cmd + [f"{ self .username } @{ self .host } " , f'mkdir -p { remote_directory } ' ]
331+ remote_directory = os .path .dirname (filename )
332+ mkdir_cmd = ['ssh' , f"{ self .username } @{ self .host } " ] + self .ssh_cmd + [f"mkdir -p { remote_directory } " ]
313333 subprocess .run (mkdir_cmd , check = True )
314334
315335 os .remove (tmp_file .name )
@@ -374,7 +394,7 @@ def get_pid(self):
374394 return int (self .exec_command ("echo $$" , encoding = get_default_encoding ()))
375395
376396 def get_process_children (self , pid ):
377- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , f"pgrep -P { pid } " ]
397+ command = ["ssh" , f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ f"pgrep -P { pid } " ]
378398
379399 result = subprocess .run (command , stdout = subprocess .PIPE , stderr = subprocess .PIPE , text = True )
380400
@@ -389,15 +409,16 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
389409 """
390410 Establish SSH tunnel and connect to a PostgreSQL database.
391411 """
392- self . establish_ssh_tunnel ( local_port = port , remote_port = self . conn_params . port )
393-
412+ local_port = reserve_port ( )
413+ self . establish_ssh_tunnel ( local_port = local_port , remote_port = port )
394414 try :
395415 conn = pglib .connect (
396416 host = host ,
397- port = port ,
417+ port = local_port ,
398418 database = dbname ,
399419 user = user ,
400420 password = password ,
421+ timeout = 10
401422 )
402423 print ("Database connection established successfully." )
403424 return conn
0 commit comments