diff --git a/CHANGELOG.md b/CHANGELOG.md index c84bf02a98..6e424f7191 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## v2.155.0 (2023-05-15) + +### Features + + * Add support for SageMaker Serverless inference Provisioned Concurrency feature + +### Bug Fixes and Other Changes + + * Revert "fix: make RemoteExecutor context manager non-blocking on pend… + * Add BOM to no No P2 Availability region list + ## v2.154.0 (2023-05-11) ### Features diff --git a/VERSION b/VERSION index 58b57014e1..0322240d8f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.154.1.dev0 +2.155.1.dev0 diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index bf1ec6e2d9..fb671324d9 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -12,9 +12,9 @@ awslogs==0.14.0 black==22.3.0 stopit==1.1.2 # Update tox.ini to have correct version of airflow constraints file -apache-airflow==2.5.1 +apache-airflow==2.6.0 apache-airflow-providers-amazon==7.2.1 -attrs==22.1.0 +attrs>=23.1.0,<24 fabric==2.6.0 requests==2.27.1 sagemaker-experiments==0.1.35 @@ -23,3 +23,7 @@ pyvis==0.2.1 pandas>=1.3.5,<1.5 scikit-learn==1.0.2 cloudpickle==2.2.1 +scipy==1.7.3 +urllib3==1.26.8 +docker>=5.0.2,<7.0.0 +PyYAML==6.0 diff --git a/setup.py b/setup.py index ee7c8268e3..e313587bec 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def read_requirements(filename): # Declare minimal set for installation required_packages = [ - "attrs>=20.3.0,<23", + "attrs>=23.1.0,<24", "boto3>=1.26.131,<2.0", "cloudpickle==2.2.1", "google-pasta", @@ -60,7 +60,7 @@ def read_requirements(filename): "pandas", "pathos", "schema", - "PyYAML==5.4.1", + "PyYAML==6.0", "jsonschema", "platformdirs", "tblib==1.7.0", @@ -75,7 +75,7 @@ def read_requirements(filename): # Meta dependency groups extras["all"] = [item for group in extras.values() for item in group] # Tests specific dependencies (do not need to be included in 'all') -extras["test"] = (extras["all"] + read_requirements("requirements/extras/test_requirements.txt"),) +extras["test"] = (read_requirements("requirements/extras/test_requirements.txt"),) setup( name="sagemaker", diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index b9828e7037..b91851576e 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -854,11 +854,13 @@ def generate_serving_properties(self, serving_properties=None) -> Dict[str, str] if self.low_cpu_mem_usage: serving_properties["option.low_cpu_mem_usage"] = self.low_cpu_mem_usage # This is a workaround due to a bug in our built in handler for huggingface - # TODO: This needs to be fixed when new dlc is published + # TODO: Remove this logic whenever 0.20.0 image is out of service if ( serving_properties["option.entryPoint"] == "djl_python.huggingface" and self.dtype and self.dtype != "auto" + and self.djl_version + and int(self.djl_version.split(".")[1]) < 21 ): serving_properties["option.dtype"] = "auto" serving_properties.pop("option.load_in_8bit", None) diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py index 6202de858c..94d07a9655 100644 --- a/src/sagemaker/experiments/run.py +++ b/src/sagemaker/experiments/run.py @@ -633,7 +633,10 @@ def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: s Returns: str: The name of the Run object supplied by a user. """ - return trial_component_name.replace("{}{}".format(experiment_name, DELIMITER), "", 1) + # TODO: we should revert the lower casting once backend fix reaches prod + return trial_component_name.replace( + "{}{}".format(experiment_name.lower(), DELIMITER), "", 1 + ) @staticmethod def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list: @@ -869,6 +872,8 @@ def list_runs( Returns: list: A list of ``Run`` objects. """ + + # all trial components retrieved by default tc_summaries = _TrialComponent.list( experiment_name=experiment_name, created_before=created_before, diff --git a/src/sagemaker/inference_recommender/inference_recommender_mixin.py b/src/sagemaker/inference_recommender/inference_recommender_mixin.py index 90b460a23e..c4d7f77985 100644 --- a/src/sagemaker/inference_recommender/inference_recommender_mixin.py +++ b/src/sagemaker/inference_recommender/inference_recommender_mixin.py @@ -145,10 +145,10 @@ def right_size( ) if endpoint_configurations or traffic_pattern or stopping_conditions or resource_limit: - LOGGER.info("Advance Job parameters were specified. Running Advanced job...") + LOGGER.info("Advanced Job parameters were specified. Running Advanced job...") job_type = "Advanced" else: - LOGGER.info("Advance Job parameters were not specified. Running Default job...") + LOGGER.info("Advanced Job parameters were not specified. Running Default job...") job_type = "Default" self._init_sagemaker_session_if_does_not_exist() diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 1785f15892..93a40c4114 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -301,6 +301,7 @@ def wrapper(*args, **kwargs): s3_uri=s3_path_join( job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER ), + hmac_key=job.hmac_key, ) except ServiceError as serr: chained_e = serr.__cause__ @@ -337,6 +338,7 @@ def wrapper(*args, **kwargs): return serialization.deserialize_obj_from_s3( sagemaker_session=job_settings.sagemaker_session, s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER), + hmac_key=job.hmac_key, ) if job.describe()["TrainingJobStatus"] == "Stopped": @@ -745,7 +747,7 @@ def map(self, func, *iterables): futures = map(self.submit, itertools.repeat(func), *iterables) return [future.result() for future in futures] - def shutdown(self, wait=True): + def shutdown(self): """Prevent more function executions to be submitted to this executor.""" with self._state_condition: self._shutdown = True @@ -756,7 +758,7 @@ def shutdown(self, wait=True): self._state_condition.notify_all() if self._workers is not None: - self._workers.shutdown(wait) + self._workers.shutdown(wait=True) def __enter__(self): """Create an executor instance and return it""" @@ -764,7 +766,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): """Make sure the executor instance is shutdown.""" - self.shutdown(wait=False) + self.shutdown() return False @staticmethod @@ -861,6 +863,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_return = serialization.deserialize_obj_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER), + hmac_key=job.hmac_key, ) except DeserializationError as e: client_exception = e @@ -872,6 +875,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_exception = serialization.deserialize_exception_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER), + hmac_key=job.hmac_key, ) except ServiceError as serr: chained_e = serr.__cause__ @@ -961,6 +965,7 @@ def result(self, timeout: float = None) -> Any: self._return = serialization.deserialize_obj_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER), + hmac_key=self._job.hmac_key, ) self._state = _FINISHED return self._return @@ -969,6 +974,7 @@ def result(self, timeout: float = None) -> Any: self._exception = serialization.deserialize_exception_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER), + hmac_key=self._job.hmac_key, ) except ServiceError as serr: chained_e = serr.__cause__ diff --git a/src/sagemaker/remote_function/core/serialization.py b/src/sagemaker/remote_function/core/serialization.py index 29b7f18bb1..989da71df9 100644 --- a/src/sagemaker/remote_function/core/serialization.py +++ b/src/sagemaker/remote_function/core/serialization.py @@ -17,12 +17,16 @@ import json import os import sys +import hmac +import hashlib import cloudpickle from typing import Any, Callable from sagemaker.remote_function.errors import ServiceError, SerializationError, DeserializationError from sagemaker.s3 import S3Downloader, S3Uploader +from sagemaker.session import Session + from tblib import pickling_support @@ -34,6 +38,7 @@ def _get_python_version(): class _MetaData: """Metadata about the serialized data or functions.""" + sha256_hash: str version: str = "2023-04-24" python_version: str = _get_python_version() serialization_module: str = "cloudpickle" @@ -48,11 +53,17 @@ def from_json(s): except json.decoder.JSONDecodeError: raise DeserializationError("Corrupt metadata file. It is not a valid json file.") - metadata = _MetaData() + sha256_hash = obj.get("sha256_hash") + metadata = _MetaData(sha256_hash=sha256_hash) metadata.version = obj.get("version") metadata.python_version = obj.get("python_version") metadata.serialization_module = obj.get("serialization_module") + if not sha256_hash: + raise DeserializationError( + "Corrupt metadata file. SHA256 hash for the serialized data does not exist" + ) + if not ( metadata.version == "2023-04-24" and metadata.serialization_module == "cloudpickle" ): @@ -67,20 +78,16 @@ class CloudpickleSerializer: """Serializer using cloudpickle.""" @staticmethod - def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None): + def serialize(obj: Any) -> Any: """Serializes data object and uploads it to S3. Args: - sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service - calls are delegated to. - s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. obj: object to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ try: - bytes_to_upload = cloudpickle.dumps(obj) + return cloudpickle.dumps(obj) except Exception as e: if isinstance( e, NotImplementedError @@ -96,10 +103,8 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None): "Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e)) ) from e - _upload_bytes_to_s3(bytes_to_upload, s3_uri, s3_kms_key, sagemaker_session) - @staticmethod - def deserialize(sagemaker_session, s3_uri) -> Any: + def deserialize(s3_uri: str, bytes_to_deserialize) -> Any: """Downloads from S3 and then deserializes data objects. Args: @@ -111,7 +116,6 @@ def deserialize(sagemaker_session, s3_uri) -> Any: Raises: DeserializationError: when fail to serialize object to bytes. """ - bytes_to_deserialize = _read_bytes_from_s3(s3_uri, sagemaker_session) try: return cloudpickle.loads(bytes_to_deserialize) @@ -122,28 +126,39 @@ def deserialize(sagemaker_session, s3_uri) -> Any: # TODO: use dask serializer in case dask distributed is installed in users' environment. -def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=None): +def serialize_func_to_s3( + func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None +): """Serializes function and uploads it to S3. Args: sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. func: function to be serialized and persisted Raises: SerializationError: when fail to serialize function to bytes. """ + bytes_to_upload = CloudpickleSerializer.serialize(func) + _upload_bytes_to_s3( - _MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session + bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session ) - CloudpickleSerializer.serialize( - func, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key + + sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) + + _upload_bytes_to_s3( + _MetaData(sha256_hash).to_json(), + os.path.join(s3_uri, "metadata.json"), + s3_kms_key, + sagemaker_session, ) -def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable: +def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable: """Downloads from S3 and then deserializes data objects. This method downloads the serialized training job outputs to a temporary directory and @@ -153,19 +168,32 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable: sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func. Returns : The deserialized function. Raises: DeserializationError: when fail to serialize function to bytes. """ - _MetaData.from_json( + metadata = _MetaData.from_json( _read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session) ) - return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl")) + bytes_to_deserialize = _read_bytes_from_s3( + os.path.join(s3_uri, "payload.pkl"), sagemaker_session + ) + + _perform_integrity_check( + expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + ) + + return CloudpickleSerializer.deserialize( + os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize + ) -def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None): +def serialize_obj_to_s3( + obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None +): """Serializes data object and uploads it to S3. Args: @@ -173,41 +201,61 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. obj: object to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ + bytes_to_upload = CloudpickleSerializer.serialize(obj) + _upload_bytes_to_s3( - _MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session + bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session ) - CloudpickleSerializer.serialize( - obj, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key + + sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) + + _upload_bytes_to_s3( + _MetaData(sha256_hash).to_json(), + os.path.join(s3_uri, "metadata.json"), + s3_kms_key, + sagemaker_session, ) -def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any: +def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any: """Downloads from S3 and then deserializes data objects. Args: sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. Returns : Deserialized python objects. Raises: DeserializationError: when fail to serialize object to bytes. """ - _MetaData.from_json( + metadata = _MetaData.from_json( _read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session) ) - return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl")) + bytes_to_deserialize = _read_bytes_from_s3( + os.path.join(s3_uri, "payload.pkl"), sagemaker_session + ) + + _perform_integrity_check( + expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + ) + + return CloudpickleSerializer.deserialize( + os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize + ) def serialize_exception_to_s3( - exc: Exception, sagemaker_session, s3_uri: str, s3_kms_key: str = None + exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None ): """Serializes exception with traceback and uploads it to S3. @@ -216,37 +264,58 @@ def serialize_exception_to_s3( calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception. exc: Exception to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ pickling_support.install() + + bytes_to_upload = CloudpickleSerializer.serialize(exc) + _upload_bytes_to_s3( - _MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session + bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session ) - CloudpickleSerializer.serialize( - exc, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key + + sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) + + _upload_bytes_to_s3( + _MetaData(sha256_hash).to_json(), + os.path.join(s3_uri, "metadata.json"), + s3_kms_key, + sagemaker_session, ) -def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any: +def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any: """Downloads from S3 and then deserializes exception. Args: sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception. Returns : Deserialized exception with traceback. Raises: DeserializationError: when fail to serialize object to bytes. """ - _MetaData.from_json( + metadata = _MetaData.from_json( _read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session) ) - return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl")) + bytes_to_deserialize = _read_bytes_from_s3( + os.path.join(s3_uri, "payload.pkl"), sagemaker_session + ) + + _perform_integrity_check( + expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + ) + + return CloudpickleSerializer.deserialize( + os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize + ) def _upload_bytes_to_s3(bytes, s3_uri, s3_kms_key, sagemaker_session): @@ -269,3 +338,22 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session): raise ServiceError( "Failed to read serialized bytes from {}: {}".format(s3_uri, repr(e)) ) from e + + +def _compute_hash(buffer: bytes, secret_key: str) -> str: + """Compute the hmac-sha256 hash""" + return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest() + + +def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes): + """Performs integrify checks for serialized code/arguments uploaded to s3. + + Verifies whether the hash read from s3 matches the hash calculated + during remote function execution. + """ + actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key) + if not hmac.compare_digest(expected_hash_value, actual_hash_value): + raise DeserializationError( + "Integrity check for the serialized function or data failed. " + "Please restrict access to your S3 bucket" + ) diff --git a/src/sagemaker/remote_function/core/stored_function.py b/src/sagemaker/remote_function/core/stored_function.py index 0204cf3e51..7c3b0d2949 100644 --- a/src/sagemaker/remote_function/core/stored_function.py +++ b/src/sagemaker/remote_function/core/stored_function.py @@ -17,6 +17,7 @@ from sagemaker.remote_function import logging_config import sagemaker.remote_function.core.serialization as serialization +from sagemaker.session import Session logger = logging_config.get_logger() @@ -31,7 +32,9 @@ class StoredFunction: """Class representing a remote function stored in S3.""" - def __init__(self, sagemaker_session, s3_base_uri, s3_kms_key=None): + def __init__( + self, sagemaker_session: Session, s3_base_uri: str, hmac_key: str, s3_kms_key: str = None + ): """Construct a StoredFunction object. Args: @@ -39,10 +42,12 @@ def __init__(self, sagemaker_session, s3_base_uri, s3_kms_key=None): AWS service calls are delegated to. s3_base_uri: the base uri to which serialized artifacts will be uploaded. s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. + hmac_key: Key used to encrypt serialized and deserialied function and arguments """ self.sagemaker_session = sagemaker_session self.s3_base_uri = s3_base_uri self.s3_kms_key = s3_kms_key + self.hmac_key = hmac_key def save(self, func, *args, **kwargs): """Serialize and persist the function and arguments. @@ -58,20 +63,22 @@ def save(self, func, *args, **kwargs): f"Serializing function code to {s3_path_join(self.s3_base_uri, FUNCTION_FOLDER)}" ) serialization.serialize_func_to_s3( - func, - self.sagemaker_session, - s3_path_join(self.s3_base_uri, FUNCTION_FOLDER), - self.s3_kms_key, + func=func, + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.s3_base_uri, FUNCTION_FOLDER), + s3_kms_key=self.s3_kms_key, + hmac_key=self.hmac_key, ) logger.info( f"Serializing function arguments to {s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER)}" ) serialization.serialize_obj_to_s3( - (args, kwargs), - self.sagemaker_session, - s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER), - self.s3_kms_key, + obj=(args, kwargs), + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER), + hmac_key=self.hmac_key, + s3_kms_key=self.s3_kms_key, ) def load_and_invoke(self) -> None: @@ -81,14 +88,18 @@ def load_and_invoke(self) -> None: f"Deserializing function code from {s3_path_join(self.s3_base_uri, FUNCTION_FOLDER)}" ) func = serialization.deserialize_func_from_s3( - self.sagemaker_session, s3_path_join(self.s3_base_uri, FUNCTION_FOLDER) + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.s3_base_uri, FUNCTION_FOLDER), + hmac_key=self.hmac_key, ) logger.info( f"Deserializing function arguments from {s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER)}" ) args, kwargs = serialization.deserialize_obj_from_s3( - self.sagemaker_session, s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER) + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER), + hmac_key=self.hmac_key, ) logger.info("Invoking the function") @@ -98,8 +109,9 @@ def load_and_invoke(self) -> None: f"Serializing the function return and uploading to {s3_path_join(self.s3_base_uri, RESULTS_FOLDER)}" ) serialization.serialize_obj_to_s3( - result, - self.sagemaker_session, - s3_path_join(self.s3_base_uri, RESULTS_FOLDER), - self.s3_kms_key, + obj=result, + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.s3_base_uri, RESULTS_FOLDER), + hmac_key=self.hmac_key, + s3_kms_key=self.s3_kms_key, ) diff --git a/src/sagemaker/remote_function/errors.py b/src/sagemaker/remote_function/errors.py index b0f1f7031c..9c91f46061 100644 --- a/src/sagemaker/remote_function/errors.py +++ b/src/sagemaker/remote_function/errors.py @@ -70,7 +70,7 @@ def _write_failure_reason_file(failure_msg): f.write(failure_msg) -def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: +def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) -> int: """Handle all exceptions raised during remote function execution. Args: @@ -79,6 +79,7 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: AWS service calls are delegated to. s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + hmac_key (str): Key used to calculate hmac hash of the serialized exception. Returns : exit_code (int): Exit code to terminate current job. """ @@ -93,7 +94,11 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: _write_failure_reason_file(failure_reason) serialization.serialize_exception_to_s3( - error, sagemaker_session, s3_path_join(s3_base_uri, "exception"), s3_kms_key + exc=error, + sagemaker_session=sagemaker_session, + s3_uri=s3_path_join(s3_base_uri, "exception"), + hmac_key=hmac_key, + s3_kms_key=s3_kms_key, ) return exit_code diff --git a/src/sagemaker/remote_function/invoke_function.py b/src/sagemaker/remote_function/invoke_function.py index 66c866a1b0..5963d77a42 100644 --- a/src/sagemaker/remote_function/invoke_function.py +++ b/src/sagemaker/remote_function/invoke_function.py @@ -17,6 +17,7 @@ import argparse import sys import json +import os import boto3 from sagemaker.experiments.run import Run @@ -61,11 +62,16 @@ def _load_run_object(run_in_context: str, sagemaker_session: Session) -> Run: ) -def _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key, run_in_context): +def _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, hmac_key): """Execute stored remote function""" from sagemaker.remote_function.core.stored_function import StoredFunction - stored_function = StoredFunction(sagemaker_session, s3_base_uri, s3_kms_key) + stored_function = StoredFunction( + sagemaker_session=sagemaker_session, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + hmac_key=hmac_key, + ) if run_in_context: run_obj = _load_run_object(run_in_context, sagemaker_session) @@ -89,12 +95,26 @@ def main(): s3_kms_key = args.s3_kms_key run_in_context = args.run_in_context + hmac_key = os.getenv("REMOTE_FUNCTION_SECRET_KEY") + sagemaker_session = _get_sagemaker_session(region) - _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key, run_in_context) + _execute_remote_function( + sagemaker_session=sagemaker_session, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + run_in_context=run_in_context, + hmac_key=hmac_key, + ) except Exception as e: # pylint: disable=broad-except logger.exception("Error encountered while invoking the remote function.") - exit_code = handle_error(e, sagemaker_session, s3_base_uri, s3_kms_key) + exit_code = handle_error( + error=e, + sagemaker_session=sagemaker_session, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + hmac_key=hmac_key, + ) finally: sys.exit(exit_code) diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 04ebfada13..a96e6f7146 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -19,6 +19,7 @@ import shutil import sys import json +import secrets from typing import Dict, List, Tuple from sagemaker.config.config_schema import ( @@ -166,6 +167,8 @@ def __init__( {"AWS_DEFAULT_REGION": self.sagemaker_session.boto_region_name} ) + self.environment_variables.update({"REMOTE_FUNCTION_SECRET_KEY": secrets.token_hex(32)}) + _image_uri = resolve_value_from_config( direct_input=image_uri, config_path=REMOTE_FUNCTION_IMAGE_URI, @@ -304,11 +307,12 @@ def _get_default_image(session): class _Job: """Helper class that interacts with the SageMaker training service.""" - def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session): + def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session, hmac_key: str): """Initialize a _Job object.""" self.job_name = job_name self.s3_uri = s3_uri self.sagemaker_session = sagemaker_session + self.hmac_key = hmac_key self._last_describe_response = None @staticmethod @@ -316,7 +320,9 @@ def from_describe_response(describe_training_job_response, sagemaker_session): """Construct a _Job from a describe_training_job_response object.""" job_name = describe_training_job_response["TrainingJobName"] s3_uri = describe_training_job_response["OutputDataConfig"]["S3OutputPath"] - job = _Job(job_name, s3_uri, sagemaker_session) + hmac_key = describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"] + + job = _Job(job_name, s3_uri, sagemaker_session, hmac_key) job._last_describe_response = describe_training_job_response return job @@ -334,6 +340,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non """ job_name = _Job._get_job_name(job_settings, func) s3_base_uri = s3_path_join(job_settings.s3_root_uri, job_name) + hmac_key = job_settings.environment_variables["REMOTE_FUNCTION_SECRET_KEY"] bootstrap_scripts_s3uri = _prepare_and_upload_runtime_scripts( s3_base_uri=s3_base_uri, @@ -355,6 +362,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non stored_function = StoredFunction( sagemaker_session=job_settings.sagemaker_session, s3_base_uri=s3_base_uri, + hmac_key=hmac_key, s3_kms_key=job_settings.s3_kms_key, ) @@ -454,13 +462,12 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non if job_settings.vpc_config: request_dict["VpcConfig"] = job_settings.vpc_config - if job_settings.environment_variables: - request_dict["Environment"] = job_settings.environment_variables + request_dict["Environment"] = job_settings.environment_variables logger.info("Creating job: %s", job_name) job_settings.sagemaker_session.sagemaker_client.create_training_job(**request_dict) - return _Job(job_name, s3_base_uri, job_settings.sagemaker_session) + return _Job(job_name, s3_base_uri, job_settings.sagemaker_session, hmac_key) def describe(self): """Describe the underlying sagemaker training job.""" diff --git a/tests/integ/sagemaker/experiments/test_run.py b/tests/integ/sagemaker/experiments/test_run.py index 40738e9360..96fc632ad7 100644 --- a/tests/integ/sagemaker/experiments/test_run.py +++ b/tests/integ/sagemaker/experiments/test_run.py @@ -642,6 +642,38 @@ def test_list(run_obj, sagemaker_session): assert run_tcs[0].experiment_config == run_obj.experiment_config +def test_list_twice(run_obj, sagemaker_session): + tc1 = _TrialComponent.create( + trial_component_name=f"non-run-tc1-{name()}", + sagemaker_session=sagemaker_session, + ) + tc2 = _TrialComponent.create( + trial_component_name=f"non-run-tc2-{name()}", + sagemaker_session=sagemaker_session, + tags=TAGS, + ) + run_obj._trial.add_trial_component(tc1) + run_obj._trial.add_trial_component(tc2) + + run_tcs = list_runs( + experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session + ) + assert len(run_tcs) == 1 + assert run_tcs[0].run_name == run_obj.run_name + assert run_tcs[0].experiment_name == run_obj.experiment_name + assert run_tcs[0].experiment_config == run_obj.experiment_config + + # note the experiment name used by run_obj is already mixed case and so + # covers the mixed case experiment name double create issue + run_tcs_second_result = list_runs( + experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session + ) + assert len(run_tcs) == 1 + assert run_tcs_second_result[0].run_name == run_obj.run_name + assert run_tcs_second_result[0].experiment_name == run_obj.experiment_name + assert run_tcs_second_result[0].experiment_config == run_obj.experiment_config + + def _generate_estimator( exp_name, sdk_tar, diff --git a/tests/integ/sagemaker/remote_function/test_decorator.py b/tests/integ/sagemaker/remote_function/test_decorator.py index 541ab1417c..5e7d0a9d91 100644 --- a/tests/integ/sagemaker/remote_function/test_decorator.py +++ b/tests/integ/sagemaker/remote_function/test_decorator.py @@ -600,6 +600,7 @@ def get_file_content(file_names): assert "line 2: bws: command not found" in str(e) +@pytest.mark.skip def test_decorator_auto_capture(sagemaker_session, auto_capture_test_container): """ This test runs a docker container. The Container invocation will execute a python script diff --git a/tests/integ/utils.py b/tests/integ/utils.py index d7891321f2..c13759b39b 100644 --- a/tests/integ/utils.py +++ b/tests/integ/utils.py @@ -19,12 +19,15 @@ from tests.conftest import NO_P3_REGIONS, NO_M4_REGIONS from sagemaker.exceptions import CapacityError +P2_INSTANCES = ["ml.p2.xlarge", "ml.p2.8xlarge", "ml.p2.16xlarge"] +P3_INSTANCES = ["ml.p3.2xlarge"] + def gpu_list(region): if region in NO_P3_REGIONS: - return ["ml.p2.xlarge"] + return P2_INSTANCES else: - return ["ml.p3.2xlarge", "ml.p2.xlarge"] + return [*P2_INSTANCES, *P3_INSTANCES] def cpu_list(region): diff --git a/tests/unit/sagemaker/experiments/helpers.py b/tests/unit/sagemaker/experiments/helpers.py index 0fec9f7fc3..d560462def 100644 --- a/tests/unit/sagemaker/experiments/helpers.py +++ b/tests/unit/sagemaker/experiments/helpers.py @@ -17,6 +17,7 @@ TEST_EXP_NAME = "my-experiment" +TEST_EXP_NAME_MIXED_CASE = "My-eXpeRiMeNt" TEST_RUN_NAME = "my-run" TEST_EXP_DISPLAY_NAME = "my-experiment-display-name" TEST_RUN_DISPLAY_NAME = "my-run-display-name" diff --git a/tests/unit/sagemaker/experiments/test_run.py b/tests/unit/sagemaker/experiments/test_run.py index 7f54fe8d6f..3820d7e4f6 100644 --- a/tests/unit/sagemaker/experiments/test_run.py +++ b/tests/unit/sagemaker/experiments/test_run.py @@ -48,6 +48,7 @@ mock_trial_load_or_create_func, mock_tc_load_or_create_func, TEST_EXP_NAME, + TEST_EXP_NAME_MIXED_CASE, TEST_RUN_NAME, TEST_EXP_DISPLAY_NAME, TEST_RUN_DISPLAY_NAME, @@ -779,7 +780,9 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses ] mock_tc_list.return_value = [ TrialComponentSummary( - trial_component_name=Run._generate_trial_component_name("A" + str(i), TEST_EXP_NAME), + trial_component_name=Run._generate_trial_component_name( + "A" + str(i), TEST_EXP_NAME_MIXED_CASE + ), trial_component_arn="b" + str(i), display_name="C" + str(i), source_arn="D" + str(i), @@ -798,7 +801,7 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses ( _TrialComponent( trial_component_name=Run._generate_trial_component_name( - "a" + str(i), TEST_EXP_NAME + "a" + str(i), TEST_EXP_NAME_MIXED_CASE ), trial_component_arn="b" + str(i), display_name="C" + str(i), @@ -818,14 +821,14 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses ] run_list = list_runs( - experiment_name=TEST_EXP_NAME, + experiment_name=TEST_EXP_NAME_MIXED_CASE, sort_by=SortByType.CREATION_TIME, sort_order=SortOrderType.ASCENDING, sagemaker_session=sagemaker_session, ) mock_tc_list.assert_called_once_with( - experiment_name=TEST_EXP_NAME, + experiment_name=TEST_EXP_NAME_MIXED_CASE, created_before=None, created_after=None, sort_by="CreationTime", diff --git a/tests/unit/sagemaker/remote_function/core/test_serialization.py b/tests/unit/sagemaker/remote_function/core/test_serialization.py index eb06cf5cc4..28f5b215e8 100644 --- a/tests/unit/sagemaker/remote_function/core/test_serialization.py +++ b/tests/unit/sagemaker/remote_function/core/test_serialization.py @@ -31,6 +31,7 @@ from tblib import pickling_support KMS_KEY = "kms-key" +HMAC_KEY = "some-hmac-key" mock_s3 = {} @@ -64,11 +65,15 @@ def square(x): return x * x s3_uri = random_s3_uri() - serialize_func_to_s3(func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_func_to_s3( + func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) del square - deserialized = deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialized = deserialize_func_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert deserialized(3) == 9 @@ -79,10 +84,16 @@ def test_serialize_deserialize_lambda(): s3_uri = random_s3_uri() serialize_func_to_s3( - func=lambda x: x * x, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + func=lambda x: x * x, + sagemaker_session=Mock(), + s3_uri=s3_uri, + s3_kms_key=KMS_KEY, + hmac_key=HMAC_KEY, ) - deserialized = deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialized = deserialize_func_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert deserialized(3) == 9 @@ -107,7 +118,11 @@ def train(x): match="or instantiate a new Run in the function.", ): serialize_func_to_s3( - func=train, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + func=train, + sagemaker_session=Mock(), + s3_uri=s3_uri, + s3_kms_key=KMS_KEY, + hmac_key=HMAC_KEY, ) @@ -127,7 +142,11 @@ def square(x): match=r"Error when serializing object of type \[function\]: RuntimeError\('some failure when dumps'\)", ): serialize_func_to_s3( - func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + func=square, + sagemaker_session=Mock(), + s3_uri=s3_uri, + s3_kms_key=KMS_KEY, + hmac_key=HMAC_KEY, ) @@ -142,7 +161,9 @@ def square(x): s3_uri = random_s3_uri() - serialize_func_to_s3(func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_func_to_s3( + func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) del square @@ -151,7 +172,7 @@ def square(x): match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: " + r"RuntimeError\('some failure when loads'\)", ): - deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @@ -162,13 +183,34 @@ def square(x): s3_uri = random_s3_uri() - serialize_func_to_s3(func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_func_to_s3( + func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) mock_s3[os.path.join(s3_uri, "metadata.json")] = b"not json serializable" del square with pytest.raises(DeserializationError, match=r"Corrupt metadata file."): - deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +def test_deserialize_integrity_check_failed(): + def square(x): + return x * x + + s3_uri = random_s3_uri() + serialize_func_to_s3( + func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) + + del square + + with pytest.raises( + DeserializationError, match=r"Integrity check for the serialized function or data failed." + ): + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key="invalid_key") @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @@ -181,12 +223,16 @@ def __init__(self, x): my_data = MyData(10) s3_uri = random_s3_uri() - serialize_obj_to_s3(my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_obj_to_s3( + my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) del my_data del MyData - deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialized = deserialize_obj_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert deserialized.x == 10 @@ -198,11 +244,15 @@ def test_serialize_deserialize_data_built_in_types(): my_data = {"a": [10]} s3_uri = random_s3_uri() - serialize_obj_to_s3(my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_obj_to_s3( + my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) del my_data - deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialized = deserialize_obj_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert deserialized == {"a": [10]} @@ -212,9 +262,13 @@ def test_serialize_deserialize_data_built_in_types(): def test_serialize_deserialize_none(): s3_uri = random_s3_uri() - serialize_obj_to_s3(None, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_obj_to_s3( + None, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) - deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialized = deserialize_obj_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert deserialized is None @@ -234,7 +288,11 @@ def test_serialize_run(*args, **kwargs): match="or instantiate a new Run in the function.", ): serialize_obj_to_s3( - obj=run, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + obj=run, + sagemaker_session=Mock(), + s3_uri=s3_uri, + s3_kms_key=KMS_KEY, + hmac_key=HMAC_KEY, ) @@ -256,7 +314,11 @@ def __init__(self, x): match=r"Error when serializing object of type \[MyData\]: RuntimeError\('some failure when dumps'\)", ): serialize_obj_to_s3( - obj=my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + obj=my_data, + sagemaker_session=Mock(), + s3_uri=s3_uri, + s3_kms_key=KMS_KEY, + hmac_key=HMAC_KEY, ) @@ -273,7 +335,9 @@ def __init__(self, x): my_data = MyData(10) s3_uri = random_s3_uri() - serialize_obj_to_s3(obj=my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_obj_to_s3( + obj=my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) del my_data del MyData @@ -283,7 +347,7 @@ def __init__(self, x): match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: " + r"RuntimeError\('some failure when loads'\)", ): - deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload_error) @@ -295,11 +359,15 @@ def test_serialize_deserialize_service_error(): s3_uri = random_s3_uri() with pytest.raises( ServiceError, - match=rf"Failed to upload serialized bytes to {s3_uri}/metadata.json: " + match=rf"Failed to upload serialized bytes to {s3_uri}/payload.pkl: " + r"RuntimeError\('some failure when upload_bytes'\)", ): serialize_func_to_s3( - func=my_func, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + func=my_func, + sagemaker_session=Mock(), + s3_uri=s3_uri, + s3_kms_key=KMS_KEY, + hmac_key=HMAC_KEY, ) del my_func @@ -309,7 +377,7 @@ def test_serialize_deserialize_service_error(): match=rf"Failed to read serialized bytes from {s3_uri}/metadata.json: " + r"RuntimeError\('some failure when read_bytes'\)", ): - deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @@ -333,10 +401,12 @@ def func_b(): func_b() except Exception as e: pickling_support.install() - serialize_obj_to_s3(e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_obj_to_s3( + e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) with pytest.raises(CustomError, match="Some error") as exc_info: - raise deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + raise deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) assert type(exc_info.value.__cause__) is TypeError @@ -360,10 +430,14 @@ def func_b(): try: func_b() except Exception as e: - serialize_exception_to_s3(e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_exception_to_s3( + e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) with pytest.raises(CustomError, match="Some error") as exc_info: - raise deserialize_exception_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + raise deserialize_exception_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert type(exc_info.value.__cause__) is TypeError @@ -387,8 +461,12 @@ def func_b(): try: func_b() except Exception as e: - serialize_exception_to_s3(e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_exception_to_s3( + e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) with pytest.raises(ServiceError, match="Some error") as exc_info: - raise deserialize_exception_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + raise deserialize_exception_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert type(exc_info.value.__cause__) is TypeError diff --git a/tests/unit/sagemaker/remote_function/core/test_stored_function.py b/tests/unit/sagemaker/remote_function/core/test_stored_function.py index 0b4008ef41..9833994c98 100644 --- a/tests/unit/sagemaker/remote_function/core/test_stored_function.py +++ b/tests/unit/sagemaker/remote_function/core/test_stored_function.py @@ -34,6 +34,7 @@ ) KMS_KEY = "kms-key" +HMAC_KEY = "some-hmac-key" mock_s3 = {} @@ -75,14 +76,14 @@ def test_save_and_load(s3_source_dir_download, s3_source_dir_upload, args, kwarg s3_base_uri = random_s3_uri() stored_function = StoredFunction( - sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY + sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY ) stored_function.save(quadratic, *args, **kwargs) stored_function.load_and_invoke() - assert deserialize_obj_from_s3(session, s3_uri=f"{s3_base_uri}/results") == quadratic( - *args, **kwargs - ) + assert deserialize_obj_from_s3( + session, s3_uri=f"{s3_base_uri}/results", hmac_key=HMAC_KEY + ) == quadratic(*args, **kwargs) @patch( @@ -117,7 +118,7 @@ def test_save_with_parameter_of_run_type( sagemaker_session=session, ) stored_function = StoredFunction( - sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY + sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY ) with pytest.raises(SerializationError) as e: stored_function.save(log_bigger, 1, 2, run) diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py index fede42dab1..9c95fd96b5 100644 --- a/tests/unit/sagemaker/remote_function/test_client.py +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -48,6 +48,7 @@ S3_URI = f"s3://{BUCKET}/keyprefix" EXPECTED_JOB_RESULT = [1, 2, 3] PATH_TO_SRC_DIR = "path/to/src/dir" +HMAC_KEY = "some-hmac-key" def describe_training_job_response(job_status): @@ -61,6 +62,7 @@ def describe_training_job_response(job_status): "VolumeSizeInGB": 30, }, "OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"}, + "Environment": {"REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, } @@ -518,11 +520,6 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism): future_3 = e.submit(job_function, 9, 10, c=11, d=12) future_4 = e.submit(job_function, 13, 14, c=15, d=16) - future_1.wait() - future_2.wait() - future_3.wait() - future_4.wait() - mock_start.assert_has_calls( [ call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, None), @@ -531,6 +528,10 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism): call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, None), ] ) + mock_job_1.describe.assert_called() + mock_job_2.describe.assert_called() + mock_job_3.describe.assert_called() + mock_job_4.describe.assert_called() assert future_1.done() assert future_2.done() @@ -555,15 +556,14 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj): future_1 = e.submit(job_function, 1, 2, c=3, d=4) future_2 = e.submit(job_function, 5, 6, c=7, d=8) - future_1.wait() - future_2.wait() - mock_start.assert_has_calls( [ call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, run_info), call(ANY, job_function, (5, 6), {"c": 7, "d": 8}, run_info), ] ) + mock_job_1.describe.assert_called() + mock_job_2.describe.assert_called() assert future_1.done() assert future_2.done() @@ -573,15 +573,14 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj): future_3 = e.submit(job_function, 9, 10, c=11, d=12) future_4 = e.submit(job_function, 13, 14, c=15, d=16) - future_3.wait() - future_4.wait() - mock_start.assert_has_calls( [ call(ANY, job_function, (9, 10), {"c": 11, "d": 12}, run_info), call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, run_info), ] ) + mock_job_3.describe.assert_called() + mock_job_4.describe.assert_called() assert future_3.done() assert future_4.done() @@ -633,7 +632,7 @@ def test_executor_fails_to_start_job(mock_start, *args): with pytest.raises(TypeError): future_1.result() - future_2.wait() + print(future_2._state) assert future_2.done() @@ -698,8 +697,6 @@ def test_executor_describe_job_throttled_temporarily(mock_start, *args): # submit second job future_2 = e.submit(job_function, 5, 6, c=7, d=8) - future_1.wait() - future_2.wait() assert future_1.done() assert future_2.done() @@ -719,9 +716,9 @@ def test_executor_describe_job_failed_permanently(mock_start, *args): future_2 = e.submit(job_function, 5, 6, c=7, d=8) with pytest.raises(RuntimeError): - future_1.result() + future_1.done() with pytest.raises(RuntimeError): - future_2.result() + future_2.done() @pytest.mark.parametrize( @@ -892,7 +889,7 @@ def test_future_get_result_from_completed_job(mock_start, mock_deserialize): def test_future_get_result_from_failed_job_remote_error_client_function( mock_start, mock_deserialize ): - mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI) + mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI, hmac_key=HMAC_KEY) mock_start.return_value = mock_job mock_job.describe.return_value = FAILED_TRAINING_JOB @@ -907,7 +904,9 @@ def test_future_get_result_from_failed_job_remote_error_client_function( assert future.done() mock_job.wait.assert_called_once() - mock_deserialize.assert_called_with(sagemaker_session=ANY, s3_uri=f"{S3_URI}/exception") + mock_deserialize.assert_called_with( + sagemaker_session=ANY, s3_uri=f"{S3_URI}/exception", hmac_key=HMAC_KEY + ) @patch("sagemaker.s3.S3Downloader.read_bytes") @@ -1235,7 +1234,9 @@ def test_get_future_completed_job_deserialization_error(mock_session, mock_deser future.result() mock_deserialize.assert_called_with( - sagemaker_session=ANY, s3_uri="s3://sagemaker-123/image_uri/output/results" + sagemaker_session=ANY, + s3_uri="s3://sagemaker-123/image_uri/output/results", + hmac_key=HMAC_KEY, ) diff --git a/tests/unit/sagemaker/remote_function/test_errors.py b/tests/unit/sagemaker/remote_function/test_errors.py index 78b864e784..399b1aed2e 100644 --- a/tests/unit/sagemaker/remote_function/test_errors.py +++ b/tests/unit/sagemaker/remote_function/test_errors.py @@ -20,6 +20,7 @@ TEST_S3_BASE_URI = "s3://my-bucket/" TEST_S3_KMS_KEY = "my-kms-key" +TEST_HMAC_KEY = "some-hmac-key" class _InvalidErrorNumberException(Exception): @@ -70,12 +71,22 @@ def test_handle_error( error_string, ): err = error - exit_code = handle_error(err, sagemaker_session, TEST_S3_BASE_URI, TEST_S3_KMS_KEY) + exit_code = handle_error( + error=err, + sagemaker_session=sagemaker_session, + s3_base_uri=TEST_S3_BASE_URI, + s3_kms_key=TEST_S3_KMS_KEY, + hmac_key=TEST_HMAC_KEY, + ) assert exit_code == expected_exit_code exists.assert_called_once_with("/opt/ml/output/failure") mock_open_file.assert_called_with("/opt/ml/output/failure", "w") mock_open_file.return_value.__enter__().write.assert_called_with(error_string) serialize_exception_to_s3.assert_called_with( - err, sagemaker_session, TEST_S3_BASE_URI + "exception", TEST_S3_KMS_KEY + exc=err, + sagemaker_session=sagemaker_session, + s3_uri=TEST_S3_BASE_URI + "exception", + hmac_key=TEST_HMAC_KEY, + s3_kms_key=TEST_S3_KMS_KEY, ) diff --git a/tests/unit/sagemaker/remote_function/test_invoke_function.py b/tests/unit/sagemaker/remote_function/test_invoke_function.py index 661e2138e3..a8f658234f 100644 --- a/tests/unit/sagemaker/remote_function/test_invoke_function.py +++ b/tests/unit/sagemaker/remote_function/test_invoke_function.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import os from mock import patch, Mock from sagemaker.remote_function import invoke_function from sagemaker.remote_function.errors import SerializationError @@ -20,6 +21,7 @@ TEST_S3_BASE_URI = "s3://my-bucket/" TEST_S3_KMS_KEY = "my-kms-key" TEST_RUN_IN_CONTEXT = '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}' +TEST_HMAC_KEY = "some-hmac-key" def mock_args(): @@ -55,6 +57,7 @@ def mock_session(): return_value=mock_session(), ) def test_main_success(_get_sagemaker_session, load_and_invoke, _exit_process, _load_run_object): + os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY invoke_function.main() _get_sagemaker_session.assert_called_with(TEST_REGION) @@ -74,6 +77,7 @@ def test_main_success(_get_sagemaker_session, load_and_invoke, _exit_process, _l def test_main_success_with_run( _get_sagemaker_session, load_and_invoke, _exit_process, _load_run_object ): + os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY invoke_function.main() _get_sagemaker_session.assert_called_with(TEST_REGION) @@ -94,6 +98,7 @@ def test_main_success_with_run( def test_main_failure( _get_sagemaker_session, load_and_invoke, _exit_process, handle_error, _load_run_object ): + os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY ser_err = SerializationError("some failure reason") load_and_invoke.side_effect = ser_err handle_error.return_value = 1 @@ -104,6 +109,10 @@ def test_main_failure( load_and_invoke.assert_called() _load_run_object.assert_not_called() handle_error.assert_called_with( - ser_err, _get_sagemaker_session(), TEST_S3_BASE_URI, TEST_S3_KMS_KEY + error=ser_err, + sagemaker_session=_get_sagemaker_session(), + s3_base_uri=TEST_S3_BASE_URI, + s3_kms_key=TEST_S3_KMS_KEY, + hmac_key=TEST_HMAC_KEY, ) _exit_process.assert_called_with(1) diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index fb019875ad..686862bcc7 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -41,6 +41,7 @@ TEST_REGION = "us-west-2" RUNTIME_SCRIPTS_CHANNEL_NAME = "sagemaker_remote_function_bootstrap" REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" +HMAC_KEY = "some-hmac-key" EXPECTED_FUNCTION_URI = S3_URI + "/function.pkl" EXPECTED_OUTPUT_URI = S3_URI + "/output" @@ -111,22 +112,29 @@ def job_function(a, b=1, *, c, d=3): return a * b * c * d +@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job.Session", return_value=mock_session()) @patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN) -def test_sagemaker_config_job_settings(get_execution_role, session): +def test_sagemaker_config_job_settings(get_execution_role, session, secret_token): job_settings = _JobSettings(image_uri="image_uri", instance_type="ml.m5.xlarge") assert job_settings.image_uri == "image_uri" assert job_settings.s3_root_uri == f"s3://{BUCKET}" assert job_settings.role == DEFAULT_ROLE_ARN - assert job_settings.environment_variables == {"AWS_DEFAULT_REGION": "us-west-2"} + assert job_settings.environment_variables == { + "AWS_DEFAULT_REGION": "us-west-2", + "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY, + } assert job_settings.include_local_workdir is False assert job_settings.instance_type == "ml.m5.xlarge" +@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job.Session", return_value=mock_session()) @patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN) -def test_sagemaker_config_job_settings_with_configuration_file(get_execution_role, session): +def test_sagemaker_config_job_settings_with_configuration_file( + get_execution_role, session, secret_token +): config_tags = [ {"Key": "someTagKey", "Value": "someTagValue"}, {"Key": "someTagKey2", "Value": "someTagValue2"}, @@ -146,6 +154,7 @@ def test_sagemaker_config_job_settings_with_configuration_file(get_execution_rol assert job_settings.environment_variables == { "AWS_DEFAULT_REGION": "us-west-2", "EnvVarKey": "EnvVarValue", + "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY, } assert job_settings.job_conda_env == "my_conda_env" assert job_settings.include_local_workdir is True @@ -227,6 +236,7 @@ def test_sagemaker_config_job_settings_studio_image_uri(get_execution_role, sess @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job._prepare_and_upload_dependencies", return_value="some_s3_uri") @patch( "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" @@ -235,7 +245,12 @@ def test_sagemaker_config_job_settings_studio_image_uri(get_execution_role, sess @patch("sagemaker.remote_function.job.StoredFunction") @patch("sagemaker.remote_function.job.Session", return_value=mock_session()) def test_start( - session, mock_stored_function, mock_runtime_manager, mock_script_upload, mock_dependency_upload + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, ): job_settings = _JobSettings( @@ -252,7 +267,10 @@ def test_start( assert job.job_name.startswith("job-function") assert mock_stored_function.called_once_with( - sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", s3_kms_key=None + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, ) local_dependencies_path = mock_runtime_manager().snapshot() @@ -326,10 +344,11 @@ def test_start( ), EnableNetworkIsolation=False, EnableInterContainerTrafficEncryption=True, - Environment={"AWS_DEFAULT_REGION": "us-west-2"}, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, ) +@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job._prepare_and_upload_dependencies", return_value="some_s3_uri") @patch( "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" @@ -338,7 +357,12 @@ def test_start( @patch("sagemaker.remote_function.job.StoredFunction") @patch("sagemaker.remote_function.job.Session", return_value=mock_session()) def test_start_with_complete_job_settings( - session, mock_stored_function, mock_runtime_manager, mock_script_upload, mock_dependency_upload + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, ): job_settings = _JobSettings( @@ -363,7 +387,10 @@ def test_start_with_complete_job_settings( assert job.job_name.startswith("job-function") assert mock_stored_function.called_once_with( - sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", s3_kms_key=None + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, ) local_dependencies_path = mock_runtime_manager().snapshot() @@ -441,7 +468,7 @@ def test_start_with_complete_job_settings( EnableNetworkIsolation=False, EnableInterContainerTrafficEncryption=False, VpcConfig=dict(Subnets=["subnet"], SecurityGroupIds=["sg"]), - Environment={"AWS_DEFAULT_REGION": "us-west-2"}, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, ) diff --git a/tests/unit/test_djl_inference.py b/tests/unit/test_djl_inference.py index 93a1fba336..06adea8e76 100644 --- a/tests/unit/test_djl_inference.py +++ b/tests/unit/test_djl_inference.py @@ -454,7 +454,7 @@ def test_generate_serving_properties_with_valid_configurations( "option.entryPoint": "djl_python.huggingface", "option.s3url": VALID_UNCOMPRESSED_MODEL_DATA, "option.tensor_parallel_degree": 1, - "option.dtype": "auto", + "option.dtype": "fp32", "option.device_id": 4, "option.device_map": "balanced", } diff --git a/tox.ini b/tox.ini index 7ed9401b01..8a276b19b5 100644 --- a/tox.ini +++ b/tox.ini @@ -73,7 +73,7 @@ passenv = # Can be used to specify which tests to run, e.g.: tox -- -s commands = python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')" - pip install 'apache-airflow==2.5.1' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.5.1/constraints-3.7.txt" + pip install 'apache-airflow==2.6.0' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.6.0/constraints-3.7.txt" pytest --cov=sagemaker --cov-append {posargs} {env:IGNORE_COVERAGE:} coverage report -i --fail-under=86