1818import time
1919from typing import List
2020
21+ import paramiko
2122from utils import SM_EFA_NCCL_INSTANCES , SM_EFA_RDMA_INSTANCES , get_python_executable , logger
2223
2324FINISHED_STATUS_FILE = "/tmp/done.algo-1"
@@ -74,6 +75,24 @@ def start_sshd_daemon():
7475 logger .info ("Started SSH daemon." )
7576
7677
78+ class CustomHostKeyPolicy (paramiko .client .MissingHostKeyPolicy ):
79+ def missing_host_key (self , client , hostname , key ):
80+ """Accept host keys for algo-* hostnames, reject others.
81+
82+ Args:
83+ client: The SSHClient instance
84+ hostname: The hostname attempting to connect
85+ key: The host key
86+
87+ Raises:
88+ paramiko.SSHException: If hostname doesn't match algo-* pattern
89+ """
90+ if hostname .startswith ("algo-" ):
91+ client .get_host_keys ().add (hostname , key .get_name (), key )
92+ return
93+ raise paramiko .SSHException (f"Unknown host key for { hostname } " )
94+
95+
7796def _can_connect (host : str , port : int = DEFAULT_SSH_PORT ) -> bool :
7897 """Check if the connection to the provided host and port is possible."""
7998 try :
@@ -82,7 +101,7 @@ def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
82101 logger .debug ("Testing connection to host %s" , host )
83102 client = paramiko .SSHClient ()
84103 client .load_system_host_keys ()
85- client .set_missing_host_key_policy (paramiko . RejectPolicy ())
104+ client .set_missing_host_key_policy (CustomHostKeyPolicy ())
86105 client .connect (host , port = port )
87106 client .close ()
88107 logger .info ("Can connect to host %s" , host )
0 commit comments