44
55from __future__ import annotations
66
7- import atexit
8- import datetime as dt
9- import logging
107import shlex
11- import subprocess
128from typing import TYPE_CHECKING , Any , Optional
139
1410from typing_extensions import override
1511
16- from crossbench import parse
1712from crossbench .plt .arch import MachineArch
1813from crossbench .plt .linux import RemoteLinuxPlatform
19- from crossbench .plt .ssh import SshPlatformMixin , SshPortManager
14+ from crossbench .plt .ssh import SshPlatformMixin
15+ from crossbench .plt .ssh_port_manager import SshPortManager
2016
2117if TYPE_CHECKING :
2218 from crossbench .path import AnyPath , LocalPath
2622
2723class LinuxSshPlatform (SshPlatformMixin , RemoteLinuxPlatform ):
2824
29- PORT_FORWARDING_TIMEOUT = dt .timedelta (seconds = 10 )
3025
3126 def __init__ (self , host_platform : Platform , host : str , port : int ,
3227 ssh_port : int , ssh_user : str ) -> None :
3328 super ().__init__ (host_platform , host , port , ssh_port , ssh_user )
3429 self ._machine : MachineArch | None = None
3530 self ._system_details : dict [str , Any ] | None = None
3631 self ._cpu_details : dict [str , Any ] | None = None
37- # TOOO: create custom PortManager for linux-ssh
38- self ._port_forward_popens : dict [int , subprocess .Popen ] = {}
39- self ._reverse_port_forward_popens : dict [int , subprocess .Popen ] = {}
40- atexit .register (self ._stop_all_port_forward )
4132
4233 def _create_port_manager (self ) -> PortManager :
4334 return SshPortManager (self )
@@ -47,7 +38,7 @@ def _create_port_manager(self) -> PortManager:
4738 def name (self ) -> str :
4839 return "linux_ssh"
4940
50- def _build_ssh_cmd (self , * args : CmdArg , shell : bool = False ) -> ListCmdArgs :
41+ def build_ssh_cmd (self , * args : CmdArg , shell : bool = False ) -> ListCmdArgs :
5142 self .validate_shell_args (args , shell )
5243 ssh_cmd : ListCmdArgs = [
5344 "ssh" , "-p" , f"{ self ._ssh_port } " , f"{ self ._ssh_user } @{ self ._host } "
@@ -64,7 +55,7 @@ def _build_ssh_cmd(self, *args: CmdArg, shell: bool = False) -> ListCmdArgs:
6455
6556 @override
6657 def build_shell_cmd (self , * args : CmdArg , shell : bool = False ) -> ListCmdArgs :
67- return self ._build_ssh_cmd (* args , shell = shell )
58+ return self .build_ssh_cmd (* args , shell = shell )
6859
6960 def processes (self ,
7061 attrs : Optional [list [str ]] = None ) -> list [dict [str , Any ]]:
@@ -99,62 +90,3 @@ def pull(self, from_path: AnyPath, to_path: LocalPath) -> LocalPath:
9990 ]
10091 self ._host_platform .sh_stdout (* scp_cmd )
10192 return to_path
102-
103- def port_forward (self , local_port : int , remote_port : int ) -> int :
104- local_port , remote_port = self ._validate_forwarding_ports (
105- local_port , remote_port )
106- self ._port_forward_popens [local_port ] = self .host_platform .popen (
107- * self ._build_ssh_cmd ("-NL" , f"{ local_port } :localhost:{ remote_port } " ))
108- self .host_platform .wait_for_port (local_port , self .PORT_FORWARDING_TIMEOUT )
109- logging .debug ("Forwarded Remote Port: %s:%s <= %s:%s" , self ._host_platform ,
110- local_port , self , remote_port )
111- return local_port
112-
113- def _validate_forwarding_ports (self , local_port : int ,
114- remote_port : int ) -> tuple [int , int ]:
115- local_port = parse .NumberParser .positive_zero_int (local_port , "local_port" )
116- remote_port = parse .NumberParser .port_number (remote_port , "remote_port" )
117- if not local_port :
118- local_port = self .host_platform .get_free_port ()
119- if local_port in self ._port_forward_popens :
120- raise RuntimeError (f"Cannot forward local port { local_port } twice." )
121- return local_port , remote_port
122-
123- def stop_port_forward (self , local_port : int ) -> None :
124- self ._port_forward_popens .pop (local_port ).terminate ()
125-
126- def reverse_port_forward (self , remote_port : int , local_port : int ) -> int :
127- # TODO: this should likely match with adb, where we support 0
128- # for auto-allocating a remote_port
129- remote_port , local_port = self ._validate_reverse_forwarding_ports (
130- remote_port , local_port )
131- self ._reverse_port_forward_popens [remote_port ] = self .host_platform .popen (
132- * self ._build_ssh_cmd ("-NR" , f"{ remote_port } :localhost:{ local_port } " ))
133- self .wait_for_port (remote_port , self .PORT_FORWARDING_TIMEOUT )
134- logging .debug ("Forwarded Local Port: %s:%s => %s:%s" , self ._host_platform ,
135- local_port , self , remote_port )
136- return remote_port
137-
138- def _validate_reverse_forwarding_ports (self , remote_port : int ,
139- local_port : int ) -> tuple [int , int ]:
140- remote_port = parse .NumberParser .port_number (remote_port , "remote_port" )
141- local_port = parse .NumberParser .positive_zero_int (local_port , "local_port" )
142- if not local_port :
143- local_port = self .host_platform .get_free_port ()
144- if remote_port in self ._reverse_port_forward_popens :
145- raise RuntimeError (f"Cannot forward remote port { remote_port } twice." )
146- return remote_port , local_port
147-
148- def stop_reverse_port_forward (self , remote_port : int ) -> None :
149- self ._reverse_port_forward_popens .pop (remote_port ).terminate ()
150-
151- def _stop_all_port_forward (self ) -> None :
152- for port in list (self ._port_forward_popens .keys ()):
153- self .stop_port_forward (port )
154- for port in list (self ._reverse_port_forward_popens .keys ()):
155- self .stop_reverse_port_forward (port )
156-
157- assert not self ._port_forward_popens , (
158- "Did not stop all port forwarding processes." )
159- assert not self ._reverse_port_forward_popens , (
160- "Did not stop all reverse port forwarding processes." )
0 commit comments