diff --git a/sagemaker-core/src/sagemaker/core/remote_function/client.py b/sagemaker-core/src/sagemaker/core/remote_function/client.py index b140c03901..a38b57662a 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/client.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/client.py @@ -366,7 +366,7 @@ def wrapper(*args, **kwargs): s3_uri=s3_path_join( job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER ), - hmac_key=job.hmac_key, + ) except ServiceError as serr: chained_e = serr.__cause__ @@ -403,7 +403,7 @@ def wrapper(*args, **kwargs): return serialization.deserialize_obj_from_s3( sagemaker_session=job_settings.sagemaker_session, s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER), - hmac_key=job.hmac_key, + ) if job.describe()["TrainingJobStatus"] == "Stopped": @@ -983,7 +983,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_return = serialization.deserialize_obj_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER), - hmac_key=job.hmac_key, + ) except DeserializationError as e: client_exception = e @@ -995,7 +995,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_exception = serialization.deserialize_exception_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER), - hmac_key=job.hmac_key, + ) except ServiceError as serr: chained_e = serr.__cause__ @@ -1085,7 +1085,7 @@ def result(self, timeout: float = None) -> Any: self._return = serialization.deserialize_obj_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER), - hmac_key=self._job.hmac_key, + ) self._state = _FINISHED return self._return @@ -1094,7 +1094,7 @@ def result(self, timeout: float = None) -> Any: self._exception = serialization.deserialize_exception_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER), - hmac_key=self._job.hmac_key, + ) except ServiceError as serr: chained_e = serr.__cause__ diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py b/sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py index 5278306063..491267b35f 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py @@ -164,7 +164,6 @@ class _DelayedReturnResolver: def __init__( self, delayed_returns: List[_DelayedReturn], - hmac_key: str, properties_resolver: _PropertiesResolver, parameter_resolver: _ParameterResolver, execution_variable_resolver: _ExecutionVariableResolver, @@ -175,7 +174,6 @@ def __init__( Args: delayed_returns: list of delayed returns to resolve. - hmac_key: key used to encrypt serialized and deserialized function and arguments. properties_resolver: resolver used to resolve step properties. parameter_resolver: resolver used to pipeline parameters. execution_variable_resolver: resolver used to resolve execution variables. @@ -197,7 +195,6 @@ def deserialization_task(uri): return uri, deserialize_obj_from_s3( sagemaker_session=settings["sagemaker_session"], s3_uri=uri, - hmac_key=hmac_key, ) with ThreadPoolExecutor() as executor: @@ -247,7 +244,6 @@ def resolve_pipeline_variables( context: Context, func_args: Tuple, func_kwargs: Dict, - hmac_key: str, s3_base_uri: str, **settings, ): @@ -257,7 +253,6 @@ def resolve_pipeline_variables( context: context for the execution. func_args: function args. func_kwargs: function kwargs. - hmac_key: key used to encrypt serialized and deserialized function and arguments. s3_base_uri: the s3 base uri of the function step that the serialized artifacts will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name. **settings: settings to pass to the deserialization function. @@ -280,7 +275,6 @@ def resolve_pipeline_variables( properties_resolver = _PropertiesResolver(context) delayed_return_resolver = _DelayedReturnResolver( delayed_returns=delayed_returns, - hmac_key=hmac_key, properties_resolver=properties_resolver, parameter_resolver=parameter_resolver, execution_variable_resolver=execution_variable_resolver, diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py index 39517bdc6b..8871f6727f 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py @@ -19,7 +19,6 @@ import io import sys -import hmac import hashlib import pickle @@ -156,7 +155,7 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any: # TODO: use dask serializer in case dask distributed is installed in users' environment. def serialize_func_to_s3( - func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None + func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None ): """Serializes function and uploads it to S3. @@ -164,7 +163,6 @@ def serialize_func_to_s3( sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. func: function to be serialized and persisted Raises: @@ -173,14 +171,13 @@ def serialize_func_to_s3( _upload_payload_and_metadata_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(func), - hmac_key=hmac_key, s3_uri=s3_uri, sagemaker_session=sagemaker_session, s3_kms_key=s3_kms_key, ) -def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable: +def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callable: """Downloads from S3 and then deserializes data objects. This method downloads the serialized training job outputs to a temporary directory and @@ -190,7 +187,6 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: sagemaker_session (sagemaker.core.helper.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func. Returns : The deserialized function. Raises: @@ -203,14 +199,14 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_obj_to_s3( - obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None + obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None ): """Serializes data object and uploads it to S3. @@ -219,7 +215,6 @@ def serialize_obj_to_s3( The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. obj: object to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. @@ -227,7 +222,6 @@ def serialize_obj_to_s3( _upload_payload_and_metadata_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(obj), - hmac_key=hmac_key, s3_uri=s3_uri, sagemaker_session=sagemaker_session, s3_kms_key=s3_kms_key, @@ -274,14 +268,13 @@ def json_serialize_obj_to_s3( ) -def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any: +def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: """Downloads from S3 and then deserializes data objects. Args: sagemaker_session (sagemaker.core.helper.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. Returns : Deserialized python objects. Raises: @@ -295,14 +288,14 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: s bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_exception_to_s3( - exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None + exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None ): """Serializes exception with traceback and uploads it to S3. @@ -311,7 +304,6 @@ def serialize_exception_to_s3( The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception. exc: Exception to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. @@ -320,7 +312,6 @@ def serialize_exception_to_s3( _upload_payload_and_metadata_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(exc), - hmac_key=hmac_key, s3_uri=s3_uri, sagemaker_session=sagemaker_session, s3_kms_key=s3_kms_key, @@ -329,7 +320,6 @@ def serialize_exception_to_s3( def _upload_payload_and_metadata_to_s3( bytes_to_upload: Union[bytes, io.BytesIO], - hmac_key: str, s3_uri: str, sagemaker_session: Session, s3_kms_key, @@ -338,7 +328,6 @@ def _upload_payload_and_metadata_to_s3( Args: bytes_to_upload (bytes): Serialized bytes to upload. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. @@ -346,7 +335,7 @@ def _upload_payload_and_metadata_to_s3( """ _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) - sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) + sha256_hash = _compute_hash(bytes_to_upload) _upload_bytes_to_s3( _MetaData(sha256_hash).to_json(), @@ -356,14 +345,13 @@ def _upload_payload_and_metadata_to_s3( ) -def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any: +def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: """Downloads from S3 and then deserializes exception. Args: sagemaker_session (sagemaker.core.helper.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception. Returns : Deserialized exception with traceback. Raises: @@ -377,7 +365,7 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_ bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) @@ -403,19 +391,19 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session): ) from e -def _compute_hash(buffer: bytes, secret_key: str) -> str: - """Compute the hmac-sha256 hash""" - return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest() +def _compute_hash(buffer: bytes) -> str: + """Compute the sha256 hash""" + return hashlib.sha256(buffer).hexdigest() -def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes): +def _perform_integrity_check(expected_hash_value: str, buffer: bytes): """Performs integrity checks for serialized code/arguments uploaded to s3. Verifies whether the hash read from s3 matches the hash calculated during remote function execution. """ - actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key) - if not hmac.compare_digest(expected_hash_value, actual_hash_value): + actual_hash_value = _compute_hash(buffer=buffer) + if expected_hash_value != actual_hash_value: raise DeserializationError( "Integrity check for the serialized function or data failed. " "Please restrict access to your S3 bucket" diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py b/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py index 48724d8e36..c7ee86f8a7 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py @@ -55,7 +55,6 @@ def __init__( self, sagemaker_session: Session, s3_base_uri: str, - hmac_key: str, s3_kms_key: str = None, context: Context = Context(), ): @@ -66,13 +65,11 @@ def __init__( AWS service calls are delegated to. s3_base_uri: the base uri to which serialized artifacts will be uploaded. s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. - hmac_key: Key used to encrypt serialized and deserialized function and arguments. context: Build or run context of a pipeline step. """ self.sagemaker_session = sagemaker_session self.s3_base_uri = s3_base_uri self.s3_kms_key = s3_kms_key - self.hmac_key = hmac_key self.context = context # For pipeline steps, function code is at: base/step_name/build_timestamp/ @@ -114,7 +111,7 @@ def save(self, func, *args, **kwargs): sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), s3_kms_key=self.s3_kms_key, - hmac_key=self.hmac_key, + ) logger.info( @@ -126,7 +123,7 @@ def save(self, func, *args, **kwargs): obj=(args, kwargs), sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - hmac_key=self.hmac_key, + s3_kms_key=self.s3_kms_key, ) @@ -144,7 +141,7 @@ def save_pipeline_step_function(self, serialized_data): ) serialization._upload_payload_and_metadata_to_s3( bytes_to_upload=serialized_data.func, - hmac_key=self.hmac_key, + s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), sagemaker_session=self.sagemaker_session, s3_kms_key=self.s3_kms_key, @@ -156,7 +153,7 @@ def save_pipeline_step_function(self, serialized_data): ) serialization._upload_payload_and_metadata_to_s3( bytes_to_upload=serialized_data.args, - hmac_key=self.hmac_key, + s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), sagemaker_session=self.sagemaker_session, s3_kms_key=self.s3_kms_key, @@ -172,7 +169,7 @@ def load_and_invoke(self) -> Any: func = serialization.deserialize_func_from_s3( sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), - hmac_key=self.hmac_key, + ) logger.info( @@ -182,7 +179,7 @@ def load_and_invoke(self) -> Any: args, kwargs = serialization.deserialize_obj_from_s3( sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - hmac_key=self.hmac_key, + ) logger.info("Resolving pipeline variables") @@ -190,7 +187,7 @@ def load_and_invoke(self) -> Any: self.context, args, kwargs, - hmac_key=self.hmac_key, + s3_base_uri=self.s3_base_uri, sagemaker_session=self.sagemaker_session, ) @@ -206,7 +203,7 @@ def load_and_invoke(self) -> Any: obj=result, sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.results_upload_path, RESULTS_FOLDER), - hmac_key=self.hmac_key, + s3_kms_key=self.s3_kms_key, ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/errors.py b/sagemaker-core/src/sagemaker/core/remote_function/errors.py index d12fde52d6..3f391570cf 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/errors.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/errors.py @@ -70,7 +70,7 @@ def _write_failure_reason_file(failure_msg): f.write(failure_msg) -def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) -> int: +def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: """Handle all exceptions raised during remote function execution. Args: @@ -79,7 +79,6 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) -> AWS service calls are delegated to. s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - hmac_key (str): Key used to calculate hmac hash of the serialized exception. Returns : exit_code (int): Exit code to terminate current job. """ @@ -97,7 +96,6 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) -> exc=error, sagemaker_session=sagemaker_session, s3_uri=s3_path_join(s3_base_uri, "exception"), - hmac_key=hmac_key, s3_kms_key=s3_kms_key, ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py b/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py index d353232b57..2e69f4f116 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py @@ -98,7 +98,7 @@ def _load_pipeline_context(args) -> Context: def _execute_remote_function( - sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, hmac_key, context + sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, context ): """Execute stored remote function""" from sagemaker.core.remote_function.core.stored_function import StoredFunction @@ -107,7 +107,6 @@ def _execute_remote_function( sagemaker_session=sagemaker_session, s3_base_uri=s3_base_uri, s3_kms_key=s3_kms_key, - hmac_key=hmac_key, context=context, ) @@ -138,15 +137,12 @@ def main(sys_args=None): run_in_context = args.run_in_context pipeline_context = _load_pipeline_context(args) - hmac_key = os.getenv("REMOTE_FUNCTION_SECRET_KEY") - sagemaker_session = _get_sagemaker_session(region) _execute_remote_function( sagemaker_session=sagemaker_session, s3_base_uri=s3_base_uri, s3_kms_key=s3_kms_key, run_in_context=run_in_context, - hmac_key=hmac_key, context=pipeline_context, ) @@ -162,7 +158,6 @@ def main(sys_args=None): sagemaker_session=sagemaker_session, s3_base_uri=s3_uri, s3_kms_key=s3_kms_key, - hmac_key=hmac_key, ) finally: sys.exit(exit_code) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/job.py b/sagemaker-core/src/sagemaker/core/remote_function/job.py index bed00e148f..435062db57 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/job.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/job.py @@ -17,7 +17,6 @@ import json import os import re -import secrets import shutil import sys import time @@ -621,11 +620,6 @@ def __init__( {"AWS_DEFAULT_REGION": self.sagemaker_session.boto_region_name} ) - # The following will be overridden by the _Job.compile method. - # However, it needs to be kept here for feature store SDK. - # TODO: update the feature store SDK to set the HMAC key there. - self.environment_variables.update({"REMOTE_FUNCTION_SECRET_KEY": secrets.token_hex(32)}) - if spark_config and image_uri: raise ValueError("spark_config and image_uri cannot be specified at the same time!") @@ -839,19 +833,17 @@ def _get_default_spark_image(session): class _Job: """Helper class that interacts with the SageMaker training service.""" - def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session, hmac_key: str): + def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session): """Initialize a _Job object. Args: job_name (str): The training job name. s3_uri (str): The training job output S3 uri. sagemaker_session (Session): SageMaker boto session. - hmac_key (str): Remote function secret key. """ self.job_name = job_name self.s3_uri = s3_uri self.sagemaker_session = sagemaker_session - self.hmac_key = hmac_key self._last_describe_response = None @staticmethod @@ -867,9 +859,8 @@ def from_describe_response(describe_training_job_response, sagemaker_session): """ job_name = describe_training_job_response["TrainingJobName"] s3_uri = describe_training_job_response["OutputDataConfig"]["S3OutputPath"] - hmac_key = describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"] - job = _Job(job_name, s3_uri, sagemaker_session, hmac_key) + job = _Job(job_name, s3_uri, sagemaker_session) job._last_describe_response = describe_training_job_response return job @@ -907,7 +898,6 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non job_name, s3_base_uri, job_settings.sagemaker_session, - training_job_request["Environment"]["REMOTE_FUNCTION_SECRET_KEY"], ) @staticmethod @@ -935,18 +925,11 @@ def compile( jobs_container_entrypoint = JOBS_CONTAINER_ENTRYPOINT[:] - # generate hmac key for integrity check - if step_compilation_context is None: - hmac_key = secrets.token_hex(32) - else: - hmac_key = step_compilation_context.function_step_secret_token - # serialize function and arguments if step_compilation_context is None: stored_function = StoredFunction( sagemaker_session=job_settings.sagemaker_session, s3_base_uri=s3_base_uri, - hmac_key=hmac_key, s3_kms_key=job_settings.s3_kms_key, ) stored_function.save(func, *func_args, **func_kwargs) @@ -954,7 +937,6 @@ def compile( stored_function = StoredFunction( sagemaker_session=job_settings.sagemaker_session, s3_base_uri=s3_base_uri, - hmac_key=hmac_key, s3_kms_key=job_settings.s3_kms_key, context=Context( step_name=step_compilation_context.step_name, @@ -1114,7 +1096,6 @@ def compile( request_dict["EnableManagedSpotTraining"] = job_settings.use_spot_instances request_dict["Environment"] = job_settings.environment_variables - request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key}) extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri) extended_request = _extend_mpirun_to_request(extended_request, job_settings) diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py index cc8319f935..461a3ecb73 100644 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py +++ b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_bootstrap_runtime_environment.py @@ -10,20 +10,23 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""Tests for bootstrap_runtime_environment module.""" +from __future__ import absolute_import -import pytest -from unittest.mock import Mock, patch, mock_open, MagicMock import json -import sys +import os +import pytest +import subprocess +from unittest.mock import patch, MagicMock, mock_open, call from sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment import ( + _parse_args, _bootstrap_runtime_env_for_remote_function, _bootstrap_runtime_env_for_pipeline_step, _handle_pre_exec_scripts, _install_dependencies, _unpack_user_workspace, _write_failure_reason_file, - _parse_args, log_key_value, log_env_variables, mask_sensitive_info, @@ -35,6 +38,11 @@ main, SUCCESS_EXIT_CODE, DEFAULT_FAILURE_CODE, + FAILURE_REASON_PATH, + REMOTE_FUNCTION_WORKSPACE, + BASE_CHANNEL_PATH, + JOB_REMOTE_FUNCTION_WORKSPACE, + SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME, SENSITIVE_KEYWORDS, HIDDEN_VALUE, ) @@ -43,506 +51,629 @@ ) -class TestBootstrapRuntimeEnvironment: - """Test cases for bootstrap runtime environment functions""" - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies" - ) - def test_bootstrap_runtime_env_for_remote_function( - self, mock_install, mock_handle, mock_unpack - ): - """Test _bootstrap_runtime_env_for_remote_function""" - mock_unpack.return_value = "/workspace" - dependency_settings = _DependencySettings(dependency_file="requirements.txt") - - _bootstrap_runtime_env_for_remote_function( - client_python_version="3.8", conda_env="myenv", dependency_settings=dependency_settings - ) - - mock_unpack.assert_called_once() - mock_handle.assert_called_once_with("/workspace") - mock_install.assert_called_once() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace" - ) - def test_bootstrap_runtime_env_for_remote_function_no_workspace(self, mock_unpack): - """Test _bootstrap_runtime_env_for_remote_function with no workspace""" - mock_unpack.return_value = None - - _bootstrap_runtime_env_for_remote_function(client_python_version="3.8") - - mock_unpack.assert_called_once() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.mkdir" - ) - def test_bootstrap_runtime_env_for_pipeline_step(self, mock_mkdir, mock_exists, mock_unpack): - """Test _bootstrap_runtime_env_for_pipeline_step""" - mock_unpack.return_value = None - mock_exists.return_value = False - - _bootstrap_runtime_env_for_pipeline_step( - client_python_version="3.8", func_step_workspace="workspace" - ) - - mock_mkdir.assert_called_once() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.isfile" - ) - def test_handle_pre_exec_scripts_exists(self, mock_isfile, mock_manager_class): - """Test _handle_pre_exec_scripts when script exists""" - mock_isfile.return_value = True - mock_manager = Mock() - mock_manager_class.return_value = mock_manager - - _handle_pre_exec_scripts("/workspace") - - mock_manager.run_pre_exec_script.assert_called_once() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.isfile" - ) - def test_handle_pre_exec_scripts_not_exists(self, mock_isfile, mock_manager_class): - """Test _handle_pre_exec_scripts when script doesn't exist""" - mock_isfile.return_value = False - mock_manager = Mock() - mock_manager_class.return_value = mock_manager - - _handle_pre_exec_scripts("/workspace") - - mock_manager.run_pre_exec_script.assert_not_called() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.join" - ) - def test_install_dependencies_with_file(self, mock_join, mock_manager_class): - """Test _install_dependencies with dependency file""" - mock_join.return_value = "/workspace/requirements.txt" - mock_manager = Mock() - mock_manager_class.return_value = mock_manager - - dependency_settings = _DependencySettings(dependency_file="requirements.txt") - - _install_dependencies( - dependency_file_dir="/workspace", - conda_env="myenv", - client_python_version="3.8", - channel_name="channel", - dependency_settings=dependency_settings, - ) - - mock_manager.bootstrap.assert_called_once() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager" - ) - def test_install_dependencies_no_file(self, mock_manager_class): - """Test _install_dependencies with no dependency file""" - mock_manager = Mock() - mock_manager_class.return_value = mock_manager - - dependency_settings = _DependencySettings(dependency_file=None) - - _install_dependencies( - dependency_file_dir="/workspace", - conda_env=None, - client_python_version="3.8", - channel_name="channel", - dependency_settings=dependency_settings, - ) - - mock_manager.bootstrap.assert_not_called() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.isfile" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.shutil.unpack_archive" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.pathlib.Path" - ) - def test_unpack_user_workspace_success(self, mock_path, mock_unpack, mock_isfile, mock_exists): - """Test _unpack_user_workspace successfully unpacks workspace""" - mock_exists.return_value = True - mock_isfile.return_value = True - mock_path.return_value.absolute.return_value = "/workspace" - - result = _unpack_user_workspace() - - assert result is not None - mock_unpack.assert_called_once() - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists" - ) - def test_unpack_user_workspace_no_directory(self, mock_exists): - """Test _unpack_user_workspace when directory doesn't exist""" - mock_exists.return_value = False - - result = _unpack_user_workspace() - - assert result is None - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists" - ) - @patch("builtins.open", new_callable=mock_open) - def test_write_failure_reason_file(self, mock_file, mock_exists): - """Test _write_failure_reason_file""" - mock_exists.return_value = False - - _write_failure_reason_file("Test error message") - - mock_file.assert_called_once() - mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message") - - def test_parse_args(self): - """Test _parse_args""" - args = _parse_args( - [ - "--job_conda_env", - "myenv", - "--client_python_version", - "3.8", - "--dependency_settings", - '{"dependency_file": "requirements.txt"}', - ] - ) - - assert args.job_conda_env == "myenv" - assert args.client_python_version == "3.8" - assert args.dependency_settings == '{"dependency_file": "requirements.txt"}' - - -class TestLoggingFunctions: - """Test cases for logging functions""" - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger" - ) - def test_log_key_value_normal(self, mock_logger): - """Test log_key_value with normal key""" - log_key_value("MY_KEY", "my_value") - +class TestParseArgs: + """Test _parse_args function.""" + + def test_parse_required_args(self): + """Test parsing required arguments.""" + args = [ + "--client_python_version", "3.8", + ] + parsed = _parse_args(args) + assert parsed.client_python_version == "3.8" + + def test_parse_all_args(self): + """Test parsing all arguments.""" + args = [ + "--job_conda_env", "my-env", + "--client_python_version", "3.9", + "--client_sagemaker_pysdk_version", "2.100.0", + "--pipeline_execution_id", "exec-123", + "--dependency_settings", '{"dependency_file": "requirements.txt"}', + "--func_step_s3_dir", "s3://bucket/func", + "--distribution", "torchrun", + "--user_nproc_per_node", "4", + ] + parsed = _parse_args(args) + assert parsed.job_conda_env == "my-env" + assert parsed.client_python_version == "3.9" + assert parsed.client_sagemaker_pysdk_version == "2.100.0" + assert parsed.pipeline_execution_id == "exec-123" + assert parsed.dependency_settings == '{"dependency_file": "requirements.txt"}' + assert parsed.func_step_s3_dir == "s3://bucket/func" + assert parsed.distribution == "torchrun" + assert parsed.user_nproc_per_node == "4" + + def test_parse_default_values(self): + """Test default values for optional arguments.""" + args = [ + "--client_python_version", "3.8", + ] + parsed = _parse_args(args) + assert parsed.job_conda_env is None + assert parsed.client_sagemaker_pysdk_version is None + assert parsed.pipeline_execution_id is None + assert parsed.dependency_settings is None + assert parsed.func_step_s3_dir is None + assert parsed.distribution is None + assert parsed.user_nproc_per_node is None + + +class TestLogKeyValue: + """Test log_key_value function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") + def test_logs_regular_value(self, mock_logger): + """Test logs regular key-value pair.""" + log_key_value("my_name", "my_value") + mock_logger.info.assert_called_once_with("%s=%s", "my_name", "my_value") + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") + def test_masks_sensitive_key(self, mock_logger): + """Test masks sensitive keywords.""" + for keyword in ["PASSWORD", "SECRET", "TOKEN", "KEY", "PRIVATE", "CREDENTIALS"]: + mock_logger.reset_mock() + log_key_value(f"my_{keyword}", "sensitive_value") + mock_logger.info.assert_called_once_with("%s=%s", f"my_{keyword}", HIDDEN_VALUE) + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") + def test_logs_dict_value(self, mock_logger): + """Test logs dictionary value.""" + value = {"field1": "value1", "field2": "value2"} + log_key_value("my_config", value) + mock_logger.info.assert_called_once_with("%s=%s", "my_config", json.dumps(value)) + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger") + def test_logs_json_string_value(self, mock_logger): + """Test logs JSON string value.""" + value = '{"key1": "value1"}' + log_key_value("my_key", value) mock_logger.info.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger" - ) - def test_log_key_value_sensitive(self, mock_logger): - """Test log_key_value with sensitive key""" - log_key_value("MY_PASSWORD", "secret123") - mock_logger.info.assert_called_once() - call_args = mock_logger.info.call_args[0] - assert HIDDEN_VALUE in str(call_args) - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger" - ) - def test_log_key_value_dict(self, mock_logger): - """Test log_key_value with dictionary value""" - log_key_value("MY_CONFIG", {"key": "value"}) - - mock_logger.info.assert_called_once() +class TestLogEnvVariables: + """Test log_env_variables function.""" - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.logger" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.environ", - {"ENV_VAR": "value"}, - ) - def test_log_env_variables(self, mock_logger): - """Test log_env_variables""" - log_env_variables({"CUSTOM_VAR": "custom_value"}) + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_key_value") + @patch.dict("os.environ", {"ENV_VAR1": "value1", "ENV_VAR2": "value2"}) + def test_logs_env_and_dict_variables(self, mock_log_kv): + """Test logs both environment and dictionary variables.""" + env_dict = {"DICT_VAR1": "dict_value1", "DICT_VAR2": "dict_value2"} + log_env_variables(env_dict) + + # Should be called for env vars and dict vars + assert mock_log_kv.call_count >= 4 - assert mock_logger.info.call_count >= 2 - def test_mask_sensitive_info(self): - """Test mask_sensitive_info""" - data = {"username": "user", "password": "secret", "nested": {"api_key": "key123"}} +class TestMaskSensitiveInfo: + """Test mask_sensitive_info function.""" + def test_masks_sensitive_keys_in_dict(self): + """Test masks sensitive keys in dictionary.""" + data = { + "username": "user", + "password": "secret123", + "api_key": "key123", + } result = mask_sensitive_info(data) - - assert result["password"] == HIDDEN_VALUE - assert result["nested"]["api_key"] == HIDDEN_VALUE assert result["username"] == "user" + assert result["password"] == HIDDEN_VALUE + assert result["api_key"] == HIDDEN_VALUE + + def test_masks_nested_dict(self): + """Test masks sensitive keys in nested dictionary.""" + data = { + "config": { + "username": "user", + "secret": "secret123", + } + } + result = mask_sensitive_info(data) + assert result["config"]["username"] == "user" + assert result["config"]["secret"] == HIDDEN_VALUE + def test_returns_non_dict_unchanged(self): + """Test returns non-dictionary unchanged.""" + data = "string_value" + result = mask_sensitive_info(data) + assert result == "string_value" -class TestResourceFunctions: - """Test cases for resource detection functions""" - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.multiprocessing.cpu_count" - ) - def test_num_cpus(self, mock_cpu_count): - """Test num_cpus""" - mock_cpu_count.return_value = 4 +class TestNumCpus: + """Test num_cpus function.""" - result = num_cpus() + @patch("multiprocessing.cpu_count") + def test_returns_cpu_count(self, mock_cpu_count): + """Test returns CPU count.""" + mock_cpu_count.return_value = 8 + assert num_cpus() == 8 - assert result == 4 - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.subprocess.check_output" - ) - def test_num_gpus_with_gpus(self, mock_check_output): - """Test num_gpus when GPUs are present""" - mock_check_output.return_value = b"GPU 0: Tesla V100\nGPU 1: Tesla V100\n" +class TestNumGpus: + """Test num_gpus function.""" - result = num_gpus() + @patch("subprocess.check_output") + def test_returns_gpu_count(self, mock_check_output): + """Test returns GPU count.""" + mock_check_output.return_value = b"GPU 0: Tesla V100\nGPU 1: Tesla V100\n" + assert num_gpus() == 2 - assert result == 2 + @patch("subprocess.check_output") + def test_returns_zero_on_error(self, mock_check_output): + """Test returns zero when nvidia-smi fails.""" + mock_check_output.side_effect = subprocess.CalledProcessError(1, "nvidia-smi") + assert num_gpus() == 0 - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.subprocess.check_output" - ) - def test_num_gpus_no_gpus(self, mock_check_output): - """Test num_gpus when no GPUs are present""" + @patch("subprocess.check_output") + def test_returns_zero_on_os_error(self, mock_check_output): + """Test returns zero when nvidia-smi not found.""" mock_check_output.side_effect = OSError() + assert num_gpus() == 0 - result = num_gpus() - assert result == 0 +class TestNumNeurons: + """Test num_neurons function.""" - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.subprocess.check_output" - ) - def test_num_neurons_with_neurons(self, mock_check_output): - """Test num_neurons when neurons are present""" - mock_check_output.return_value = b'[{"nc_count": 2}, {"nc_count": 2}]' + @patch("subprocess.check_output") + def test_returns_neuron_count(self, mock_check_output): + """Test returns neuron core count.""" + mock_output = json.dumps([{"nc_count": 2}, {"nc_count": 4}]) + mock_check_output.return_value = mock_output.encode("utf-8") + assert num_neurons() == 6 - result = num_neurons() - - assert result == 4 - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.subprocess.check_output" - ) - def test_num_neurons_no_neurons(self, mock_check_output): - """Test num_neurons when no neurons are present""" + @patch("subprocess.check_output") + def test_returns_zero_on_os_error(self, mock_check_output): + """Test returns zero when neuron-ls not found.""" mock_check_output.side_effect = OSError() + assert num_neurons() == 0 - result = num_neurons() - - assert result == 0 - - -class TestSerializationFunctions: - """Test cases for serialization functions""" - - def test_safe_serialize_string(self): - """Test safe_serialize with string""" - result = safe_serialize("test_string") - - assert result == "test_string" + @patch("subprocess.check_output") + def test_returns_zero_on_called_process_error(self, mock_check_output): + """Test returns zero when neuron-ls fails.""" + error = subprocess.CalledProcessError(1, "neuron-ls") + error.output = b"error=No neuron devices found" + mock_check_output.side_effect = error + assert num_neurons() == 0 - def test_safe_serialize_dict(self): - """Test safe_serialize with dictionary""" - result = safe_serialize({"key": "value"}) - assert result == '{"key": "value"}' +class TestSafeSerialize: + """Test safe_serialize function.""" - def test_safe_serialize_list(self): - """Test safe_serialize with list""" - result = safe_serialize([1, 2, 3]) + def test_returns_string_as_is(self): + """Test returns string without quotes.""" + assert safe_serialize("test_string") == "test_string" - assert result == "[1, 2, 3]" + def test_serializes_dict(self): + """Test serializes dictionary.""" + data = {"key": "value"} + assert safe_serialize(data) == '{"key": "value"}' - def test_safe_serialize_non_serializable(self): - """Test safe_serialize with non-serializable object""" + def test_serializes_list(self): + """Test serializes list.""" + data = [1, 2, 3] + assert safe_serialize(data) == "[1, 2, 3]" - class CustomObject: + def test_returns_str_for_non_serializable(self): + """Test returns str() for non-serializable objects.""" + class CustomObj: def __str__(self): return "custom_object" - - result = safe_serialize(CustomObject()) - - assert "custom_object" in result + + obj = CustomObj() + assert safe_serialize(obj) == "custom_object" class TestSetEnv: - """Test cases for set_env function""" + """Test set_env function.""" @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.environ", - {"TRAINING_JOB_NAME": "test-job"}, - ) - def test_set_env_basic(self, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test set_env with basic configuration""" - mock_cpus.return_value = 4 - mock_gpus.return_value = 0 + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") + @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) + def test_sets_basic_env_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): + """Test sets basic environment variables.""" + mock_cpus.return_value = 8 + mock_gpus.return_value = 2 mock_neurons.return_value = 0 - + resource_config = { "current_host": "algo-1", - "current_instance_type": "ml.m5.xlarge", - "hosts": ["algo-1"], + "current_instance_type": "ml.p3.2xlarge", + "hosts": ["algo-1", "algo-2"], "network_interface_name": "eth0", } - + set_env(resource_config) - + mock_file.assert_called_once() + mock_log_env.assert_called_once() @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.environ", - {"TRAINING_JOB_NAME": "test-job"}, - ) - def test_set_env_with_torchrun(self, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test set_env with torchrun distribution""" - mock_cpus.return_value = 4 + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") + @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) + def test_sets_torchrun_distribution_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): + """Test sets torchrun distribution environment variables.""" + mock_cpus.return_value = 8 mock_gpus.return_value = 2 mock_neurons.return_value = 0 - + resource_config = { "current_host": "algo-1", - "current_instance_type": "ml.p3.2xlarge", - "hosts": ["algo-1", "algo-2"], + "current_instance_type": "ml.p4d.24xlarge", + "hosts": ["algo-1"], "network_interface_name": "eth0", } - + set_env(resource_config, distribution="torchrun") - + + # Verify file was written mock_file.assert_called_once() @patch("builtins.open", new_callable=mock_open) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.environ", - {"TRAINING_JOB_NAME": "test-job"}, - ) - def test_set_env_with_mpirun(self, mock_neurons, mock_gpus, mock_cpus, mock_file): - """Test set_env with mpirun distribution""" - mock_cpus.return_value = 4 + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") + @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) + def test_sets_mpirun_distribution_vars(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): + """Test sets mpirun distribution environment variables.""" + mock_cpus.return_value = 8 mock_gpus.return_value = 2 mock_neurons.return_value = 0 - + resource_config = { "current_host": "algo-1", "current_instance_type": "ml.p3.2xlarge", "hosts": ["algo-1", "algo-2"], "network_interface_name": "eth0", } - + set_env(resource_config, distribution="mpirun") + + mock_file.assert_called_once() + @patch("builtins.open", new_callable=mock_open) + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.log_env_variables") + @patch.dict("os.environ", {"TRAINING_JOB_NAME": "test-job"}) + def test_uses_user_nproc_per_node(self, mock_log_env, mock_neurons, mock_gpus, mock_cpus, mock_file): + """Test uses user-specified nproc_per_node.""" + mock_cpus.return_value = 8 + mock_gpus.return_value = 2 + mock_neurons.return_value = 0 + + resource_config = { + "current_host": "algo-1", + "current_instance_type": "ml.p3.2xlarge", + "hosts": ["algo-1"], + "network_interface_name": "eth0", + } + + set_env(resource_config, user_nproc_per_node="4") + mock_file.assert_called_once() +class TestWriteFailureReasonFile: + """Test _write_failure_reason_file function.""" + + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.exists") + def test_writes_failure_file(self, mock_exists, mock_file): + """Test writes failure reason file.""" + mock_exists.return_value = False + + _write_failure_reason_file("Test error message") + + mock_file.assert_called_once_with(FAILURE_REASON_PATH, "w") + mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message") + + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.exists") + def test_does_not_write_if_exists(self, mock_exists, mock_file): + """Test does not write if failure file already exists.""" + mock_exists.return_value = True + + _write_failure_reason_file("Test error message") + + mock_file.assert_not_called() + + +class TestUnpackUserWorkspace: + """Test _unpack_user_workspace function.""" + + @patch("os.path.exists") + def test_returns_none_if_dir_not_exists(self, mock_exists): + """Test returns None if workspace directory doesn't exist.""" + mock_exists.return_value = False + + result = _unpack_user_workspace() + + assert result is None + + @patch("os.path.isfile") + @patch("os.path.exists") + def test_returns_none_if_archive_not_exists(self, mock_exists, mock_isfile): + """Test returns None if workspace archive doesn't exist.""" + mock_exists.return_value = True + mock_isfile.return_value = False + + result = _unpack_user_workspace() + + assert result is None + + @patch("shutil.unpack_archive") + @patch("os.path.isfile") + @patch("os.path.exists") + @patch("os.getcwd") + def test_unpacks_workspace_successfully(self, mock_getcwd, mock_exists, mock_isfile, mock_unpack): + """Test unpacks workspace successfully.""" + mock_getcwd.return_value = "/tmp/workspace" + mock_exists.return_value = True + mock_isfile.return_value = True + + result = _unpack_user_workspace() + + mock_unpack.assert_called_once() + assert result is not None + + +class TestHandlePreExecScripts: + """Test _handle_pre_exec_scripts function.""" + + @patch("os.path.isfile") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + def test_runs_pre_exec_script(self, mock_manager_class, mock_isfile): + """Test runs pre-execution script.""" + mock_isfile.return_value = True + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + _handle_pre_exec_scripts("/tmp/scripts") + + mock_manager.run_pre_exec_script.assert_called_once() + + +class TestInstallDependencies: + """Test _install_dependencies function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + def test_installs_with_dependency_settings(self, mock_manager_class): + """Test installs dependencies with dependency settings.""" + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + dep_settings = _DependencySettings(dependency_file="requirements.txt") + + _install_dependencies( + "/tmp/deps", + "my-env", + "3.8", + "channel", + dep_settings + ) + + mock_manager.bootstrap.assert_called_once() + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + def test_skips_if_no_dependency_file(self, mock_manager_class): + """Test skips installation if no dependency file.""" + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + dep_settings = _DependencySettings(dependency_file=None) + + _install_dependencies( + "/tmp/deps", + "my-env", + "3.8", + "channel", + dep_settings + ) + + mock_manager.bootstrap.assert_not_called() + + @patch("os.listdir") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + def test_finds_dependency_file_legacy(self, mock_manager_class, mock_listdir): + """Test finds dependency file in legacy mode.""" + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + mock_listdir.return_value = ["requirements.txt", "script.py"] + + _install_dependencies( + "/tmp/deps", + "my-env", + "3.8", + "channel", + None + ) + + mock_manager.bootstrap.assert_called_once() + + +class TestBootstrapRuntimeEnvForRemoteFunction: + """Test _bootstrap_runtime_env_for_remote_function function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") + def test_bootstraps_successfully(self, mock_unpack, mock_handle_scripts, mock_install): + """Test bootstraps runtime environment successfully.""" + mock_unpack.return_value = "/tmp/workspace" + + _bootstrap_runtime_env_for_remote_function("3.8", "my-env", None) + + mock_unpack.assert_called_once() + mock_handle_scripts.assert_called_once() + mock_install.assert_called_once() + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") + def test_returns_early_if_no_workspace(self, mock_unpack): + """Test returns early if no workspace to unpack.""" + mock_unpack.return_value = None + + _bootstrap_runtime_env_for_remote_function("3.8", "my-env", None) + + mock_unpack.assert_called_once() + + +class TestBootstrapRuntimeEnvForPipelineStep: + """Test _bootstrap_runtime_env_for_pipeline_step function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") + @patch("shutil.copy") + @patch("os.listdir") + @patch("os.path.exists") + @patch("os.mkdir") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") + def test_bootstraps_with_workspace(self, mock_unpack, mock_mkdir, mock_exists, mock_listdir, mock_copy, mock_handle_scripts, mock_install): + """Test bootstraps pipeline step with workspace.""" + mock_unpack.return_value = "/tmp/workspace" + mock_exists.return_value = True + mock_listdir.return_value = ["requirements.txt"] + + _bootstrap_runtime_env_for_pipeline_step("3.8", "func_step", "my-env", None) + + mock_unpack.assert_called_once() + mock_handle_scripts.assert_called_once() + mock_install.assert_called_once() + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._install_dependencies") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._handle_pre_exec_scripts") + @patch("os.path.exists") + @patch("os.mkdir") + @patch("os.getcwd") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._unpack_user_workspace") + def test_creates_workspace_if_none(self, mock_unpack, mock_getcwd, mock_mkdir, mock_exists, mock_handle_scripts, mock_install): + """Test creates workspace directory if none exists.""" + mock_unpack.return_value = None + mock_getcwd.return_value = "/tmp" + mock_exists.return_value = False + + _bootstrap_runtime_env_for_pipeline_step("3.8", "func_step", "my-env", None) + + mock_mkdir.assert_called_once() + + class TestMain: - """Test cases for main function""" - - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_remote_function" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.getpass.getuser" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.os.path.exists" - ) - def test_main_success( - self, mock_exists, mock_getuser, mock_manager_class, mock_bootstrap, mock_parse - ): - """Test main function successful execution""" - mock_args = Mock() + """Test main function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.set_env") + @patch("builtins.open", new_callable=mock_open, read_data='{"current_host": "algo-1", "current_instance_type": "ml.m5.xlarge", "hosts": ["algo-1"], "network_interface_name": "eth0"}') + @patch("os.path.exists") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_remote_function") + @patch("getpass.getuser") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args") + def test_main_success(self, mock_parse_args, mock_getuser, mock_bootstrap, mock_manager_class, mock_exists, mock_file, mock_set_env): + """Test main function successful execution.""" + mock_getuser.return_value = "root" + mock_exists.return_value = True + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + # Mock parsed args + mock_args = MagicMock() mock_args.client_python_version = "3.8" - mock_args.client_sagemaker_pysdk_version = "2.0.0" + mock_args.client_sagemaker_pysdk_version = None mock_args.job_conda_env = None mock_args.pipeline_execution_id = None mock_args.dependency_settings = None mock_args.func_step_s3_dir = None mock_args.distribution = None mock_args.user_nproc_per_node = None - mock_parse.return_value = mock_args - - mock_getuser.return_value = "root" - mock_exists.return_value = False - - mock_manager = Mock() - mock_manager_class.return_value = mock_manager - + mock_parse_args.return_value = mock_args + + args = [ + "--client_python_version", "3.8", + ] + with pytest.raises(SystemExit) as exc_info: - main([]) - + main(args) + assert exc_info.value.code == SUCCESS_EXIT_CODE + mock_bootstrap.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._write_failure_reason_file" - ) - def test_main_failure(self, mock_write_failure, mock_parse): - """Test main function with failure""" - mock_parse.side_effect = Exception("Test error") - + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._write_failure_reason_file") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + @patch("getpass.getuser") + def test_main_handles_exception(self, mock_getuser, mock_manager_class, mock_write_failure): + """Test main function handles exceptions.""" + mock_getuser.return_value = "root" + mock_manager = MagicMock() + mock_manager._validate_python_version.side_effect = Exception("Test error") + mock_manager_class.return_value = mock_manager + + args = [ + "--client_python_version", "3.8", + ] + with pytest.raises(SystemExit) as exc_info: - main([]) - + main(args) + assert exc_info.value.code == DEFAULT_FAILURE_CODE mock_write_failure.assert_called_once() + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.set_env") + @patch("builtins.open", new_callable=mock_open, read_data='{"current_host": "algo-1", "current_instance_type": "ml.m5.xlarge", "hosts": ["algo-1"], "network_interface_name": "eth0"}') + @patch("os.path.exists") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_env_for_pipeline_step") + @patch("getpass.getuser") + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment._parse_args") + def test_main_pipeline_execution(self, mock_parse_args, mock_getuser, mock_bootstrap, mock_manager_class, mock_exists, mock_file, mock_set_env): + """Test main function for pipeline execution.""" + mock_getuser.return_value = "root" + mock_exists.return_value = True + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + # Mock parsed args + mock_args = MagicMock() + mock_args.client_python_version = "3.8" + mock_args.client_sagemaker_pysdk_version = None + mock_args.job_conda_env = None + mock_args.pipeline_execution_id = "exec-123" + mock_args.dependency_settings = None + mock_args.func_step_s3_dir = "s3://bucket/func" + mock_args.distribution = None + mock_args.user_nproc_per_node = None + mock_parse_args.return_value = mock_args + + args = [ + "--client_python_version", "3.8", + "--pipeline_execution_id", "exec-123", + "--func_step_s3_dir", "s3://bucket/func", + ] + + with pytest.raises(SystemExit) as exc_info: + main(args) + + assert exc_info.value.code == SUCCESS_EXIT_CODE + mock_bootstrap.assert_called_once() + + @patch("sagemaker.core.remote_function.runtime_environment.bootstrap_runtime_environment.RuntimeEnvironmentManager") + @patch("getpass.getuser") + def test_main_non_root_user(self, mock_getuser, mock_manager_class): + """Test main function with non-root user.""" + mock_getuser.return_value = "ubuntu" + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + args = [ + "--client_python_version", "3.8", + ] + + with pytest.raises(SystemExit): + main(args) + + mock_manager.change_dir_permission.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py index e075489b6b..b84dda5c1a 100644 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py +++ b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_mpi_utils_remote.py @@ -10,10 +10,14 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""Tests for mpi_utils_remote module.""" +from __future__ import absolute_import +import os import pytest -from unittest.mock import Mock, patch, MagicMock, mock_open import subprocess +import time +from unittest.mock import patch, MagicMock, mock_open, call import paramiko from sagemaker.core.remote_function.runtime_environment.mpi_utils_remote import ( @@ -32,6 +36,7 @@ main, SUCCESS_EXIT_CODE, DEFAULT_FAILURE_CODE, + FAILURE_REASON_PATH, FINISHED_STATUS_FILE, READY_FILE, DEFAULT_SSH_PORT, @@ -39,328 +44,381 @@ class TestCustomHostKeyPolicy: - """Test cases for CustomHostKeyPolicy class""" + """Test CustomHostKeyPolicy class.""" - def test_missing_host_key_algo_hostname(self): - """Test missing_host_key accepts algo-* hostnames""" + def test_accepts_algo_hostname(self): + """Test accepts hostnames starting with algo-.""" policy = CustomHostKeyPolicy() - client = Mock() - client.get_host_keys.return_value = Mock() - key = Mock() - key.get_name.return_value = "ssh-rsa" - + mock_client = MagicMock() + mock_hostname = "algo-1234" + mock_key = MagicMock() + mock_key.get_name.return_value = "ssh-rsa" + # Should not raise exception - policy.missing_host_key(client, "algo-1", key) - - client.get_host_keys().add.assert_called_once() + policy.missing_host_key(mock_client, mock_hostname, mock_key) + + mock_client.get_host_keys().add.assert_called_once_with(mock_hostname, "ssh-rsa", mock_key) - def test_missing_host_key_unknown_hostname(self): - """Test missing_host_key rejects unknown hostnames""" + def test_rejects_non_algo_hostname(self): + """Test rejects hostnames not starting with algo-.""" policy = CustomHostKeyPolicy() - client = Mock() - key = Mock() + mock_client = MagicMock() + mock_hostname = "unknown-host" + mock_key = MagicMock() + + with pytest.raises(paramiko.SSHException): + policy.missing_host_key(mock_client, mock_hostname, mock_key) + + +class TestParseArgs: + """Test _parse_args function.""" - with pytest.raises(paramiko.SSHException, match="Unknown host key"): - policy.missing_host_key(client, "unknown-host", key) + def test_parse_default_args(self): + """Test parsing with default arguments.""" + args = [] + parsed = _parse_args(args) + assert parsed.job_ended == "0" + def test_parse_job_ended_true(self): + """Test parsing with job_ended set to true.""" + args = ["--job_ended", "1"] + parsed = _parse_args(args) + assert parsed.job_ended == "1" -class TestConnectionFunctions: - """Test cases for connection functions""" + def test_parse_job_ended_false(self): + """Test parsing with job_ended set to false.""" + args = ["--job_ended", "0"] + parsed = _parse_args(args) + assert parsed.job_ended == "0" - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.paramiko.SSHClient") + +class TestCanConnect: + """Test _can_connect function.""" + + @patch("paramiko.SSHClient") def test_can_connect_success(self, mock_ssh_client_class): - """Test _can_connect when connection succeeds""" - mock_client = Mock() + """Test successful connection.""" + mock_client = MagicMock() mock_ssh_client_class.return_value.__enter__.return_value = mock_client - + result = _can_connect("algo-1", DEFAULT_SSH_PORT) - + assert result is True mock_client.connect.assert_called_once_with("algo-1", port=DEFAULT_SSH_PORT) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.paramiko.SSHClient") + @patch("paramiko.SSHClient") def test_can_connect_failure(self, mock_ssh_client_class): - """Test _can_connect when connection fails""" - mock_client = Mock() + """Test failed connection.""" + mock_client = MagicMock() mock_client.connect.side_effect = Exception("Connection failed") mock_ssh_client_class.return_value.__enter__.return_value = mock_client - + result = _can_connect("algo-1", DEFAULT_SSH_PORT) - + assert result is False - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.subprocess.run") - def test_write_file_to_host_success(self, mock_run): - """Test _write_file_to_host when write succeeds""" - mock_run.return_value = Mock() + @patch("paramiko.SSHClient") + def test_can_connect_uses_custom_port(self, mock_ssh_client_class): + """Test connection with custom port.""" + mock_client = MagicMock() + mock_ssh_client_class.return_value.__enter__.return_value = mock_client + + _can_connect("algo-1", 2222) + + mock_client.connect.assert_called_once_with("algo-1", port=2222) - result = _write_file_to_host("algo-1", "/tmp/status") +class TestWriteFileToHost: + """Test _write_file_to_host function.""" + + @patch("subprocess.run") + def test_write_file_success(self, mock_run): + """Test successful file write.""" + mock_run.return_value = MagicMock(returncode=0) + + result = _write_file_to_host("algo-1", "/tmp/status") + assert result is True mock_run.assert_called_once() - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.subprocess.run") - def test_write_file_to_host_failure(self, mock_run): - """Test _write_file_to_host when write fails""" + @patch("subprocess.run") + def test_write_file_failure(self, mock_run): + """Test failed file write.""" mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") - + result = _write_file_to_host("algo-1", "/tmp/status") - + assert result is False - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists") + +class TestWriteFailureReasonFile: + """Test _write_failure_reason_file function.""" + @patch("builtins.open", new_callable=mock_open) - def test_write_failure_reason_file(self, mock_file, mock_exists): - """Test _write_failure_reason_file""" + @patch("os.path.exists") + def test_writes_failure_file(self, mock_exists, mock_file): + """Test writes failure reason file.""" mock_exists.return_value = False + + _write_failure_reason_file("Test error message") + + mock_file.assert_called_once_with(FAILURE_REASON_PATH, "w") + mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error message") - _write_failure_reason_file("Test error") - - mock_file.assert_called_once() - mock_file().write.assert_called_once_with("RuntimeEnvironmentError: Test error") + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.exists") + def test_does_not_write_if_exists(self, mock_exists, mock_file): + """Test does not write if failure file already exists.""" + mock_exists.return_value = True + + _write_failure_reason_file("Test error message") + + mock_file.assert_not_called() -class TestWaitFunctions: - """Test cases for wait functions""" +class TestWaitForMaster: + """Test _wait_for_master function.""" + @patch("time.sleep") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - def test_wait_for_master_success(self, mock_sleep, mock_can_connect): - """Test _wait_for_master when master becomes available""" - mock_can_connect.side_effect = [False, False, True] - + def test_wait_for_master_success(self, mock_can_connect, mock_sleep): + """Test successful wait for master.""" + mock_can_connect.return_value = True + _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) + + mock_can_connect.assert_called_once_with("algo-1", DEFAULT_SSH_PORT) - assert mock_can_connect.call_count == 3 - + @patch("time.time") + @patch("time.sleep") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.time") - def test_wait_for_master_timeout(self, mock_time, mock_sleep, mock_can_connect): - """Test _wait_for_master when timeout occurs""" + def test_wait_for_master_timeout(self, mock_can_connect, mock_sleep, mock_time): + """Test timeout waiting for master.""" mock_can_connect.return_value = False - mock_time.side_effect = [0, 100, 200, 301, 301] - - with pytest.raises(TimeoutError, match="Timed out waiting for master"): + # Need enough values for all time.time() calls in the loop + mock_time.side_effect = [0] + [i * 5 for i in range(1, 100)] # Simulate time passing + + with pytest.raises(TimeoutError): _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - def test_wait_for_status_file(self, mock_sleep, mock_exists): - """Test _wait_for_status_file""" - mock_exists.side_effect = [False, False, True] + @patch("time.time") + @patch("time.sleep") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") + def test_wait_for_master_retries(self, mock_can_connect, mock_sleep, mock_time): + """Test retries before successful connection.""" + mock_can_connect.side_effect = [False, False, True] + # Return value instead of side_effect for time.time() + mock_time.return_value = 0 + + _wait_for_master("algo-1", DEFAULT_SSH_PORT, timeout=300) + + assert mock_can_connect.call_count == 3 + +class TestWaitForStatusFile: + """Test _wait_for_status_file function.""" + + @patch("time.sleep") + @patch("os.path.exists") + def test_wait_for_status_file_exists(self, mock_exists, mock_sleep): + """Test wait for status file that exists.""" + mock_exists.return_value = True + _wait_for_status_file("/tmp/status") + + mock_exists.assert_called_once_with("/tmp/status") + @patch("time.sleep") + @patch("os.path.exists") + def test_wait_for_status_file_waits(self, mock_exists, mock_sleep): + """Test waits until status file exists.""" + mock_exists.side_effect = [False, False, True] + + _wait_for_status_file("/tmp/status") + assert mock_exists.call_count == 3 + assert mock_sleep.call_count == 2 + + +class TestWaitForWorkers: + """Test _wait_for_workers function.""" + + @patch("os.path.exists") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") + def test_wait_for_workers_empty_list(self, mock_can_connect, mock_exists): + """Test wait for workers with empty list.""" + _wait_for_workers([], DEFAULT_SSH_PORT, timeout=300) + + mock_can_connect.assert_not_called() + @patch("time.sleep") + @patch("os.path.exists") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - def test_wait_for_workers_success(self, mock_sleep, mock_exists, mock_can_connect): - """Test _wait_for_workers when all workers become available""" + def test_wait_for_workers_success(self, mock_can_connect, mock_exists, mock_sleep): + """Test successful wait for workers.""" mock_can_connect.return_value = True mock_exists.return_value = True - + _wait_for_workers(["algo-2", "algo-3"], DEFAULT_SSH_PORT, timeout=300) - + assert mock_can_connect.call_count == 2 + @patch("time.time") + @patch("time.sleep") + @patch("os.path.exists") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._can_connect") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.time") - def test_wait_for_workers_timeout(self, mock_time, mock_sleep, mock_can_connect): - """Test _wait_for_workers when timeout occurs""" + def test_wait_for_workers_timeout(self, mock_can_connect, mock_exists, mock_sleep, mock_time): + """Test timeout waiting for workers.""" mock_can_connect.return_value = False - mock_time.side_effect = [0, 100, 200, 301, 301] - - with pytest.raises(TimeoutError, match="Timed out waiting for workers"): + mock_exists.return_value = False + # Need enough values for all time.time() calls in the loop + mock_time.side_effect = [0] + [i * 5 for i in range(1, 100)] + + with pytest.raises(TimeoutError): _wait_for_workers(["algo-2"], DEFAULT_SSH_PORT, timeout=300) - def test_wait_for_workers_no_workers(self): - """Test _wait_for_workers with no workers""" - # Should not raise exception - _wait_for_workers([], DEFAULT_SSH_PORT, timeout=300) - -class TestBootstrapFunctions: - """Test cases for bootstrap functions""" +class TestBootstrapMasterNode: + """Test bootstrap_master_node function.""" @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_workers") def test_bootstrap_master_node(self, mock_wait): - """Test bootstrap_master_node""" - bootstrap_master_node(["algo-2", "algo-3"]) + """Test bootstrap master node.""" + worker_hosts = ["algo-2", "algo-3"] + + bootstrap_master_node(worker_hosts) + + mock_wait.assert_called_once_with(worker_hosts) - mock_wait.assert_called_once_with(["algo-2", "algo-3"]) +class TestBootstrapWorkerNode: + """Test bootstrap_worker_node function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_status_file") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_master") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._wait_for_status_file" - ) - def test_bootstrap_worker_node(self, mock_wait_status, mock_write, mock_wait_master): - """Test bootstrap_worker_node""" + def test_bootstrap_worker_node(self, mock_wait_master, mock_write, mock_wait_status): + """Test bootstrap worker node.""" bootstrap_worker_node("algo-1", "algo-2", "/tmp/status") - + mock_wait_master.assert_called_once_with("algo-1") mock_write.assert_called_once() mock_wait_status.assert_called_once_with("/tmp/status") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists") - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.subprocess.Popen") - def test_start_sshd_daemon_success(self, mock_popen, mock_exists): - """Test start_sshd_daemon when sshd exists""" - mock_exists.return_value = True - start_sshd_daemon() +class TestStartSshdDaemon: + """Test start_sshd_daemon function.""" - mock_popen.assert_called_once() + @patch("subprocess.Popen") + @patch("os.path.exists") + def test_starts_sshd_successfully(self, mock_exists, mock_popen): + """Test starts SSH daemon successfully.""" + mock_exists.return_value = True + + start_sshd_daemon() + + mock_popen.assert_called_once_with(["/usr/sbin/sshd", "-D"]) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.path.exists") - def test_start_sshd_daemon_not_found(self, mock_exists): - """Test start_sshd_daemon when sshd not found""" + @patch("os.path.exists") + def test_raises_error_if_sshd_not_found(self, mock_exists): + """Test raises error if SSH daemon not found.""" mock_exists.return_value = False - - with pytest.raises(RuntimeError, match="SSH daemon not found"): + + with pytest.raises(RuntimeError): start_sshd_daemon() - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host" - ) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - def test_write_status_file_to_workers_success(self, mock_sleep, mock_write): - """Test write_status_file_to_workers when writes succeed""" - mock_write.return_value = True - write_status_file_to_workers(["algo-2", "algo-3"], "/tmp/status") +class TestWriteStatusFileToWorkers: + """Test write_status_file_to_workers function.""" + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") + def test_writes_to_all_workers(self, mock_write): + """Test writes status file to all workers.""" + mock_write.return_value = True + worker_hosts = ["algo-2", "algo-3"] + + write_status_file_to_workers(worker_hosts, "/tmp/status") + assert mock_write.call_count == 2 - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host" - ) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.time.sleep") - def test_write_status_file_to_workers_timeout(self, mock_sleep, mock_write): - """Test write_status_file_to_workers when timeout occurs""" + @patch("time.sleep") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") + def test_retries_on_failure(self, mock_write, mock_sleep): + """Test retries writing status file on failure.""" + mock_write.side_effect = [False, False, True] + worker_hosts = ["algo-2"] + + write_status_file_to_workers(worker_hosts, "/tmp/status") + + assert mock_write.call_count == 3 + assert mock_sleep.call_count == 2 + + @patch("time.sleep") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_file_to_host") + def test_raises_timeout_after_retries(self, mock_write, mock_sleep): + """Test raises timeout after max retries.""" mock_write.return_value = False - - with pytest.raises(TimeoutError, match="Timed out waiting"): - write_status_file_to_workers(["algo-2"], "/tmp/status") - - -class TestParseArgs: - """Test cases for _parse_args function""" - - def test_parse_args_job_ended_false(self): - """Test _parse_args with job_ended=0""" - args = _parse_args(["--job_ended", "0"]) - - assert args.job_ended == "0" - - def test_parse_args_job_ended_true(self): - """Test _parse_args with job_ended=1""" - args = _parse_args(["--job_ended", "1"]) - - assert args.job_ended == "1" - - def test_parse_args_default(self): - """Test _parse_args with default values""" - args = _parse_args([]) - - assert args.job_ended == "0" + worker_hosts = ["algo-2"] + + with pytest.raises(TimeoutError): + write_status_file_to_workers(worker_hosts, "/tmp/status") class TestMain: - """Test cases for main function""" + """Test main function.""" - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._parse_args") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.environ", - {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}, - ) - def test_main_worker_node_job_running(self, mock_bootstrap_worker, mock_start_sshd, mock_parse): - """Test main for worker node when job is running""" - mock_args = Mock() - mock_args.job_ended = "0" - mock_parse.return_value = mock_args - - main([]) - + @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) + def test_main_worker_node_running(self, mock_start_sshd, mock_bootstrap_worker): + """Test main function for worker node during job run.""" + args = ["--job_ended", "0"] + + main(args) + mock_start_sshd.assert_called_once() mock_bootstrap_worker.assert_called_once() - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._parse_args") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node" - ) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.json.loads") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.environ", - { - "SM_MASTER_ADDR": "algo-1", - "SM_CURRENT_HOST": "algo-1", - "SM_HOSTS": '["algo-1", "algo-2", "algo-3"]', - }, - ) - def test_main_master_node_job_running( - self, mock_json_loads, mock_bootstrap_master, mock_start_sshd, mock_parse - ): - """Test main for master node when job is running""" - mock_args = Mock() - mock_args.job_ended = "0" - mock_parse.return_value = mock_args - mock_json_loads.return_value = ["algo-1", "algo-2", "algo-3"] - - main([]) - + @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-1", "SM_HOSTS": '["algo-1", "algo-2"]'}) + def test_main_master_node_running(self, mock_start_sshd, mock_bootstrap_master): + """Test main function for master node during job run.""" + args = ["--job_ended", "0"] + + main(args) + mock_start_sshd.assert_called_once() - mock_bootstrap_master.assert_called_once_with(["algo-2", "algo-3"]) - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._parse_args") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers" - ) - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.json.loads") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.environ", - { - "SM_MASTER_ADDR": "algo-1", - "SM_CURRENT_HOST": "algo-1", - "SM_HOSTS": '["algo-1", "algo-2"]', - }, - ) - def test_main_master_node_job_ended(self, mock_json_loads, mock_write_status, mock_parse): - """Test main for master node when job has ended""" - mock_args = Mock() - mock_args.job_ended = "1" - mock_parse.return_value = mock_args - mock_json_loads.return_value = ["algo-1", "algo-2"] - - main([]) - - mock_write_status.assert_called_once_with(["algo-2"]) - - @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._parse_args") - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_failure_reason_file" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.os.environ", - {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}, - ) - def test_main_with_exception(self, mock_write_failure, mock_parse): - """Test main when exception occurs""" - mock_parse.side_effect = Exception("Test error") - + mock_bootstrap_master.assert_called_once() + + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers") + @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-1", "SM_HOSTS": '["algo-1", "algo-2"]'}) + def test_main_master_node_job_ended(self, mock_write_status): + """Test main function for master node after job ends.""" + args = ["--job_ended", "1"] + + main(args) + + mock_write_status.assert_called_once() + + @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) + def test_main_worker_node_job_ended(self): + """Test main function for worker node after job ends.""" + args = ["--job_ended", "1"] + + # Should not raise any exceptions + main(args) + + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote._write_failure_reason_file") + @patch("sagemaker.core.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") + @patch.dict("os.environ", {"SM_MASTER_ADDR": "algo-1", "SM_CURRENT_HOST": "algo-2"}) + def test_main_handles_exception(self, mock_start_sshd, mock_write_failure): + """Test main function handles exceptions.""" + mock_start_sshd.side_effect = Exception("Test error") + args = ["--job_ended", "0"] + with pytest.raises(SystemExit) as exc_info: - main([]) - + main(args) + assert exc_info.value.code == DEFAULT_FAILURE_CODE mock_write_failure.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py index be2f1430d6..a300daf2b3 100644 --- a/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py +++ b/sagemaker-core/tests/unit/remote_function/runtime_environment/test_runtime_environment_manager.py @@ -10,16 +10,20 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""Tests for runtime_environment_manager module.""" +from __future__ import absolute_import -import pytest -from unittest.mock import Mock, patch, MagicMock, mock_open +import json +import os import subprocess import sys +import pytest +from unittest.mock import patch, MagicMock, mock_open, call from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( + _DependencySettings, RuntimeEnvironmentManager, RuntimeEnvironmentError, - _DependencySettings, get_logger, _run_and_get_output_shell_cmd, _run_pre_execution_command_script, @@ -31,467 +35,465 @@ class TestDependencySettings: - """Test cases for _DependencySettings class""" + """Test _DependencySettings class.""" + + def test_init_with_no_file(self): + """Test initialization without dependency file.""" + settings = _DependencySettings() + assert settings.dependency_file is None def test_init_with_file(self): - """Test initialization with dependency file""" + """Test initialization with dependency file.""" settings = _DependencySettings(dependency_file="requirements.txt") - assert settings.dependency_file == "requirements.txt" - def test_init_without_file(self): - """Test initialization without dependency file""" - settings = _DependencySettings() - - assert settings.dependency_file is None - def test_to_string(self): - """Test to_string method""" + """Test converts to JSON string.""" settings = _DependencySettings(dependency_file="requirements.txt") - result = settings.to_string() + assert result == '{"dependency_file": "requirements.txt"}' - assert "requirements.txt" in result - - def test_from_string(self): - """Test from_string method""" + def test_from_string_with_file(self): + """Test creates from JSON string with file.""" json_str = '{"dependency_file": "requirements.txt"}' - settings = _DependencySettings.from_string(json_str) - assert settings.dependency_file == "requirements.txt" - def test_from_string_none(self): - """Test from_string with None""" + def test_from_string_with_none(self): + """Test creates from None.""" settings = _DependencySettings.from_string(None) - assert settings is None - def test_from_dependency_file_path(self): - """Test from_dependency_file_path method""" - settings = _DependencySettings.from_dependency_file_path("/path/to/requirements.txt") - - assert settings.dependency_file == "requirements.txt" + def test_from_dependency_file_path_with_none(self): + """Test creates from None file path.""" + settings = _DependencySettings.from_dependency_file_path(None) + assert settings.dependency_file is None - def test_from_dependency_file_path_auto_capture(self): - """Test from_dependency_file_path with auto_capture""" + def test_from_dependency_file_path_with_auto_capture(self): + """Test creates from auto_capture.""" settings = _DependencySettings.from_dependency_file_path("auto_capture") - assert settings.dependency_file == "env_snapshot.yml" - def test_from_dependency_file_path_none(self): - """Test from_dependency_file_path with None""" - settings = _DependencySettings.from_dependency_file_path(None) + def test_from_dependency_file_path_with_path(self): + """Test creates from file path.""" + settings = _DependencySettings.from_dependency_file_path("/path/to/requirements.txt") + assert settings.dependency_file == "requirements.txt" - assert settings.dependency_file is None + +class TestGetLogger: + """Test get_logger function.""" + + def test_returns_logger(self): + """Test returns logger instance.""" + logger = get_logger() + assert logger is not None + assert logger.name == "sagemaker.remote_function" class TestRuntimeEnvironmentManager: - """Test cases for RuntimeEnvironmentManager class""" + """Test RuntimeEnvironmentManager class.""" def test_init(self): - """Test initialization""" + """Test initialization.""" manager = RuntimeEnvironmentManager() - assert manager is not None - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile" - ) - def test_snapshot_with_requirements_txt(self, mock_isfile): - """Test snapshot with requirements.txt""" - mock_isfile.return_value = True + @patch("os.path.isfile") + def test_snapshot_returns_none_for_none(self, mock_isfile): + """Test snapshot returns None when dependencies is None.""" manager = RuntimeEnvironmentManager() + result = manager.snapshot(None) + assert result is None - result = manager.snapshot("requirements.txt") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._capture_from_local_runtime") + def test_snapshot_auto_capture(self, mock_capture): + """Test snapshot with auto_capture.""" + mock_capture.return_value = "/path/to/env_snapshot.yml" + manager = RuntimeEnvironmentManager() + result = manager.snapshot("auto_capture") + assert result == "/path/to/env_snapshot.yml" + mock_capture.assert_called_once() + @patch("os.path.isfile") + def test_snapshot_with_txt_file(self, mock_isfile): + """Test snapshot with requirements.txt file.""" + mock_isfile.return_value = True + manager = RuntimeEnvironmentManager() + result = manager.snapshot("requirements.txt") assert result == "requirements.txt" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile" - ) - def test_snapshot_with_conda_yml(self, mock_isfile): - """Test snapshot with conda environment.yml""" + @patch("os.path.isfile") + def test_snapshot_with_yml_file(self, mock_isfile): + """Test snapshot with conda.yml file.""" mock_isfile.return_value = True manager = RuntimeEnvironmentManager() - result = manager.snapshot("environment.yml") - assert result == "environment.yml" - @patch.object(RuntimeEnvironmentManager, "_capture_from_local_runtime") - def test_snapshot_with_auto_capture(self, mock_capture): - """Test snapshot with auto_capture""" - mock_capture.return_value = "env_snapshot.yml" - manager = RuntimeEnvironmentManager() - - result = manager.snapshot("auto_capture") - - assert result == "env_snapshot.yml" - mock_capture.assert_called_once() - - def test_snapshot_with_none(self): - """Test snapshot with None""" - manager = RuntimeEnvironmentManager() - - result = manager.snapshot(None) - - assert result is None - - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile" - ) - def test_snapshot_with_invalid_file(self, mock_isfile): - """Test snapshot with invalid file""" + @patch("os.path.isfile") + def test_snapshot_raises_error_for_invalid_file(self, mock_isfile): + """Test snapshot raises error for invalid file.""" mock_isfile.return_value = False manager = RuntimeEnvironmentManager() + with pytest.raises(ValueError): + manager.snapshot("requirements.txt") - with pytest.raises(ValueError, match="No dependencies file named"): - manager.snapshot("invalid.txt") - - @patch.object(RuntimeEnvironmentManager, "_get_active_conda_env_name") - @patch.object(RuntimeEnvironmentManager, "_get_active_conda_env_prefix") - @patch.object(RuntimeEnvironmentManager, "_export_conda_env_from_prefix") - def test_capture_from_local_runtime_with_conda_env(self, mock_export, mock_prefix, mock_name): - """Test _capture_from_local_runtime with conda environment""" - mock_name.return_value = "myenv" - mock_prefix.return_value = "/opt/conda/envs/myenv" + def test_snapshot_raises_error_for_invalid_format(self): + """Test snapshot raises error for invalid format.""" manager = RuntimeEnvironmentManager() + with pytest.raises(ValueError): + manager.snapshot("invalid.json") - result = manager._capture_from_local_runtime() - - assert "env_snapshot.yml" in result - mock_export.assert_called_once() - - @patch.object(RuntimeEnvironmentManager, "_get_active_conda_env_name") - @patch.object(RuntimeEnvironmentManager, "_get_active_conda_env_prefix") - def test_capture_from_local_runtime_no_conda_env(self, mock_prefix, mock_name): - """Test _capture_from_local_runtime without conda environment""" - mock_name.return_value = None - mock_prefix.return_value = None - manager = RuntimeEnvironmentManager() - - with pytest.raises(ValueError, match="No conda environment"): - manager._capture_from_local_runtime() - - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.getenv" - ) + @patch("os.getenv") def test_get_active_conda_env_prefix(self, mock_getenv): - """Test _get_active_conda_env_prefix""" + """Test gets active conda environment prefix.""" mock_getenv.return_value = "/opt/conda/envs/myenv" manager = RuntimeEnvironmentManager() - result = manager._get_active_conda_env_prefix() - assert result == "/opt/conda/envs/myenv" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.getenv" - ) + @patch("os.getenv") def test_get_active_conda_env_name(self, mock_getenv): - """Test _get_active_conda_env_name""" + """Test gets active conda environment name.""" mock_getenv.return_value = "myenv" manager = RuntimeEnvironmentManager() - result = manager._get_active_conda_env_name() - assert result == "myenv" - @patch.object(RuntimeEnvironmentManager, "_install_req_txt_in_conda_env") - @patch.object(RuntimeEnvironmentManager, "_write_conda_env_to_file") - def test_bootstrap_with_requirements_txt_and_conda_env(self, mock_write, mock_install): - """Test bootstrap with requirements.txt and conda environment""" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._export_conda_env_from_prefix") + @patch("os.getcwd") + @patch("os.getenv") + def test_capture_from_local_runtime(self, mock_getenv, mock_getcwd, mock_export): + """Test captures from local runtime.""" + mock_getenv.side_effect = lambda x: "myenv" if x == "CONDA_DEFAULT_ENV" else "/opt/conda/envs/myenv" + mock_getcwd.return_value = "/tmp" manager = RuntimeEnvironmentManager() + result = manager._capture_from_local_runtime() + assert result == "/tmp/env_snapshot.yml" + mock_export.assert_called_once() - manager.bootstrap( - local_dependencies_file="requirements.txt", - client_python_version="3.8", - conda_env="myenv", - ) - - mock_install.assert_called_once_with("myenv", "requirements.txt") - mock_write.assert_called_once_with("myenv") - - @patch.object(RuntimeEnvironmentManager, "_install_requirements_txt") - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._python_executable" - ) - def test_bootstrap_with_requirements_txt_no_conda_env(self, mock_python_exec, mock_install): - """Test bootstrap with requirements.txt without conda environment""" - mock_python_exec.return_value = "/usr/bin/python3" + @patch("os.getenv") + def test_capture_from_local_runtime_raises_error_no_conda(self, mock_getenv): + """Test raises error when no conda environment active.""" + mock_getenv.return_value = None manager = RuntimeEnvironmentManager() + with pytest.raises(ValueError): + manager._capture_from_local_runtime() - manager.bootstrap(local_dependencies_file="requirements.txt", client_python_version="3.8") - + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._install_requirements_txt") + def test_bootstrap_with_txt_file_no_conda(self, mock_install): + """Test bootstrap with requirements.txt without conda.""" + manager = RuntimeEnvironmentManager() + manager.bootstrap("requirements.txt", "3.8", None) mock_install.assert_called_once() - @patch.object(RuntimeEnvironmentManager, "_update_conda_env") - @patch.object(RuntimeEnvironmentManager, "_write_conda_env_to_file") - def test_bootstrap_with_conda_yml_and_conda_env(self, mock_write, mock_update): - """Test bootstrap with conda yml and existing conda environment""" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._install_req_txt_in_conda_env") + def test_bootstrap_with_txt_file_with_conda(self, mock_install, mock_write): + """Test bootstrap with requirements.txt with conda.""" manager = RuntimeEnvironmentManager() + manager.bootstrap("requirements.txt", "3.8", "myenv") + mock_install.assert_called_once() + mock_write.assert_called_once() - manager.bootstrap( - local_dependencies_file="environment.yml", - client_python_version="3.8", - conda_env="myenv", - ) - + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._update_conda_env") + def test_bootstrap_with_yml_file_with_conda(self, mock_update, mock_write): + """Test bootstrap with conda.yml with existing conda env.""" + manager = RuntimeEnvironmentManager() + manager.bootstrap("environment.yml", "3.8", "myenv") mock_update.assert_called_once() mock_write.assert_called_once() - @patch.object(RuntimeEnvironmentManager, "_create_conda_env") - @patch.object(RuntimeEnvironmentManager, "_validate_python_version") - @patch.object(RuntimeEnvironmentManager, "_write_conda_env_to_file") - def test_bootstrap_with_conda_yml_no_conda_env(self, mock_write, mock_validate, mock_create): - """Test bootstrap with conda yml without existing conda environment""" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._write_conda_env_to_file") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._validate_python_version") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._create_conda_env") + def test_bootstrap_with_yml_file_without_conda(self, mock_create, mock_validate, mock_write): + """Test bootstrap with conda.yml without existing conda env.""" manager = RuntimeEnvironmentManager() - - manager.bootstrap(local_dependencies_file="environment.yml", client_python_version="3.8") - + manager.bootstrap("environment.yml", "3.8", None) mock_create.assert_called_once() mock_validate.assert_called_once() mock_write.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script" - ) - def test_run_pre_exec_script_exists(self, mock_run_script, mock_isfile): - """Test run_pre_exec_script when script exists""" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script") + @patch("os.path.isfile") + def test_run_pre_exec_script_exists(self, mock_isfile, mock_run_script): + """Test runs pre-execution script when it exists.""" mock_isfile.return_value = True mock_run_script.return_value = (0, "") manager = RuntimeEnvironmentManager() - manager.run_pre_exec_script("/path/to/script.sh") - mock_run_script.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.os.path.isfile" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script" - ) - def test_run_pre_exec_script_fails(self, mock_run_script, mock_isfile): - """Test run_pre_exec_script when script fails""" + @patch("os.path.isfile") + def test_run_pre_exec_script_not_exists(self, mock_isfile): + """Test handles pre-execution script not existing.""" + mock_isfile.return_value = False + manager = RuntimeEnvironmentManager() + # Should not raise exception + manager.run_pre_exec_script("/path/to/script.sh") + + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_pre_execution_command_script") + @patch("os.path.isfile") + def test_run_pre_exec_script_raises_error_on_failure(self, mock_isfile, mock_run_script): + """Test raises error when pre-execution script fails.""" mock_isfile.return_value = True mock_run_script.return_value = (1, "Error message") manager = RuntimeEnvironmentManager() - - with pytest.raises(RuntimeEnvironmentError, match="Encountered error"): + with pytest.raises(RuntimeEnvironmentError): manager.run_pre_exec_script("/path/to/script.sh") - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.run" - ) + @patch("subprocess.run") def test_change_dir_permission_success(self, mock_run): - """Test change_dir_permission successfully""" + """Test changes directory permissions successfully.""" manager = RuntimeEnvironmentManager() - manager.change_dir_permission(["/tmp/dir1", "/tmp/dir2"], "777") - mock_run.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.run" - ) - def test_change_dir_permission_failure(self, mock_run): - """Test change_dir_permission with failure""" - mock_run.side_effect = subprocess.CalledProcessError( - 1, "chmod", stderr=b"Permission denied" - ) + @patch("subprocess.run") + def test_change_dir_permission_raises_error_on_failure(self, mock_run): + """Test raises error when permission change fails.""" + mock_run.side_effect = subprocess.CalledProcessError(1, "chmod", stderr=b"Permission denied") manager = RuntimeEnvironmentManager() + with pytest.raises(RuntimeEnvironmentError): + manager.change_dir_permission(["/tmp/dir1"], "777") + @patch("subprocess.run") + def test_change_dir_permission_raises_error_no_sudo(self, mock_run): + """Test raises error when sudo not found.""" + mock_run.side_effect = FileNotFoundError("[Errno 2] No such file or directory: 'sudo'") + manager = RuntimeEnvironmentManager() with pytest.raises(RuntimeEnvironmentError): - manager.change_dir_permission(["/tmp/dir"], "777") + manager.change_dir_permission(["/tmp/dir1"], "777") - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd" - ) + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") def test_install_requirements_txt(self, mock_run_cmd): - """Test _install_requirements_txt""" + """Test installs requirements.txt.""" manager = RuntimeEnvironmentManager() - - manager._install_requirements_txt("/path/to/requirements.txt", "/usr/bin/python3") - + manager._install_requirements_txt("/path/to/requirements.txt", "/usr/bin/python") mock_run_cmd.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd" - ) - @patch.object(RuntimeEnvironmentManager, "_get_conda_exe") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") def test_create_conda_env(self, mock_get_conda, mock_run_cmd): - """Test _create_conda_env""" + """Test creates conda environment.""" mock_get_conda.return_value = "conda" manager = RuntimeEnvironmentManager() - manager._create_conda_env("myenv", "/path/to/environment.yml") + mock_run_cmd.assert_called_once() + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") + def test_install_req_txt_in_conda_env(self, mock_get_conda, mock_run_cmd): + """Test installs requirements.txt in conda environment.""" + mock_get_conda.return_value = "conda" + manager = RuntimeEnvironmentManager() + manager._install_req_txt_in_conda_env("myenv", "/path/to/requirements.txt") mock_run_cmd.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd" - ) - @patch.object(RuntimeEnvironmentManager, "_get_conda_exe") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._run_shell_cmd") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") def test_update_conda_env(self, mock_get_conda, mock_run_cmd): - """Test _update_conda_env""" + """Test updates conda environment.""" mock_get_conda.return_value = "conda" manager = RuntimeEnvironmentManager() - manager._update_conda_env("myenv", "/path/to/environment.yml") - mock_run_cmd.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen" - ) - def test_get_conda_exe_mamba(self, mock_popen): - """Test _get_conda_exe returns mamba""" - mock_process = Mock() + @patch("builtins.open", new_callable=mock_open) + @patch("subprocess.Popen") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") + def test_export_conda_env_from_prefix(self, mock_get_conda, mock_popen, mock_file): + """Test exports conda environment.""" + mock_get_conda.return_value = "conda" + mock_process = MagicMock() + mock_process.communicate.return_value = (b"env output", b"") mock_process.wait.return_value = 0 mock_popen.return_value = mock_process + manager = RuntimeEnvironmentManager() + manager._export_conda_env_from_prefix("/opt/conda/envs/myenv", "/tmp/env.yml") + + mock_popen.assert_called_once() + mock_file.assert_called_once_with("/tmp/env.yml", "w") + @patch("builtins.open", new_callable=mock_open) + @patch("os.getcwd") + def test_write_conda_env_to_file(self, mock_getcwd, mock_file): + """Test writes conda environment name to file.""" + mock_getcwd.return_value = "/tmp" + manager = RuntimeEnvironmentManager() + manager._write_conda_env_to_file("myenv") + mock_file.assert_called_once_with("/tmp/remote_function_conda_env.txt", "w") + mock_file().write.assert_called_once_with("myenv") + + @patch("subprocess.Popen") + def test_get_conda_exe_returns_mamba(self, mock_popen): + """Test returns mamba when available.""" + mock_popen.return_value.wait.side_effect = [0, 1] # mamba exists, conda doesn't + manager = RuntimeEnvironmentManager() result = manager._get_conda_exe() - assert result == "mamba" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen" - ) - def test_get_conda_exe_conda(self, mock_popen): - """Test _get_conda_exe returns conda""" - mock_process = Mock() - mock_process.wait.side_effect = [1, 0] # mamba not found, conda found - mock_popen.return_value = mock_process + @patch("subprocess.Popen") + def test_get_conda_exe_returns_conda(self, mock_popen): + """Test returns conda when mamba not available.""" + mock_popen.return_value.wait.side_effect = [1, 0] # mamba doesn't exist, conda does manager = RuntimeEnvironmentManager() - result = manager._get_conda_exe() - assert result == "conda" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen" - ) - def test_get_conda_exe_not_found(self, mock_popen): - """Test _get_conda_exe when neither mamba nor conda found""" - mock_process = Mock() - mock_process.wait.return_value = 1 - mock_popen.return_value = mock_process + @patch("subprocess.Popen") + def test_get_conda_exe_raises_error(self, mock_popen): + """Test raises error when neither conda nor mamba available.""" + mock_popen.return_value.wait.return_value = 1 manager = RuntimeEnvironmentManager() - - with pytest.raises(ValueError, match="Neither conda nor mamba"): + with pytest.raises(ValueError): manager._get_conda_exe() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.check_output" - ) - @patch.object(RuntimeEnvironmentManager, "_get_conda_exe") + @patch("subprocess.check_output") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") def test_python_version_in_conda_env(self, mock_get_conda, mock_check_output): - """Test _python_version_in_conda_env""" + """Test gets Python version in conda environment.""" mock_get_conda.return_value = "conda" mock_check_output.return_value = b"Python 3.8.10" manager = RuntimeEnvironmentManager() - result = manager._python_version_in_conda_env("myenv") - assert result == "3.8" - def test_current_python_version(self): - """Test _current_python_version""" + @patch("subprocess.check_output") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._get_conda_exe") + def test_python_version_in_conda_env_raises_error(self, mock_get_conda, mock_check_output): + """Test raises error when getting Python version fails.""" + mock_get_conda.return_value = "conda" + mock_check_output.side_effect = subprocess.CalledProcessError(1, "conda", output=b"Error") manager = RuntimeEnvironmentManager() + with pytest.raises(RuntimeEnvironmentError): + manager._python_version_in_conda_env("myenv") + def test_current_python_version(self): + """Test gets current Python version.""" + manager = RuntimeEnvironmentManager() result = manager._current_python_version() + expected = f"{sys.version_info.major}.{sys.version_info.minor}" + assert result == expected - assert result == f"{sys.version_info.major}.{sys.version_info.minor}" - - @patch.object(RuntimeEnvironmentManager, "_python_version_in_conda_env") - def test_validate_python_version_match(self, mock_python_version): - """Test _validate_python_version when versions match""" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._python_version_in_conda_env") + def test_validate_python_version_with_conda(self, mock_python_version): + """Test validates Python version with conda environment.""" mock_python_version.return_value = "3.8" manager = RuntimeEnvironmentManager() + # Should not raise exception + manager._validate_python_version("3.8", "myenv") - # Should not raise error - manager._validate_python_version("3.8", conda_env="myenv") - - @patch.object(RuntimeEnvironmentManager, "_python_version_in_conda_env") - def test_validate_python_version_mismatch(self, mock_python_version): - """Test _validate_python_version when versions don't match""" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._python_version_in_conda_env") + def test_validate_python_version_mismatch_with_conda(self, mock_python_version): + """Test raises error on Python version mismatch with conda.""" mock_python_version.return_value = "3.9" manager = RuntimeEnvironmentManager() + with pytest.raises(RuntimeEnvironmentError): + manager._validate_python_version("3.8", "myenv") - with pytest.raises(RuntimeEnvironmentError, match="does not match"): - manager._validate_python_version("3.8", conda_env="myenv") - - @patch.object(RuntimeEnvironmentManager, "_current_sagemaker_pysdk_version") - def test_validate_sagemaker_pysdk_version_match(self, mock_version): - """Test _validate_sagemaker_pysdk_version when versions match""" - mock_version.return_value = "2.0.0" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_python_version") + def test_validate_python_version_without_conda(self, mock_current_version): + """Test validates Python version without conda environment.""" + mock_current_version.return_value = "3.8" manager = RuntimeEnvironmentManager() + # Should not raise exception + manager._validate_python_version("3.8", None) - # Should not raise error, just log warning - manager._validate_sagemaker_pysdk_version("2.0.0") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_python_version") + def test_validate_python_version_mismatch_without_conda(self, mock_current_version): + """Test raises error on Python version mismatch without conda.""" + mock_current_version.return_value = "3.9" + manager = RuntimeEnvironmentManager() + with pytest.raises(RuntimeEnvironmentError): + manager._validate_python_version("3.8", None) - @patch.object(RuntimeEnvironmentManager, "_current_sagemaker_pysdk_version") - def test_validate_sagemaker_pysdk_version_mismatch(self, mock_version): - """Test _validate_sagemaker_pysdk_version when versions don't match""" - mock_version.return_value = "2.1.0" + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") + def test_validate_sagemaker_pysdk_version_match(self, mock_current_version): + """Test validates matching SageMaker SDK version.""" + mock_current_version.return_value = "2.100.0" manager = RuntimeEnvironmentManager() + # Should not raise exception or warning + manager._validate_sagemaker_pysdk_version("2.100.0") - # Should log warning but not raise error - manager._validate_sagemaker_pysdk_version("2.0.0") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") + def test_validate_sagemaker_pysdk_version_mismatch(self, mock_current_version): + """Test logs warning on SageMaker SDK version mismatch.""" + mock_current_version.return_value = "2.101.0" + manager = RuntimeEnvironmentManager() + # Should log warning but not raise exception + manager._validate_sagemaker_pysdk_version("2.100.0") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager._current_sagemaker_pysdk_version") + def test_validate_sagemaker_pysdk_version_none(self, mock_current_version): + """Test handles None client version.""" + mock_current_version.return_value = "2.100.0" + manager = RuntimeEnvironmentManager() + # Should not raise exception + manager._validate_sagemaker_pysdk_version(None) -class TestHelperFunctions: - """Test cases for helper functions""" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.check_output" - ) - def test_run_and_get_output_shell_cmd(self, mock_check_output): - """Test _run_and_get_output_shell_cmd""" - mock_check_output.return_value = b"output" +class TestRunAndGetOutputShellCmd: + """Test _run_and_get_output_shell_cmd function.""" + @patch("subprocess.check_output") + def test_runs_command_successfully(self, mock_check_output): + """Test runs command and returns output.""" + mock_check_output.return_value = b"command output" result = _run_and_get_output_shell_cmd("echo test") + assert result == "command output" + - assert result == "output" - - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error" - ) - def test_run_pre_execution_command_script(self, mock_log_error, mock_log_output, mock_popen): - """Test _run_pre_execution_command_script""" - mock_process = Mock() +class TestRunPreExecutionCommandScript: + """Test _run_pre_execution_command_script function.""" + + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") + @patch("subprocess.Popen") + @patch("os.path.dirname") + def test_runs_script_successfully(self, mock_dirname, mock_popen, mock_log_output, mock_log_error): + """Test runs script successfully.""" + mock_dirname.return_value = "/tmp" + mock_process = MagicMock() mock_process.wait.return_value = 0 mock_popen.return_value = mock_process mock_log_error.return_value = "" + + return_code, error_logs = _run_pre_execution_command_script("/tmp/script.sh") + + assert return_code == 0 + assert error_logs == "" + + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") + @patch("subprocess.Popen") + @patch("os.path.dirname") + def test_runs_script_with_error(self, mock_dirname, mock_popen, mock_log_output, mock_log_error): + """Test runs script that returns error.""" + mock_dirname.return_value = "/tmp" + mock_process = MagicMock() + mock_process.wait.return_value = 1 + mock_popen.return_value = mock_process + mock_log_error.return_value = "Error message" + + return_code, error_logs = _run_pre_execution_command_script("/tmp/script.sh") + + assert return_code == 1 + assert error_logs == "Error message" - return_code, error_logs = _run_pre_execution_command_script("/path/to/script.sh") - assert return_code == 0 +class TestRunShellCmd: + """Test _run_shell_cmd function.""" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error" - ) - def test_run_shell_cmd_success(self, mock_log_error, mock_log_output, mock_popen): - """Test _run_shell_cmd with successful command""" - mock_process = Mock() + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") + @patch("subprocess.Popen") + def test_runs_command_successfully(self, mock_popen, mock_log_output, mock_log_error): + """Test runs command successfully.""" + mock_process = MagicMock() mock_process.wait.return_value = 0 mock_popen.return_value = mock_process mock_log_error.return_value = "" @@ -500,63 +502,71 @@ def test_run_shell_cmd_success(self, mock_log_error, mock_log_output, mock_popen mock_popen.assert_called_once() - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.subprocess.Popen" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output" - ) - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error" - ) - def test_run_shell_cmd_failure(self, mock_log_error, mock_log_output, mock_popen): - """Test _run_shell_cmd with failed command""" - mock_process = Mock() + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_error") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager._log_output") + @patch("subprocess.Popen") + def test_runs_command_raises_error_on_failure(self, mock_popen, mock_log_output, mock_log_error): + """Test raises error when command fails.""" + mock_process = MagicMock() mock_process.wait.return_value = 1 mock_popen.return_value = mock_process mock_log_error.return_value = "Error message" - - with pytest.raises(RuntimeEnvironmentError, match="Encountered error"): + + with pytest.raises(RuntimeEnvironmentError): _run_shell_cmd(["false"]) - def test_python_executable(self): - """Test _python_executable""" - result = _python_executable() - assert result == sys.executable +class TestLogOutput: + """Test _log_output function.""" - @patch( - "sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.sys.executable", - None, - ) - def test_python_executable_not_found(self): - """Test _python_executable when not found""" - with pytest.raises(RuntimeEnvironmentError, match="Failed to retrieve"): - _python_executable() + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.logger") + def test_logs_output(self, mock_logger): + """Test logs process output.""" + from io import BytesIO + mock_process = MagicMock() + mock_process.stdout = BytesIO(b"line1\nline2\n") + + _log_output(mock_process) + + assert mock_logger.info.call_count == 2 -class TestRuntimeEnvironmentError: - """Test cases for RuntimeEnvironmentError exception""" +class TestLogError: + """Test _log_error function.""" - def test_init(self): - """Test initialization""" - error = RuntimeEnvironmentError("Test error message") + @patch("sagemaker.core.remote_function.runtime_environment.runtime_environment_manager.logger") + def test_logs_error(self, mock_logger): + """Test logs process errors.""" + from io import BytesIO + mock_process = MagicMock() + mock_process.stderr = BytesIO(b"ERROR: error message\nwarning message\n") + + error_logs = _log_error(mock_process) + + assert "ERROR: error message" in error_logs + assert "warning message" in error_logs - assert error.message == "Test error message" - assert str(error) == "Test error message" - def test_raise(self): - """Test raising the exception""" - with pytest.raises(RuntimeEnvironmentError, match="Test error"): - raise RuntimeEnvironmentError("Test error") +class TestPythonExecutable: + """Test _python_executable function.""" + def test_returns_python_executable(self): + """Test returns Python executable path.""" + result = _python_executable() + assert result == sys.executable -class TestGetLogger: - """Test cases for get_logger function""" + @patch("sys.executable", None) + def test_raises_error_if_no_executable(self): + """Test raises error if no Python executable.""" + with pytest.raises(RuntimeEnvironmentError): + _python_executable() - def test_get_logger(self): - """Test get_logger returns logger""" - logger = get_logger() - assert logger is not None - assert logger.name == "sagemaker.remote_function" +class TestRuntimeEnvironmentError: + """Test RuntimeEnvironmentError class.""" + + def test_creates_error_with_message(self): + """Test creates error with message.""" + error = RuntimeEnvironmentError("Test error") + assert str(error) == "Test error" + assert error.message == "Test error" diff --git a/sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py b/sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py new file mode 100644 index 0000000000..98a5f8bcc8 --- /dev/null +++ b/sagemaker-core/tests/unit/remote_function/test_checkpoint_location.py @@ -0,0 +1,82 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for checkpoint_location module.""" +from __future__ import absolute_import + +import pytest +from sagemaker.core.remote_function.checkpoint_location import ( + CheckpointLocation, + _validate_s3_uri_for_checkpoint, + _JOB_CHECKPOINT_LOCATION, +) + + +class TestValidateS3Uri: + """Test _validate_s3_uri_for_checkpoint function.""" + + def test_valid_s3_uri(self): + """Test valid s3:// URI.""" + assert _validate_s3_uri_for_checkpoint("s3://my-bucket/path/to/checkpoints") + + def test_valid_https_uri(self): + """Test valid https:// URI.""" + assert _validate_s3_uri_for_checkpoint("https://my-bucket.s3.amazonaws.com/path") + + def test_valid_s3_uri_no_path(self): + """Test valid s3:// URI without path.""" + assert _validate_s3_uri_for_checkpoint("s3://my-bucket") + + def test_invalid_uri_no_protocol(self): + """Test invalid URI without protocol.""" + assert not _validate_s3_uri_for_checkpoint("my-bucket/path") + + def test_invalid_uri_wrong_protocol(self): + """Test invalid URI with wrong protocol.""" + assert not _validate_s3_uri_for_checkpoint("http://my-bucket/path") + + def test_invalid_uri_empty(self): + """Test invalid empty URI.""" + assert not _validate_s3_uri_for_checkpoint("") + + +class TestCheckpointLocation: + """Test CheckpointLocation class.""" + + def test_init_with_valid_s3_uri(self): + """Test initialization with valid s3 URI.""" + s3_uri = "s3://my-bucket/checkpoints" + checkpoint_loc = CheckpointLocation(s3_uri) + assert checkpoint_loc._s3_uri == s3_uri + + def test_init_with_valid_https_uri(self): + """Test initialization with valid https URI.""" + s3_uri = "https://my-bucket.s3.amazonaws.com/checkpoints" + checkpoint_loc = CheckpointLocation(s3_uri) + assert checkpoint_loc._s3_uri == s3_uri + + def test_init_with_invalid_uri_raises_error(self): + """Test initialization with invalid URI raises ValueError.""" + with pytest.raises(ValueError, match="CheckpointLocation should be specified with valid s3 URI"): + CheckpointLocation("invalid-uri") + + def test_fspath_returns_local_path(self): + """Test __fspath__ returns the job local path.""" + checkpoint_loc = CheckpointLocation("s3://my-bucket/checkpoints") + assert checkpoint_loc.__fspath__() == _JOB_CHECKPOINT_LOCATION + + def test_can_be_used_as_pathlike(self): + """Test CheckpointLocation can be used as os.PathLike.""" + import os + checkpoint_loc = CheckpointLocation("s3://my-bucket/checkpoints") + path = os.fspath(checkpoint_loc) + assert path == _JOB_CHECKPOINT_LOCATION diff --git a/sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py b/sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py new file mode 100644 index 0000000000..5145a77adf --- /dev/null +++ b/sagemaker-core/tests/unit/remote_function/test_custom_file_filter.py @@ -0,0 +1,169 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for custom_file_filter module.""" +from __future__ import absolute_import + +import os +import tempfile +import shutil +from unittest.mock import patch, MagicMock +import pytest + +from sagemaker.core.remote_function.custom_file_filter import ( + CustomFileFilter, + resolve_custom_file_filter_from_config_file, + copy_workdir, +) + + +class TestCustomFileFilter: + """Test CustomFileFilter class.""" + + def test_init_with_no_patterns(self): + """Test initialization without ignore patterns.""" + filter_obj = CustomFileFilter() + assert filter_obj.ignore_name_patterns == [] + assert filter_obj.workdir == os.getcwd() + + def test_init_with_patterns(self): + """Test initialization with ignore patterns.""" + patterns = ["*.pyc", "__pycache__", "*.log"] + filter_obj = CustomFileFilter(ignore_name_patterns=patterns) + assert filter_obj.ignore_name_patterns == patterns + + def test_ignore_name_patterns_property(self): + """Test ignore_name_patterns property.""" + patterns = ["*.txt", "temp*"] + filter_obj = CustomFileFilter(ignore_name_patterns=patterns) + assert filter_obj.ignore_name_patterns == patterns + + def test_workdir_property(self): + """Test workdir property.""" + filter_obj = CustomFileFilter() + assert filter_obj.workdir == os.getcwd() + + +class TestResolveCustomFileFilterFromConfigFile: + """Test resolve_custom_file_filter_from_config_file function.""" + + def test_returns_direct_input_when_provided_as_filter(self): + """Test returns direct input when CustomFileFilter is provided.""" + filter_obj = CustomFileFilter(ignore_name_patterns=["*.pyc"]) + result = resolve_custom_file_filter_from_config_file(direct_input=filter_obj) + assert result is filter_obj + + def test_returns_direct_input_when_provided_as_callable(self): + """Test returns direct input when callable is provided.""" + def custom_filter(path, names): + return [] + result = resolve_custom_file_filter_from_config_file(direct_input=custom_filter) + assert result is custom_filter + + @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config") + def test_returns_none_when_no_config(self, mock_resolve): + """Test returns None when no config is found.""" + mock_resolve.return_value = None + result = resolve_custom_file_filter_from_config_file() + assert result is None + + @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config") + def test_creates_filter_from_config(self, mock_resolve): + """Test creates CustomFileFilter from config.""" + patterns = ["*.pyc", "*.log"] + mock_resolve.return_value = patterns + result = resolve_custom_file_filter_from_config_file() + assert isinstance(result, CustomFileFilter) + assert result.ignore_name_patterns == patterns + + @patch("sagemaker.core.remote_function.custom_file_filter.resolve_value_from_config") + def test_passes_sagemaker_session_to_resolve(self, mock_resolve): + """Test passes sagemaker_session to resolve_value_from_config.""" + mock_session = MagicMock() + mock_resolve.return_value = None + resolve_custom_file_filter_from_config_file(sagemaker_session=mock_session) + mock_resolve.assert_called_once() + assert mock_resolve.call_args[1]["sagemaker_session"] == mock_session + + +class TestCopyWorkdir: + """Test copy_workdir function.""" + + def setup_method(self): + """Set up test fixtures.""" + self.temp_src = tempfile.mkdtemp() + self.temp_dst = tempfile.mkdtemp() + + # Create test files + with open(os.path.join(self.temp_src, "test.py"), "w") as f: + f.write("print('test')") + with open(os.path.join(self.temp_src, "test.txt"), "w") as f: + f.write("text file") + os.makedirs(os.path.join(self.temp_src, "__pycache__")) + with open(os.path.join(self.temp_src, "__pycache__", "test.pyc"), "w") as f: + f.write("compiled") + + def teardown_method(self): + """Clean up test fixtures.""" + if os.path.exists(self.temp_src): + shutil.rmtree(self.temp_src) + if os.path.exists(self.temp_dst): + shutil.rmtree(self.temp_dst) + + @patch("os.getcwd") + def test_copy_workdir_without_filter_only_python_files(self, mock_getcwd): + """Test copy_workdir without filter copies only Python files.""" + mock_getcwd.return_value = self.temp_src + dst = os.path.join(self.temp_dst, "output") + + copy_workdir(dst) + + assert os.path.exists(os.path.join(dst, "test.py")) + assert not os.path.exists(os.path.join(dst, "test.txt")) + assert not os.path.exists(os.path.join(dst, "__pycache__")) + + @patch("os.getcwd") + def test_copy_workdir_with_callable_filter(self, mock_getcwd): + """Test copy_workdir with callable filter.""" + mock_getcwd.return_value = self.temp_src + dst = os.path.join(self.temp_dst, "output") + + def custom_filter(path, names): + return ["test.txt"] + + copy_workdir(dst, custom_file_filter=custom_filter) + + assert os.path.exists(os.path.join(dst, "test.py")) + assert not os.path.exists(os.path.join(dst, "test.txt")) + + def test_copy_workdir_with_custom_file_filter_object(self): + """Test copy_workdir with CustomFileFilter object.""" + filter_obj = CustomFileFilter(ignore_name_patterns=["*.py"]) + filter_obj._workdir = self.temp_src + dst = os.path.join(self.temp_dst, "output") + + copy_workdir(dst, custom_file_filter=filter_obj) + + assert not os.path.exists(os.path.join(dst, "test.py")) + assert os.path.exists(os.path.join(dst, "test.txt")) + + def test_copy_workdir_with_pattern_matching(self): + """Test copy_workdir with pattern matching in CustomFileFilter.""" + filter_obj = CustomFileFilter(ignore_name_patterns=["*.txt", "__pycache__"]) + filter_obj._workdir = self.temp_src + dst = os.path.join(self.temp_dst, "output") + + copy_workdir(dst, custom_file_filter=filter_obj) + + assert os.path.exists(os.path.join(dst, "test.py")) + assert not os.path.exists(os.path.join(dst, "test.txt")) + assert not os.path.exists(os.path.join(dst, "__pycache__")) diff --git a/sagemaker-core/tests/unit/remote_function/test_invoke_function.py b/sagemaker-core/tests/unit/remote_function/test_invoke_function.py new file mode 100644 index 0000000000..4810eba2e0 --- /dev/null +++ b/sagemaker-core/tests/unit/remote_function/test_invoke_function.py @@ -0,0 +1,280 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for invoke_function module.""" +from __future__ import absolute_import + +import json +import pytest +from unittest.mock import patch, MagicMock, call + +from sagemaker.core.remote_function.invoke_function import ( + _parse_args, + _get_sagemaker_session, + _load_run_object, + _load_pipeline_context, + _execute_remote_function, + main, + SUCCESS_EXIT_CODE, +) +from sagemaker.core.remote_function.job import KEY_EXPERIMENT_NAME, KEY_RUN_NAME + + +class TestParseArgs: + """Test _parse_args function.""" + + def test_parse_required_args(self): + """Test parsing required arguments.""" + args = [ + "--region", "us-west-2", + "--s3_base_uri", "s3://my-bucket/path", + ] + parsed = _parse_args(args) + assert parsed.region == "us-west-2" + assert parsed.s3_base_uri == "s3://my-bucket/path" + + def test_parse_all_args(self): + """Test parsing all arguments.""" + args = [ + "--region", "us-east-1", + "--s3_base_uri", "s3://bucket/path", + "--s3_kms_key", "key-123", + "--run_in_context", '{"experiment": "exp1"}', + "--pipeline_step_name", "step1", + "--pipeline_execution_id", "exec-123", + "--property_references", "prop1", "val1", "prop2", "val2", + "--serialize_output_to_json", "true", + "--func_step_s3_dir", "s3://bucket/func", + ] + parsed = _parse_args(args) + assert parsed.region == "us-east-1" + assert parsed.s3_base_uri == "s3://bucket/path" + assert parsed.s3_kms_key == "key-123" + assert parsed.run_in_context == '{"experiment": "exp1"}' + assert parsed.pipeline_step_name == "step1" + assert parsed.pipeline_execution_id == "exec-123" + assert parsed.property_references == ["prop1", "val1", "prop2", "val2"] + assert parsed.serialize_output_to_json is True + assert parsed.func_step_s3_dir == "s3://bucket/func" + + def test_parse_serialize_output_false(self): + """Test parsing serialize_output_to_json as false.""" + args = [ + "--region", "us-west-2", + "--s3_base_uri", "s3://bucket/path", + "--serialize_output_to_json", "false", + ] + parsed = _parse_args(args) + assert parsed.serialize_output_to_json is False + + def test_parse_default_values(self): + """Test default values for optional arguments.""" + args = [ + "--region", "us-west-2", + "--s3_base_uri", "s3://bucket/path", + ] + parsed = _parse_args(args) + assert parsed.s3_kms_key is None + assert parsed.run_in_context is None + assert parsed.pipeline_step_name is None + assert parsed.pipeline_execution_id is None + assert parsed.property_references == [] + assert parsed.serialize_output_to_json is False + assert parsed.func_step_s3_dir is None + + +class TestGetSagemakerSession: + """Test _get_sagemaker_session function.""" + + @patch("sagemaker.core.remote_function.invoke_function.boto3.session.Session") + @patch("sagemaker.core.remote_function.invoke_function.Session") + def test_creates_session_with_region(self, mock_session_class, mock_boto_session): + """Test creates SageMaker session with correct region.""" + mock_boto = MagicMock() + mock_boto_session.return_value = mock_boto + + _get_sagemaker_session("us-west-2") + + mock_boto_session.assert_called_once_with(region_name="us-west-2") + mock_session_class.assert_called_once_with(boto_session=mock_boto) + + +class TestLoadRunObject: + """Test _load_run_object function.""" + + @patch("sagemaker.core.experiments.run.Run") + def test_loads_run_from_json(self, mock_run_class): + """Test loads Run object from JSON string.""" + run_dict = { + KEY_EXPERIMENT_NAME: "my-experiment", + KEY_RUN_NAME: "my-run", + } + run_json = json.dumps(run_dict) + mock_session = MagicMock() + + _load_run_object(run_json, mock_session) + + mock_run_class.assert_called_once_with( + experiment_name="my-experiment", + run_name="my-run", + sagemaker_session=mock_session, + ) + + +class TestLoadPipelineContext: + """Test _load_pipeline_context function.""" + + def test_loads_context_with_all_fields(self): + """Test loads pipeline context with all fields.""" + args = MagicMock() + args.pipeline_step_name = "step1" + args.pipeline_execution_id = "exec-123" + args.property_references = ["prop1", "val1", "prop2", "val2"] + args.serialize_output_to_json = True + args.func_step_s3_dir = "s3://bucket/func" + + context = _load_pipeline_context(args) + + assert context.step_name == "step1" + assert context.execution_id == "exec-123" + assert context.property_references == {"prop1": "val1", "prop2": "val2"} + assert context.serialize_output_to_json is True + assert context.func_step_s3_dir == "s3://bucket/func" + + def test_loads_context_with_empty_property_references(self): + """Test loads pipeline context with empty property references.""" + args = MagicMock() + args.pipeline_step_name = "step1" + args.pipeline_execution_id = "exec-123" + args.property_references = [] + args.serialize_output_to_json = False + args.func_step_s3_dir = None + + context = _load_pipeline_context(args) + + assert context.property_references == {} + + +class TestExecuteRemoteFunction: + """Test _execute_remote_function function.""" + + @patch("sagemaker.core.remote_function.core.stored_function.StoredFunction") + def test_executes_without_run_context(self, mock_stored_function_class): + """Test executes stored function without run context.""" + mock_stored_func = MagicMock() + mock_stored_function_class.return_value = mock_stored_func + mock_session = MagicMock() + mock_context = MagicMock() + + _execute_remote_function( + sagemaker_session=mock_session, + s3_base_uri="s3://bucket/path", + s3_kms_key="key-123", + run_in_context=None, + context=mock_context, + ) + + mock_stored_function_class.assert_called_once_with( + sagemaker_session=mock_session, + s3_base_uri="s3://bucket/path", + s3_kms_key="key-123", + context=mock_context, + ) + mock_stored_func.load_and_invoke.assert_called_once() + + @patch("sagemaker.core.remote_function.invoke_function._load_run_object") + @patch("sagemaker.core.remote_function.core.stored_function.StoredFunction") + def test_executes_with_run_context(self, mock_stored_function_class, mock_load_run): + """Test executes stored function with run context.""" + mock_stored_func = MagicMock() + mock_stored_function_class.return_value = mock_stored_func + mock_run = MagicMock() + mock_load_run.return_value = mock_run + mock_session = MagicMock() + mock_context = MagicMock() + run_json = '{"experiment": "exp1"}' + + _execute_remote_function( + sagemaker_session=mock_session, + s3_base_uri="s3://bucket/path", + s3_kms_key=None, + run_in_context=run_json, + context=mock_context, + ) + + # Verify run object was loaded and used as context manager + mock_load_run.assert_called_once_with(run_json, mock_session) + mock_run.__enter__.assert_called_once() + mock_run.__exit__.assert_called_once() + + +class TestMain: + """Test main function.""" + + @patch("sagemaker.core.remote_function.invoke_function._execute_remote_function") + @patch("sagemaker.core.remote_function.invoke_function._get_sagemaker_session") + @patch("sagemaker.core.remote_function.invoke_function._load_pipeline_context") + @patch("sagemaker.core.remote_function.invoke_function._parse_args") + def test_main_success(self, mock_parse, mock_load_context, mock_get_session, mock_execute): + """Test main function successful execution.""" + mock_args = MagicMock() + mock_args.region = "us-west-2" + mock_args.s3_base_uri = "s3://bucket/path" + mock_args.s3_kms_key = None + mock_args.run_in_context = None + mock_parse.return_value = mock_args + + mock_context = MagicMock() + mock_context.step_name = None + mock_load_context.return_value = mock_context + + mock_session = MagicMock() + mock_get_session.return_value = mock_session + + with pytest.raises(SystemExit) as exc_info: + main(["--region", "us-west-2", "--s3_base_uri", "s3://bucket/path"]) + + assert exc_info.value.code == SUCCESS_EXIT_CODE + mock_execute.assert_called_once() + + @patch("sagemaker.core.remote_function.invoke_function.handle_error") + @patch("sagemaker.core.remote_function.invoke_function._execute_remote_function") + @patch("sagemaker.core.remote_function.invoke_function._get_sagemaker_session") + @patch("sagemaker.core.remote_function.invoke_function._load_pipeline_context") + @patch("sagemaker.core.remote_function.invoke_function._parse_args") + def test_main_handles_exception( + self, mock_parse, mock_load_context, mock_get_session, mock_execute, mock_handle_error + ): + """Test main function handles exceptions.""" + mock_args = MagicMock() + mock_args.region = "us-west-2" + mock_args.s3_base_uri = "s3://bucket/path" + mock_args.s3_kms_key = None + mock_args.run_in_context = None + mock_parse.return_value = mock_args + + mock_context = MagicMock() + mock_context.step_name = None + mock_load_context.return_value = mock_context + + mock_session = MagicMock() + mock_get_session.return_value = mock_session + + test_exception = Exception("Test error") + mock_execute.side_effect = test_exception + mock_handle_error.return_value = 1 + + with pytest.raises(SystemExit) as exc_info: + main(["--region", "us-west-2", "--s3_base_uri", "s3://bucket/path"]) + + assert exc_info.value.code == 1 + mock_handle_error.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/test_job.py b/sagemaker-core/tests/unit/remote_function/test_job.py index abc5be68be..6f10016643 100644 --- a/sagemaker-core/tests/unit/remote_function/test_job.py +++ b/sagemaker-core/tests/unit/remote_function/test_job.py @@ -143,26 +143,23 @@ class TestJob: def test_init(self, mock_session): """Test _Job initialization.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) assert job.job_name == "test-job" assert job.s3_uri == "s3://bucket/output" - assert job.hmac_key == "test-key" def test_from_describe_response(self, mock_session): """Test creating _Job from describe response.""" response = { "TrainingJobName": "test-job", "OutputDataConfig": {"S3OutputPath": "s3://bucket/output"}, - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "test-key"}, } job = _Job.from_describe_response(response, mock_session) assert job.job_name == "test-job" assert job.s3_uri == "s3://bucket/output" - assert job.hmac_key == "test-key" def test_describe_returns_cached_response(self, mock_session): """Test that describe returns cached response for completed jobs.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) job._last_describe_response = {"TrainingJobStatus": "Completed"} result = job.describe() @@ -171,7 +168,7 @@ def test_describe_returns_cached_response(self, mock_session): def test_describe_calls_api_for_in_progress_jobs(self, mock_session): """Test that describe calls API for in-progress jobs.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) mock_session.sagemaker_client.describe_training_job.return_value = { "TrainingJobStatus": "InProgress" } @@ -182,7 +179,7 @@ def test_describe_calls_api_for_in_progress_jobs(self, mock_session): def test_stop(self, mock_session): """Test stopping a job.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) job.stop() mock_session.sagemaker_client.stop_training_job.assert_called_once_with( TrainingJobName="test-job" @@ -191,7 +188,7 @@ def test_stop(self, mock_session): @patch("sagemaker.core.remote_function.job._logs_for_job") def test_wait(self, mock_logs, mock_session): """Test waiting for job completion.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) mock_logs.return_value = {"TrainingJobStatus": "Completed"} job.wait(timeout=100) @@ -882,7 +879,7 @@ def test_start(self, mock_get_name, mock_compile, mock_session): mock_get_name.return_value = "test-job" mock_compile.return_value = { "TrainingJobName": "test-job", - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "test-key"}, + "Environment": {}, } job_settings = Mock() diff --git a/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py b/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py index 4069029685..bc8d5a8e56 100644 --- a/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py +++ b/sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py @@ -144,17 +144,15 @@ def test_from_describe_response(self, mock_session): response = { "TrainingJobName": "test-job", "OutputDataConfig": {"S3OutputPath": "s3://bucket/output"}, - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "test-key"}, } job = _Job.from_describe_response(response, mock_session) assert job.job_name == "test-job" assert job.s3_uri == "s3://bucket/output" - assert job.hmac_key == "test-key" assert job._last_describe_response == response def test_describe_cached_completed(self, mock_session): """Test lines 865-871: describe with cached completed job.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) job._last_describe_response = {"TrainingJobStatus": "Completed"} result = job.describe() @@ -163,7 +161,7 @@ def test_describe_cached_completed(self, mock_session): def test_describe_cached_failed(self, mock_session): """Test lines 865-871: describe with cached failed job.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) job._last_describe_response = {"TrainingJobStatus": "Failed"} result = job.describe() @@ -172,7 +170,7 @@ def test_describe_cached_failed(self, mock_session): def test_describe_cached_stopped(self, mock_session): """Test lines 865-871: describe with cached stopped job.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) job._last_describe_response = {"TrainingJobStatus": "Stopped"} result = job.describe() @@ -181,7 +179,7 @@ def test_describe_cached_stopped(self, mock_session): def test_stop(self, mock_session): """Test lines 886-887: stop method.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) job.stop() mock_session.sagemaker_client.stop_training_job.assert_called_once_with( TrainingJobName="test-job" @@ -190,7 +188,7 @@ def test_stop(self, mock_session): @patch("sagemaker.core.remote_function.job._logs_for_job") def test_wait(self, mock_logs, mock_session): """Test lines 889-903: wait method.""" - job = _Job("test-job", "s3://bucket/output", mock_session, "test-key") + job = _Job("test-job", "s3://bucket/output", mock_session) mock_logs.return_value = {"TrainingJobStatus": "Completed"} job.wait(timeout=100) diff --git a/sagemaker-core/tests/unit/remote_function/test_logging_config.py b/sagemaker-core/tests/unit/remote_function/test_logging_config.py new file mode 100644 index 0000000000..6454ea1071 --- /dev/null +++ b/sagemaker-core/tests/unit/remote_function/test_logging_config.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for logging_config module.""" +from __future__ import absolute_import + +import logging +import time +from unittest.mock import patch +from sagemaker.core.remote_function.logging_config import _UTCFormatter, get_logger + + +class TestUTCFormatter: + """Test _UTCFormatter class.""" + + def test_converter_is_gmtime(self): + """Test that converter is set to gmtime.""" + formatter = _UTCFormatter() + assert formatter.converter == time.gmtime + + def test_formats_time_in_utc(self): + """Test that time is formatted in UTC.""" + formatter = _UTCFormatter("%(asctime)s") + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="test message", + args=(), + exc_info=None, + ) + formatted = formatter.format(record) + # Should contain UTC time format + assert formatted + + +class TestGetLogger: + """Test get_logger function.""" + + def test_returns_logger_with_correct_name(self): + """Test that logger has correct name.""" + logger = get_logger() + assert logger.name == "sagemaker.remote_function" + + def test_logger_has_info_level(self): + """Test that logger is set to INFO level.""" + logger = get_logger() + assert logger.level == logging.INFO + + def test_logger_has_handler(self): + """Test that logger has at least one handler.""" + logger = get_logger() + assert len(logger.handlers) > 0 + + def test_logger_handler_has_utc_formatter(self): + """Test that logger handler uses UTC formatter.""" + logger = get_logger() + handler = logger.handlers[0] + # Check that formatter has gmtime converter (UTC formatter characteristic) + assert handler.formatter.converter == time.gmtime + + def test_logger_does_not_propagate(self): + """Test that logger does not propagate to root logger.""" + logger = get_logger() + assert logger.propagate == 0 + + def test_get_logger_is_idempotent(self): + """Test that calling get_logger multiple times returns same logger.""" + logger1 = get_logger() + logger2 = get_logger() + assert logger1 is logger2 + + def test_logger_handler_is_stream_handler(self): + """Test that logger uses StreamHandler.""" + logger = get_logger() + assert isinstance(logger.handlers[0], logging.StreamHandler) diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py index 30fbba3639..9b4b9a191b 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py @@ -1036,7 +1036,6 @@ def get_function_step_result( return deserialize_obj_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_uri, - hmac_key=describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"], ) raise RemoteFunctionError(_ERROR_MSG_OF_STEP_INCOMPLETE) diff --git a/sagemaker-mlops/tests/unit/workflow/test_pipeline.py b/sagemaker-mlops/tests/unit/workflow/test_pipeline.py index 55922a66ca..9169f1ce7f 100644 --- a/sagemaker-mlops/tests/unit/workflow/test_pipeline.py +++ b/sagemaker-mlops/tests/unit/workflow/test_pipeline.py @@ -360,7 +360,6 @@ def test_get_function_step_result_incomplete_job(mock_session): "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT}, "OutputDataConfig": {"S3OutputPath": "s3://bucket/path"}, "TrainingJobStatus": "Failed", - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "key"} } with pytest.raises(RemoteFunctionError, match="not in Completed status"): @@ -376,7 +375,6 @@ def test_get_function_step_result_success(mock_session): "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT}, "OutputDataConfig": {"S3OutputPath": "s3://bucket/path/exec-id/step1/results"}, "TrainingJobStatus": "Completed", - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "key"} } with patch("sagemaker.mlops.workflow.pipeline.deserialize_obj_from_s3", return_value="result"): @@ -443,7 +441,6 @@ def test_pipeline_execution_result_terminal_failure(mock_session): "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT}, "OutputDataConfig": {"S3OutputPath": "s3://bucket/path/exec-id/step1/results"}, "TrainingJobStatus": "Completed", - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "key"} } with patch.object(execution, "wait", side_effect=WaiterError("name", "Waiter encountered a terminal failure state", {})): @@ -461,7 +458,6 @@ def test_get_function_step_result_obsolete_s3_path(mock_session): "AlgorithmSpecification": {"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT}, "OutputDataConfig": {"S3OutputPath": "s3://bucket/different/path"}, "TrainingJobStatus": "Completed", - "Environment": {"REMOTE_FUNCTION_SECRET_KEY": "key"} } with patch("sagemaker.mlops.workflow.pipeline.deserialize_obj_from_s3", return_value="result"):