1010# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
13- """An entry point for runtime environment. This must be kept independent of SageMaker PySDK"""
13+ """An utils function for runtime environment. This must be kept independent of SageMaker PySDK"""
1414from __future__ import absolute_import
1515
1616import argparse
2121import time
2222from typing import List
2323
24+ import paramiko
25+
2426if __package__ is None or __package__ == "" :
2527 from runtime_environment_manager import (
2628 get_logger ,
4345logger = get_logger ()
4446
4547
48+ class CustomHostKeyPolicy (paramiko .client .MissingHostKeyPolicy ):
49+ """Class to handle host key policy for SageMaker distributed training SSH connections.
50+
51+ Example:
52+ >>> client = paramiko.SSHClient()
53+ >>> client.set_missing_host_key_policy(CustomHostKeyPolicy())
54+ >>> # Will succeed for SageMaker algorithm containers
55+ >>> client.connect('algo-1234.internal')
56+ >>> # Will raise SSHException for other unknown hosts
57+ >>> client.connect('unknown-host') # raises SSHException
58+ """
59+
60+ def missing_host_key (self , client , hostname , key ):
61+ """Accept host keys for algo-* hostnames, reject others.
62+
63+ Args:
64+ client: The SSHClient instance
65+ hostname: The hostname attempting to connect
66+ key: The host key
67+ Raises:
68+ paramiko.SSHException: If hostname doesn't match algo-* pattern
69+ """
70+ if hostname .startswith ("algo-" ):
71+ client .get_host_keys ().add (hostname , key .get_name (), key )
72+ return
73+ raise paramiko .SSHException (f"Unknown host key for { hostname } " )
74+
75+
4676def _parse_args (sys_args ):
4777 """Parses CLI arguments."""
4878 parser = argparse .ArgumentParser ()
@@ -54,16 +84,12 @@ def _parse_args(sys_args):
5484def _can_connect (host : str , port : int = DEFAULT_SSH_PORT ) -> bool :
5585 """Check if the connection to the provided host and port is possible."""
5686 try :
57- import paramiko
58-
59- logger .debug ("Testing connection to host %s" , host )
60- client = paramiko .SSHClient ()
61- client .load_system_host_keys ()
62- client .set_missing_host_key_policy (paramiko .AutoAddPolicy ())
63- client .connect (host , port = port )
64- client .close ()
65- logger .info ("Can connect to host %s" , host )
66- return True
87+ with paramiko .SSHClient () as client :
88+ client .load_system_host_keys ()
89+ client .set_missing_host_key_policy (CustomHostKeyPolicy ())
90+ client .connect (host , port = port )
91+ logger .info ("Can connect to host %s" , host )
92+ return True
6793 except Exception as e : # pylint: disable=W0703
6894 logger .info ("Cannot connect to host %s" , host )
6995 logger .debug ("Connection failed with exception: %s" , e )
0 commit comments