diff --git a/.gitignore b/.gitignore index 9829ed9781..cae8f890ea 100644 --- a/.gitignore +++ b/.gitignore @@ -30,5 +30,6 @@ env/ .vscode/ **/tmp .python-version -**/_repack_model.py -**/_repack_script_launcher.sh \ No newline at end of file +**/_repack_script_launcher.sh +tests/data/**/_repack_model.py +tests/data/experiment/sagemaker-dev-1.0.tar.gz diff --git a/CHANGELOG.md b/CHANGELOG.md index 95e4a7b9cf..37b3440f69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,119 @@ # Changelog +## v2.125.0 (2022-12-19) + +### Features + + * add RandomSeed to support reproducible HPO + +### Bug Fixes and Other Changes + + * Correct SageMaker Clarify API docstrings by changing JSONPath to JMESPath + +## v2.124.0 (2022-12-16) + +### Features + + * Doc update for TableFormatEnum + * Add p4de to smddp supported instance types + * Add disable_profiler field in config and propagate changes + * Added doc update for dataset builder + +### Bug Fixes and Other Changes + + * Use Async Inference Config when available for endpoint update + +### Documentation Changes + + * smdistributed libraries release notes + +## v2.123.0 (2022-12-15) + +### Features + + * Add support for TF2.9.2 training images + * Add SageMaker Experiment + +## v2.122.0 (2022-12-14) + +### Features + + * Feature Store dataset builder, delete_record, get_record, list_feature_group + * Add OSU region to frameworks for DLC + +### Bug Fixes and Other Changes + + * the Hyperband support fix for the HPO + * unpin packaging version + * Remove content type image/jpg from analysis configuration schema + +## v2.121.2 (2022-12-12) + +### Bug Fixes and Other Changes + + * Update for Tensorflow Serving 2.11 inference DLCs + * Revert "fix: type hint of PySparkProcessor __init__" + * Skip Bad Transform Test + +## v2.121.1 (2022-12-09) + +### Bug Fixes and Other Changes + + * Pop out ModelPackageName from pipeline definition + * Fix failing jumpstart cache unit tests + +## v2.121.0 (2022-12-08) + +### Features + + * Algorithms Region Expansion OSU/DXB + +### Bug Fixes and Other Changes + + * FrameworkProcessor S3 uploads + * Add constraints file for apache-airflow + +## v2.120.0 (2022-12-07) + +### Features + + * Add Neo image uri config for Pytorch 1.12 + * Adding support for SageMaker Training Compiler in PyTorch estimator starting 1.12 + * Update registries with new region account number mappings. + * Add DXB region to frameworks by DLC + +### Bug Fixes and Other Changes + + * support idempotency for framework and spark processors + +## v2.119.0 (2022-12-03) + +### Features + + * Add Code Owners file + * Added transform with monitoring pipeline step in transformer + * Update TF 2.9 and TF 2.10 inference DLCs + * make estimator accept json file as modelparallel config + * SageMaker Training Compiler does not support p4de instances + * Add support for SparkML v3.3 + +### Bug Fixes and Other Changes + + * Fix bug forcing uploaded tar to be named sourcedir + * Update local_requirements.txt PyYAML version + * refactoring : using with statement + * Allow Py 3.7 for MMS Test Docker env + * fix PySparkProcessor __init__ params type + * type hint of PySparkProcessor __init__ + * Return ARM XGB/SKLearn tags if `image_scope` is `inference_graviton` + * Update scipy to 1.7.3 to support M1 development envs + * Fixing type hints for Spark processor that has instance type/count params in reverse order + * Add DeepAR ap-northeast-3 repository. + * Fix AsyncInferenceConfig documentation typo + * fix ml_inf to ml_inf1 in Neo multi-version support + * Fix type annotations + * add neo mvp region accounts + ## v2.118.0 (2022-12-01) ### Features diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000000..7f7ac28644 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @aws/sagemaker-ml-frameworks diff --git a/VERSION b/VERSION index 34d47b7f52..1e80f372b6 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.118.1.dev0 +2.125.1.dev0 diff --git a/doc/api/prep_data/feature_store.rst b/doc/api/prep_data/feature_store.rst index 1980a0b069..838558c0a4 100644 --- a/doc/api/prep_data/feature_store.rst +++ b/doc/api/prep_data/feature_store.rst @@ -72,3 +72,15 @@ Inputs .. autoclass:: sagemaker.feature_store.inputs.FeatureValue :members: :show-inheritance: + +.. autoclass:: sagemaker.feature_store.inputs.TableFormatEnum + :members: + :show-inheritance: + + +Dataset Builder +*************** + +.. autoclass:: sagemaker.feature_store.dataset_builder.DatasetBuilder + :members: + :show-inheritance: diff --git a/doc/api/training/sdp_versions/latest.rst b/doc/api/training/sdp_versions/latest.rst index c3fcc5f78e..461f58998f 100644 --- a/doc/api/training/sdp_versions/latest.rst +++ b/doc/api/training/sdp_versions/latest.rst @@ -26,8 +26,8 @@ depending on the version of the library you use. `_ for more information. -Version 1.4.0, 1.4.1, 1.5.0 (Latest) -==================================== +Version 1.4.0, 1.4.1, 1.5.0, 1.6.0 (Latest) +=========================================== .. toctree:: :maxdepth: 1 diff --git a/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst b/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst index 05eb7220e0..8ff7fabf1c 100644 --- a/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst +++ b/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst @@ -7,9 +7,51 @@ Release Notes New features, bug fixes, and improvements are regularly made to the SageMaker distributed data parallel library. -SageMaker Distributed Data Parallel 1.5.0 Release Notes +SageMaker Distributed Data Parallel 1.6.0 Release Notes ======================================================= +*Date: Dec. 15. 2022* + +**New Features** + +* New optimized SMDDP AllGather collective to complement the sharded data parallelism technique + in the SageMaker model parallelism library. For more information, see `Sharded data parallelism with SMDDP Collectives + `_ + in the *Amazon SageMaker Developer Guide*. +* Added support for Amazon EC2 ``ml.p4de.24xlarge`` instances. You can run data parallel training jobs + on ``ml.p4de.24xlarge`` instances with the SageMaker data parallelism library’s AllReduce collective. + +**Improvements** + +* General performance improvements of the SMDDP AllReduce collective communication operation. + +**Migration to AWS Deep Learning Containers** + +This version passed benchmark testing and is migrated to the following AWS Deep Learning Containers (DLC): + +- SageMaker training container for PyTorch v1.12.1 + + .. code:: + + 763104351884.dkr.ecr..amazonaws.com/pytorch-training:1.12.1-gpu-py38-cu113-ubuntu20.04-sagemaker + + +Binary file of this version of the library for `custom container +`_ users: + + .. code:: + + https://smdataparallel.s3.amazonaws.com/binary/pytorch/1.12.1/cu113/2022-12-05/smdistributed_dataparallel-1.6.0-cp38-cp38-linux_x86_64.whl + + +---- + +Release History +=============== + +SageMaker Distributed Data Parallel 1.5.0 Release Notes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + *Date: Jul. 26. 2022* **Currency Updates** @@ -38,12 +80,6 @@ Binary file of this version of the library for `custom container https://smdataparallel.s3.amazonaws.com/binary/pytorch/1.12.0/cu113/2022-07-01/smdistributed_dataparallel-1.5.0-cp38-cp38-linux_x86_64.whl - ----- - -Release History -=============== - SageMaker Distributed Data Parallel 1.4.1 Release Notes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst b/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst index 6f89fa45a5..92ccc8c407 100644 --- a/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst +++ b/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst @@ -6,9 +6,60 @@ New features, bug fixes, and improvements are regularly made to the SageMaker distributed model parallel library. -SageMaker Distributed Model Parallel 1.11.0 Release Notes +SageMaker Distributed Model Parallel 1.13.0 Release Notes ========================================================= +*Date: Dec. 15. 2022* + +**New Features** + +* Sharded data parallelism now supports a new backend for collectives called *SMDDP Collectives*. + For supported scenarios, SMDDP Collectives are on by default for the AllGather operation. + For more information, see + `Sharded data parallelism with SMDDP Collectives + `_ + in the *Amazon SageMaker Developer Guide*. +* Introduced FlashAttention for DistributedTransformer to improve memory usage and computational + performance of models such as GPT2, GPTNeo, GPTJ, GPTNeoX, BERT, and RoBERTa. + +**Bug Fixes** + +* Fixed initialization of ``lm_head`` in DistributedTransformer to use a provided range + for initialization, when weights are not tied with the embeddings. + +**Improvements** + +* When a module has no parameters, we have introduced an optimization to execute + such a module on the same rank as its parent during pipeline parallelism. + +**Migration to AWS Deep Learning Containers** + +This version passed benchmark testing and is migrated to the following AWS Deep Learning Containers (DLC): + +- SageMaker training container for PyTorch v1.12.1 + + .. code:: + + 763104351884.dkr.ecr..amazonaws.com/pytorch-training:1.12.1-gpu-py38-cu113-ubuntu20.04-sagemaker + + +Binary file of this version of the library for `custom container +`_ users: + +- For PyTorch 1.12.0 + + .. code:: + + https://sagemaker-distributed-model-parallel.s3.us-west-2.amazonaws.com/pytorch-1.12.1/build-artifacts/2022-12-08-21-34/smdistributed_modelparallel-1.13.0-cp38-cp38-linux_x86_64.whl + +---- + +Release History +=============== + +SageMaker Distributed Model Parallel 1.11.0 Release Notes +--------------------------------------------------------- + *Date: August. 17. 2022* **New Features** @@ -41,12 +92,7 @@ Binary file of this version of the library for `custom container .. code:: - https://sagemaker-distributed-model-parallel.s3.us-west-2.amazonaws.com/pytorch-1.12.0/build-artifacts/2022-08-12-16-58/smdistributed_modelparallel-1.11.0-cp38-cp38-linux_x86_64.whl - ----- - -Release History -=============== + https://sagemaker-distribu SageMaker Distributed Model Parallel 1.10.1 Release Notes --------------------------------------------------------- diff --git a/doc/api/training/smp_versions/latest.rst b/doc/api/training/smp_versions/latest.rst index 1a2032c9ed..1eb358b2a3 100644 --- a/doc/api/training/smp_versions/latest.rst +++ b/doc/api/training/smp_versions/latest.rst @@ -10,8 +10,8 @@ depending on which version of the library you need to use. To use the library, reference the **Common API** documentation alongside the framework specific API documentation. -Version 1.11.0 (Latest) -=========================================== +Version 1.11.0, 1.13.0 (Latest) +=============================== To use the library, reference the Common API documentation alongside the framework specific API documentation. diff --git a/doc/experiments/index.rst b/doc/experiments/index.rst new file mode 100644 index 0000000000..8c12f30edc --- /dev/null +++ b/doc/experiments/index.rst @@ -0,0 +1,10 @@ +############################ +Amazon SageMaker Experiments +############################ + +The SageMaker Python SDK supports to track and organize your machine learning workflow across SageMaker with jobs, such as Processing, Training and Transform, or locally. + +.. toctree:: + :maxdepth: 2 + + sagemaker.experiments diff --git a/doc/experiments/sagemaker.experiments.rst b/doc/experiments/sagemaker.experiments.rst new file mode 100644 index 0000000000..f0776ec43b --- /dev/null +++ b/doc/experiments/sagemaker.experiments.rst @@ -0,0 +1,20 @@ +Experiments +============ + +Run +------------- + +.. autoclass:: sagemaker.experiments.Run + :members: + +.. automethod:: sagemaker.experiments.load_run + +.. automethod:: sagemaker.experiments.list_runs + +.. autoclass:: sagemaker.experiments.SortByType + :members: + :undoc-members: + +.. autoclass:: sagemaker.experiments.SortOrderType + :members: + :undoc-members: diff --git a/doc/index.rst b/doc/index.rst index 2d4ebe32c1..69038056b0 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -60,6 +60,16 @@ Orchestrate your SageMaker training and inference workflows with Airflow and Kub workflows/index +**************************** +Amazon SageMaker Experiments +**************************** +You can use Amazon SageMaker Experiments to track machine learning experiments. + +.. toctree:: + :maxdepth: 2 + + experiments/index + ************************* Amazon SageMaker Debugger ************************* diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index b52f394bd0..961b04914a 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -11,7 +11,8 @@ contextlib2==21.6.0 awslogs==0.14.0 black==22.3.0 stopit==1.1.2 -apache-airflow==2.4.1 +# Update tox.ini to have correct version of airflow constraints file +apache-airflow==2.4.3 apache-airflow-providers-amazon==4.0.0 attrs==22.1.0 fabric==2.6.0 @@ -19,3 +20,4 @@ requests==2.27.1 sagemaker-experiments==0.1.35 Jinja2==3.0.3 pandas>=1.3.5,<1.5 +scikit-learn==1.0.2 diff --git a/setup.py b/setup.py index 4327045760..e2adb6b433 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ def read_requirements(filename): # Declare minimal set for installation required_packages = [ "attrs>=20.3.0,<23", - "boto3>=1.26.20,<2.0", + "boto3>=1.26.28,<2.0", "google-pasta", "numpy>=1.9.0,<2.0", "protobuf>=3.1,<4.0", diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index b156f2e65f..1abea5e48c 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -27,7 +27,7 @@ from sagemaker.deprecations import renamed_warning from sagemaker.estimator import EstimatorBase, _TrainingJob from sagemaker.inputs import FileSystemInput, TrainingInput -from sagemaker.utils import sagemaker_timestamp +from sagemaker.utils import sagemaker_timestamp, check_and_get_run_experiment_config from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline from sagemaker.workflow import is_pipeline_variable @@ -242,8 +242,8 @@ def fit( generates a default job name, based on the training image name and current timestamp. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -255,6 +255,7 @@ def fit( """ self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_training_job = _TrainingJob.start_new( self, records, experiment_config=experiment_config ) diff --git a/src/sagemaker/apiutils/_base_types.py b/src/sagemaker/apiutils/_base_types.py index e920797b18..9a7359e12b 100644 --- a/src/sagemaker/apiutils/_base_types.py +++ b/src/sagemaker/apiutils/_base_types.py @@ -173,8 +173,10 @@ def _search( search_items = search_method_response.get("Results", []) next_token = search_method_response.get(boto_next_token_name) for item in search_items: - if cls.__name__ in item: - yield search_item_factory(item[cls.__name__]) + # _TrialComponent class in experiments module is not public currently + class_name = cls.__name__.lstrip("_") + if class_name in item: + yield search_item_factory(item[class_name]) if not next_token: break except StopIteration: diff --git a/src/sagemaker/apiutils/_boto_functions.py b/src/sagemaker/apiutils/_boto_functions.py index 1e29f2ebea..a227d30ca8 100644 --- a/src/sagemaker/apiutils/_boto_functions.py +++ b/src/sagemaker/apiutils/_boto_functions.py @@ -68,7 +68,9 @@ def from_boto(boto_dict, boto_name_to_member_name, member_name_to_type): api_type, is_collection = member_name_to_type[member_name] if is_collection: if isinstance(boto_value, dict): - member_value = api_type.from_boto(boto_value) + member_value = { + key: api_type.from_boto(value) for key, value in boto_value.items() + } else: member_value = [api_type.from_boto(item) for item in boto_value] else: diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 4765630ce8..18fed12042 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -282,7 +282,6 @@ "text/csv", "application/jsonlines", "image/jpeg", - "image/jpg", "image/png", "application/x-npy", ), @@ -331,11 +330,11 @@ def __init__( s3_analysis_config_output_path (str): S3 prefix to store the analysis config output. If this field is None, then the ``s3_output_path`` will be used to store the ``analysis_config`` output. - label (str): Target attribute of the model required by bias metrics. - Specified as column name or index for CSV dataset or as JSONPath for JSONLines. + label (str): Target attribute of the model required by bias metrics. Specified as + column name or index for CSV dataset or as JMESPath expression for JSONLines. *Required parameter* except for when the input dataset does not contain the label. - features (List[str]): JSONPath for locating the feature columns for bias metrics if the - dataset format is JSONLines. + features (List[str]): JMESPath expression to locate the feature columns for + bias metrics if the dataset format is JSONLines. dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV, ``"application/jsonlines"`` for JSONLines, and ``"application/x-parquet"`` for Parquet. @@ -717,11 +716,11 @@ def __init__( ``label_headers=['cat','dog','fish']`` and infer the predicted label to be ``'fish'``. Args: - label (str or int): Index or JSONPath location in the model output for the prediction. - In case, this is a predicted label of the same type as the label in the dataset, - no further arguments need to be specified. - probability (str or int): Index or JSONPath location in the model output - for the predicted score(s). + label (str or int): Index or JMESPath expression to locate the prediction + in the model output. In case, this is a predicted label of the same type + as the label in the dataset, no further arguments need to be specified. + probability (str or int): Index or JMESPath expression to locate the predicted score(s) + in the model output. probability_threshold (float): An optional value for binary prediction tasks in which the model returns a probability, to indicate the threshold to convert the prediction to a boolean value. Default is ``0.5``. @@ -1646,9 +1645,9 @@ def run_explainability( You can request multiple methods at once by passing in a list of `~sagemaker.clarify.ExplainabilityConfig`. model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`): - Index or JSONPath to locate the predicted scores in the model output. This is not - required if the model output is a single score. Alternatively, it can be an instance - of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` + Index or JMESPath expression to locate the predicted scores in the model output. + This is not required if the model output is a single score. Alternatively, + it can be an instance of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` to provide more parameters like ``label_headers``. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. @@ -1775,9 +1774,9 @@ def run_bias_and_explainability( str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig` ): - Index or JSONPath to locate the predicted scores in the model output. This is not - required if the model output is a single score. Alternatively, it can be an instance - of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` + Index or JMESPath expression to locate the predicted scores in the model output. + This is not required if the model output is a single score. Alternatively, + it can be an instance of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` to provide more parameters like ``label_headers``. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. diff --git a/src/sagemaker/dataset_definition/inputs.py b/src/sagemaker/dataset_definition/inputs.py index 90a272c4d7..468be22ac3 100644 --- a/src/sagemaker/dataset_definition/inputs.py +++ b/src/sagemaker/dataset_definition/inputs.py @@ -124,8 +124,10 @@ class DatasetDefinition(ApiObject): """DatasetDefinition input.""" _custom_boto_types = { - "redshift_dataset_definition": (RedshiftDatasetDefinition, True), - "athena_dataset_definition": (AthenaDatasetDefinition, True), + # RedshiftDatasetDefinition and AthenaDatasetDefinition are not collection + # Instead they are singleton objects. Thus, set the is_collection flag to False. + "redshift_dataset_definition": (RedshiftDatasetDefinition, False), + "athena_dataset_definition": (AthenaDatasetDefinition, False), } def __init__( diff --git a/src/sagemaker/debugger/profiler_config.py b/src/sagemaker/debugger/profiler_config.py index 3d4a24e8d1..561de38b9f 100644 --- a/src/sagemaker/debugger/profiler_config.py +++ b/src/sagemaker/debugger/profiler_config.py @@ -32,6 +32,7 @@ def __init__( s3_output_path: Optional[Union[str, PipelineVariable]] = None, system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None, framework_profile_params: Optional[FrameworkProfile] = None, + disable_profiler: Optional[Union[str, PipelineVariable]] = False, ): """Initialize a ``ProfilerConfig`` instance. @@ -78,6 +79,7 @@ class and SageMaker Framework estimators. self.s3_output_path = s3_output_path self.system_monitor_interval_millis = system_monitor_interval_millis self.framework_profile_params = framework_profile_params + self.disable_profiler = disable_profiler def _to_request_dict(self): """Generate a request dictionary using the parameters provided when initializing the object. @@ -91,6 +93,8 @@ def _to_request_dict(self): if self.s3_output_path is not None: profiler_config_request["S3OutputPath"] = self.s3_output_path + profiler_config_request["DisableProfiler"] = self.disable_profiler + if self.system_monitor_interval_millis is not None: profiler_config_request[ "ProfilingIntervalInMilliseconds" diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 6f729267de..8ed9b724a5 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -79,6 +79,7 @@ get_config_value, name_from_base, to_string, + check_and_get_run_experiment_config, ) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -937,26 +938,29 @@ def _prepare_collection_configs(self): def _prepare_profiler_for_training(self): """Set necessary values and do basic validations in profiler config and profiler rules. - When user explicitly set rules to an empty list, default profiler rule won't be enabled. - Default profiler rule will be enabled in supported regions when either: - 1. user doesn't specify any rules, i.e., rules=None; or - 2. user only specify debugger rules, i.e., rules=[Rule.sagemaker(...)] + No default profiler rule will be used. The user needs to specify rules explicitly """ if self.disable_profiler: - if self.profiler_config: - raise RuntimeError("profiler_config cannot be set when disable_profiler is True.") + if self.profiler_config and not self.profiler_config.disable_profiler: + raise RuntimeError( + "profiler_config.disable_profiler cannot be False" + + " when disable_profiler is True." + ) if self.profiler_rules: raise RuntimeError("ProfilerRule cannot be set when disable_profiler is True.") elif _region_supports_profiler(self.sagemaker_session.boto_region_name): if self.profiler_config is None: self.profiler_config = ProfilerConfig(s3_output_path=self.output_path) if self.rules is None or (self.rules and not self.profiler_rules): - self.profiler_rules = [get_default_profiler_rule()] + self.profiler_rules = [] if self.profiler_config and not self.profiler_config.s3_output_path: self.profiler_config.s3_output_path = self.output_path self.profiler_rule_configs = self._prepare_profiler_rules() + # if profiler_config is still None, it means the job has profiler disabled + if self.profiler_config is None: + self.profiler_config = ProfilerConfig(disable_profiler=True) def _prepare_profiler_rules(self): """Set any necessary values in profiler rules, if they are provided.""" @@ -1047,7 +1051,7 @@ def latest_job_profiler_artifacts_path(self): error_message="""Cannot get the profiling output artifacts path. The Estimator is not associated with a training job.""" ) - if self.profiler_config is not None: + if self.profiler_config is not None and not self.profiler_config.disable_profiler: return os.path.join( self.profiler_config.s3_output_path, self.latest_training_job.name, @@ -1103,8 +1107,8 @@ def fit( job_name (str): Training job name. If not specified, the estimator generates a default job name based on the training image name and current timestamp. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -1122,6 +1126,7 @@ def fit( """ self._prepare_for_training(job_name=job_name) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config) self.jobs.append(self.latest_training_job) if wait: @@ -1893,8 +1898,8 @@ def enable_default_profiling(self): else: self.profiler_config = ProfilerConfig(s3_output_path=self.output_path) - self.profiler_rules = [get_default_profiler_rule()] - self.profiler_rule_configs = self._prepare_profiler_rules() + self.profiler_rules = [] + self.profiler_rule_configs = [] _TrainingJob.update( self, self.profiler_rule_configs, self.profiler_config._to_request_dict() @@ -2023,8 +2028,8 @@ def start_new(cls, estimator, inputs, experiment_config): inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -2033,6 +2038,7 @@ def start_new(cls, estimator, inputs, experiment_config): * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. Returns: sagemaker.estimator._TrainingJob: Constructed object that captures all information about the started training job. @@ -2053,8 +2059,8 @@ def _get_train_args(cls, estimator, inputs, experiment_config): inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -2063,6 +2069,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config): * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. Returns: Dict: dict for `sagemaker.session.Session.train` method diff --git a/src/sagemaker/experiments/__init__.py b/src/sagemaker/experiments/__init__.py new file mode 100644 index 0000000000..b87656b1ab --- /dev/null +++ b/src/sagemaker/experiments/__init__.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker Experiment Module""" +from __future__ import absolute_import + +from sagemaker.experiments.run import Run # noqa: F401 +from sagemaker.experiments.run import load_run # noqa: F401 +from sagemaker.experiments.run import list_runs # noqa: F401 +from sagemaker.experiments.run import SortOrderType # noqa: F401 +from sagemaker.experiments.run import SortByType # noqa: F401 diff --git a/src/sagemaker/experiments/_api_types.py b/src/sagemaker/experiments/_api_types.py new file mode 100644 index 0000000000..78f82565aa --- /dev/null +++ b/src/sagemaker/experiments/_api_types.py @@ -0,0 +1,251 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains API objects for SageMaker experiments.""" +from __future__ import absolute_import + +import enum +import numbers + +from sagemaker.apiutils import _base_types + + +class TrialComponentMetricSummary(_base_types.ApiObject): + """Summary model of a trial component. + + Attributes: + metric_name (str): The name of the metric. + source_arn (str): The ARN of the source. + time_stamp (datetime): Metric last updated value. + max (float): The max value of the metric. + min (float): The min value of the metric. + last (float): The last value of the metric. + count (float): The number of samples used to generate the metric. + avg (float): The average value of the metric. + std_dev (float): The standard deviation of the metric. + """ + + metric_name = None + source_arn = None + time_stamp = None + max = None + min = None + last = None + count = None + avg = None + std_dev = None + + def __init__(self, metric_name=None, source_arn=None, **kwargs): + super(TrialComponentMetricSummary, self).__init__( + metric_name=metric_name, source_arn=source_arn, **kwargs + ) + + +class TrialComponentParameters(_base_types.ApiObject): + """A dictionary of TrialComponentParameterValues""" + + @classmethod + def from_boto(cls, boto_dict, **kwargs): + """Converts a boto dict to a dictionary of TrialComponentParameterValues + + Args: + boto_dict (dict): boto response dictionary. + **kwargs: Arbitrary keyword arguments. + + Returns: + dict: Dictionary of parameter values. + """ + return_map = {} + for key, value in boto_dict.items(): + return_map[key] = value.get("NumberValue", value.get("StringValue", None)) + return return_map + + @classmethod + def to_boto(cls, parameters): + """Converts TrialComponentParameters to dict. + + Args: + parameters (TrialComponentParameters): Dictionary to convert. + + Returns: + dict: Dictionary of trial component parameters in boto format. + """ + boto_map = {} + for key, value in parameters.items(): + if isinstance(value, numbers.Number): + boto_map[key] = {"NumberValue": value} + else: + boto_map[key] = {"StringValue": str(value)} + return boto_map + + +class TrialComponentArtifact(_base_types.ApiObject): + """Trial component artifact. + + Attributes: + value (str): The artifact value. + media_type (str): The media type. + """ + + value = None + media_type = None + + def __init__(self, value=None, media_type=None, **kwargs): + super(TrialComponentArtifact, self).__init__(value=value, media_type=media_type, **kwargs) + + +class _TrialComponentStatusType(enum.Enum): + """The type of trial component status""" + + InProgress = "InProgress" + Completed = "Completed" + Failed = "Failed" + + +class TrialComponentStatus(_base_types.ApiObject): + """Status of the trial component. + + Attributes: + primary_status (str): The status of a trial component. + message (str): Status message. + """ + + primary_status = None + message = None + + def __init__(self, primary_status=None, message=None, **kwargs): + super(TrialComponentStatus, self).__init__( + primary_status=primary_status, message=message, **kwargs + ) + + +class TrialComponentSummary(_base_types.ApiObject): + """Summary model of a trial component. + + Attributes: + trial_component_name (str): Name of trial component. + trial_component_arn (str): ARN of the trial component. + display_name (str): Friendly display name in UI. + source_arn (str): ARN of the trial component source. + status (str): Status. + start_time (datetime): Start time. + end_time (datetime): End time. + creation_time (datetime): Creation time. + created_by (str): Created by. + last_modified_time (datetime): Date last modified. + last_modified_by (datetime): User last modified. + """ + + _custom_boto_types = { + "status": (TrialComponentStatus, False), + } + trial_component_name = None + trial_component_arn = None + display_name = None + source_arn = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + + +class TrialComponentSource(_base_types.ApiObject): + """Trial Component Source + + Attributes: + source_arn (str): The ARN of the source. + """ + + source_arn = None + + def __init__(self, source_arn=None, **kwargs): + super(TrialComponentSource, self).__init__(source_arn=source_arn, **kwargs) + + +class Parent(_base_types.ApiObject): + """The trial/experiment/run that a trial component is associated with. + + Attributes: + trial_name (str): Name of the trial. + experiment_name (str): Name of the experiment. + run_name (str): Name of the run. + """ + + trial_name = None + experiment_name = None + run_name = None + + +class TrialComponentSearchResult(_base_types.ApiObject): + """Summary model of an Trial Component search result. + + Attributes: + trial_component_arn (str): ARN of the trial component. + trial_component_name (str): Name of the trial component. + display_name (str): Display name of the trial component for UI display. + source (dict): The source of the trial component. + status (dict): The status of the trial component. + start_time (datetime): Start time. + end_time (datetime): End time. + creation_time (datetime): Creation time. + created_by (str): Created by. + last_modified_time (datetime): Date last modified. + last_modified_by (datetime): User last modified. + parameters (dict): The hyperparameters of the component. + input_artifacts (dict): The input artifacts of the component. + output_artifacts (dict): The output artifacts of the component. + metrics (list): The metrics for the component. + source_detail (dict): The source of the trial component. + tags (list): The list of tags that are associated with the trial component. + parents (list[Parent]): The parent of trial component. + """ + + _custom_boto_types = { + "parents": (Parent, True), # parents is a collection (list) of Parent objects + } + trial_component_arn = None + trial_component_name = None + display_name = None + source = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + parameters = None + input_artifacts = None + output_artifacts = None + metrics = None + source_detail = None + tags = None + parents = None + + +class TrialSummary(_base_types.ApiObject): + """Summary model of a trial. + + Attributes: + trial_arn (str): The ARN of the trial. + trial_name (str): The name of the trial. + creation_time (datetime): When the trial was created. + last_modified_time (datetime): When the trial was last modified. + """ + + trial_arn = None + trial_name = None + creation_time = None + last_modified_time = None diff --git a/src/sagemaker/experiments/_environment.py b/src/sagemaker/experiments/_environment.py new file mode 100644 index 0000000000..441661ae5a --- /dev/null +++ b/src/sagemaker/experiments/_environment.py @@ -0,0 +1,132 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the _RunEnvironment class.""" +from __future__ import absolute_import + +import enum +import json +import logging +import os + +from sagemaker.experiments import trial_component +from sagemaker.utils import retry_with_backoff + +TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN" +PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json" +TRANSFORM_JOB_ENV_BATCH_VAR = "SAGEMAKER_BATCH" +MAX_RETRY_ATTEMPTS = 7 + +logger = logging.getLogger(__name__) + + +class _EnvironmentType(enum.Enum): + """SageMaker jobs which data can be pulled from the environment.""" + + SageMakerTrainingJob = 1 + SageMakerProcessingJob = 2 + SageMakerTransformJob = 3 + + +class _RunEnvironment(object): + """Retrieves job specific data from the environment.""" + + def __init__(self, environment_type, source_arn): + """Init for _RunEnvironment. + + Args: + environment_type (_EnvironmentType): The environment type. + source_arn (str): The ARN of the current job. + """ + self.environment_type = environment_type + self.source_arn = source_arn + + @classmethod + def load( + cls, + training_job_arn_env=TRAINING_JOB_ARN_ENV, + processing_job_config_path=PROCESSING_JOB_CONFIG_PATH, + transform_job_batch_var=TRANSFORM_JOB_ENV_BATCH_VAR, + ): + """Loads source arn of current job from environment. + + Args: + training_job_arn_env (str): The environment key for training job ARN + (default: `TRAINING_JOB_ARN`). + processing_job_config_path (str): The processing job config path + (default: `/opt/ml/config/processingjobconfig.json`). + transform_job_batch_var (str): The environment variable indicating if + it is a transform job (default: `SAGEMAKER_BATCH`). + + Returns: + _RunEnvironment: Job data loaded from the environment. None if config does not exist. + """ + if training_job_arn_env in os.environ: + environment_type = _EnvironmentType.SageMakerTrainingJob + source_arn = os.environ.get(training_job_arn_env) + return _RunEnvironment(environment_type, source_arn) + if os.path.exists(processing_job_config_path): + environment_type = _EnvironmentType.SageMakerProcessingJob + source_arn = json.loads(open(processing_job_config_path).read())["ProcessingJobArn"] + return _RunEnvironment(environment_type, source_arn) + if transform_job_batch_var in os.environ and os.environ[transform_job_batch_var] == "true": + environment_type = _EnvironmentType.SageMakerTransformJob + # TODO: need to figure out how to get source_arn from job env + # with Transform team's help. + source_arn = "" + return _RunEnvironment(environment_type, source_arn) + + return None + + def get_trial_component(self, sagemaker_session): + """Retrieves the trial component from the job in the environment. + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + _TrialComponent: The trial component created from the job. None if not found. + """ + # TODO: Remove this condition check once we have a way to retrieve source ARN + # from transform job env + if self.environment_type == _EnvironmentType.SageMakerTransformJob: + logger.error( + "Currently getting the job trial component from the transform job environment " + "is not supported. Returning None." + ) + return None + + def _get_trial_component(): + summaries = list( + trial_component._TrialComponent.list( + source_arn=self.source_arn.lower(), sagemaker_session=sagemaker_session + ) + ) + if summaries: + summary = summaries[0] + return trial_component._TrialComponent.load( + trial_component_name=summary.trial_component_name, + sagemaker_session=sagemaker_session, + ) + return None + + job_tc = None + try: + job_tc = retry_with_backoff(_get_trial_component, MAX_RETRY_ATTEMPTS) + except Exception as ex: # pylint: disable=broad-except + logger.error( + "Failed to get trail component in the current environment due to %s", str(ex) + ) + return job_tc diff --git a/src/sagemaker/experiments/_helper.py b/src/sagemaker/experiments/_helper.py new file mode 100644 index 0000000000..0c689b1125 --- /dev/null +++ b/src/sagemaker/experiments/_helper.py @@ -0,0 +1,266 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the helper classes for SageMaker Experiment.""" +from __future__ import absolute_import + +import json +import logging +import os + +import botocore + +from sagemaker.experiments._utils import is_already_exist_error + +logger = logging.getLogger(__name__) + + +_DEFAULT_ARTIFACT_PREFIX = "trial-component-artifacts" +_DEFAULT_ARTIFACT_TYPE = "Tracker" + + +class _ArtifactUploader(object): + """Artifact uploader""" + + def __init__( + self, + trial_component_name, + sagemaker_session, + artifact_bucket=None, + artifact_prefix=_DEFAULT_ARTIFACT_PREFIX, + ): + """Initialize a `_ArtifactUploader` instance. + + Args: + trial_component_name (str): The name of the trial component, + which is used to generate the S3 path to upload the artifact to. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + artifact_bucket (str): The S3 bucket to upload the artifact to. + If not specified, the default bucket defined in `sagemaker_session` + will be used. + artifact_prefix (str): The S3 key prefix used to generate the S3 path + to upload the artifact to (default: "trial-component-artifacts"). + """ + self.sagemaker_session = sagemaker_session + self.trial_component_name = trial_component_name + self.artifact_bucket = artifact_bucket + self.artifact_prefix = artifact_prefix + self._s3_client = self.sagemaker_session.boto_session.client("s3") + + def upload_artifact(self, file_path): + """Upload an artifact file to S3. + + Args: + file_path (str): the file path of the artifact + + Returns: + (str, str): The s3 URI of the uploaded file and the etag of the file. + + Raises: + ValueError: If file does not exist. + """ + file_path = os.path.expanduser(file_path) + if not os.path.isfile(file_path): + raise ValueError( + "{} does not exist or is not a file. Please supply a file path.".format(file_path) + ) + if not self.artifact_bucket: + self.artifact_bucket = self.sagemaker_session.default_bucket() + artifact_name = os.path.basename(file_path) + artifact_s3_key = "{}/{}/{}".format( + self.artifact_prefix, self.trial_component_name, artifact_name + ) + self._s3_client.upload_file(file_path, self.artifact_bucket, artifact_s3_key) + etag = self._try_get_etag(artifact_s3_key) + return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag + + def upload_object_artifact(self, artifact_name, artifact_object, file_extension=None): + """Upload an artifact object to S3. + + Args: + artifact_name (str): the name of the artifact. + artifact_object (obj): the object of the artifact + file_extension (str): Optional file extension. + + Returns: + str: The s3 URI of the uploaded file and the version of the file. + """ + if not self.artifact_bucket: + self.artifact_bucket = self.sagemaker_session.default_bucket() + if file_extension: + artifact_name = ( + artifact_name + ("" if file_extension.startswith(".") else ".") + file_extension + ) + artifact_s3_key = "{}/{}/{}".format( + self.artifact_prefix, self.trial_component_name, artifact_name + ) + self._s3_client.put_object( + Body=json.dumps(artifact_object), Bucket=self.artifact_bucket, Key=artifact_s3_key + ) + etag = self._try_get_etag(artifact_s3_key) + return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag + + def _try_get_etag(self, key): + """Get ETag of given key and return None if not allowed + + Args: + key (str): The S3 object key. + + Returns: + str: The S3 object ETag if it allows, otherwise return None. + """ + try: + response = self._s3_client.head_object(Bucket=self.artifact_bucket, Key=key) + return response["ETag"] + except botocore.exceptions.ClientError as error: + # requires read permissions + logger.warning("Failed to get ETag of %s due to %s", key, error) + return None + + +class _LineageArtifactManager(object): + """A helper class to manage Lineage Artifacts""" + + def __init__( + self, + name, + source_uri, + etag, + source_arn=None, + dest_arn=None, + artifact_type=_DEFAULT_ARTIFACT_TYPE, + ): + """Initialize a `_LineageArtifactManager` instance. + + Args: + name (str): The name of the Lineage artifact to be created. + source_uri (str): The source URI used to create the Lineage artifact. + etag (str): The S3 Etag used to create the Lineage artifact. + source_arn (str): The source ARN of a trail component to associate + this Lineage artifact with (default: None). + dest_arn (str): The destination ARN of a trial component to associate + this Lineage artifact with (default: None). + artifact_type (str): The type of the Lineage artifact (default: "Tracker"). + """ + self.name = name + self.source_uri = source_uri + self.etag = etag + self.source_arn = source_arn + self.dest_arn = dest_arn + self.artifact_arn = None + self.artifact_type = artifact_type + + def create_artifact(self, sagemaker_session): + """Create the artifact by calling `CreateArtifact` API + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + """ + source_ids = [] + if self.etag: + source_ids.append({"SourceIdType": "S3ETag", "Value": self.etag}) + + try: + response = sagemaker_session.sagemaker_client.create_artifact( + ArtifactName=self.name, + ArtifactType=self.artifact_type, + Source={"SourceUri": self.source_uri, "SourceTypes": source_ids}, + ) + self.artifact_arn = response["ArtifactArn"] + except botocore.exceptions.ClientError as err: + err_info = err.response["Error"] + if not is_already_exist_error(err_info): + raise + logger.warning( + "Skip creating the artifact since it already exists: %s", err_info["Message"] + ) + + def add_association(self, sagemaker_session): + """Associate the artifact with a source/destination ARN (e.g. trial component arn) + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + """ + source_arn = self.source_arn if self.source_arn else self.artifact_arn + dest_arn = self.dest_arn if self.dest_arn else self.artifact_arn + # if the trial component (job) is the source then it produced the artifact, + # otherwise the artifact contributed to the trial component (job) + association_edge_type = "Produced" if self.source_arn else "ContributedTo" + try: + sagemaker_session.sagemaker_client.add_association( + SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_edge_type + ) + except botocore.exceptions.ClientError as err: + err_info = err.response["Error"] + if not is_already_exist_error(err_info): + raise + logger.warning( + "Skip associating since the association already exists: %s", err_info["Message"] + ) + + +class _LineageArtifactTracker(object): + """Lineage Artifact Tracker""" + + def __init__(self, trial_component_arn, sagemaker_session): + """Initialize a `_LineageArtifactTracker` instance. + + Args: + trial_component_arn (str): The ARN of the trial component to be + associated with the input/output artifacts. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + """ + self.trial_component_arn = trial_component_arn + self.sagemaker_session = sagemaker_session + self.artifacts = [] + + def add_input_artifact(self, name, source_uri, etag, artifact_type): + """Add a Lineage input artifact locally + + Args: + name (str): The name of the Lineage input artifact to be added. + source_uri (str): The source URI used to create the Lineage input artifact. + etag (str): The S3 Etag used to create the Lineage input artifact. + artifact_type (str): The type of the Lineage input artifact. + """ + artifact = _LineageArtifactManager( + name, source_uri, etag, dest_arn=self.trial_component_arn, artifact_type=artifact_type + ) + self.artifacts.append(artifact) + + def add_output_artifact(self, name, source_uri, etag, artifact_type): + """Add a Lineage output artifact locally + + Args: + name (str): The name of the Lineage output artifact to be added. + source_uri (str): The source URI used to create the Lineage output artifact. + etag (str): The S3 Etag used to create the Lineage output artifact. + artifact_type (str): The type of the Lineage output artifact. + """ + artifact = _LineageArtifactManager( + name, source_uri, etag, source_arn=self.trial_component_arn, artifact_type=artifact_type + ) + self.artifacts.append(artifact) + + def save(self): + """Persist any artifact data saved locally""" + for artifact in self.artifacts: + artifact.create_artifact(self.sagemaker_session) + artifact.add_association(self.sagemaker_session) diff --git a/src/sagemaker/experiments/_metrics.py b/src/sagemaker/experiments/_metrics.py new file mode 100644 index 0000000000..f80c43f337 --- /dev/null +++ b/src/sagemaker/experiments/_metrics.py @@ -0,0 +1,413 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes to manage metrics for Sagemaker Experiment""" +from __future__ import absolute_import + +import datetime +import json +import logging +import os +import time +import threading +import queue + +import dateutil.tz + +from sagemaker.session import Session + +METRICS_DIR = os.environ.get("SAGEMAKER_METRICS_DIRECTORY", ".") +METRIC_TS_LOWER_BOUND_TO_NOW = 1209600 # on seconds +METRIC_TS_UPPER_BOUND_FROM_NOW = 7200 # on seconds + +BATCH_SIZE = 10 + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# TODO: remove this _SageMakerFileMetricsWriter class +# when _MetricsManager is fully ready +class _SageMakerFileMetricsWriter(object): + """Write metric data to file.""" + + def __init__(self, metrics_file_path=None): + """Construct a `_SageMakerFileMetricsWriter` object""" + self._metrics_file_path = metrics_file_path + self._file = None + self._closed = False + + def log_metric(self, metric_name, value, timestamp=None, step=None): + """Write a metric to file. + + Args: + metric_name (str): The name of the metric. + value (float): The value of the metric. + timestamp (datetime.datetime): Timestamp of the metric. + If not specified, the current UTC time will be used. + step (int): Iteration number of the metric (default: None). + + Raises: + SageMakerMetricsWriterException: If the metrics file is closed. + AttributeError: If file has been initialized and the writer hasn't been closed. + """ + raw_metric_data = _RawMetricData( + metric_name=metric_name, value=value, timestamp=timestamp, step=step + ) + try: + logger.debug("Writing metric: %s", raw_metric_data) + self._file.write(json.dumps(raw_metric_data.to_record())) + self._file.write("\n") + except AttributeError as attr_err: + if self._closed: + raise SageMakerMetricsWriterException("log_metric called on a closed writer") + if not self._file: + self._file = open(self._get_metrics_file_path(), "a", buffering=1) + self._file.write(json.dumps(raw_metric_data.to_record())) + self._file.write("\n") + else: + raise attr_err + + def close(self): + """Closes the metric file.""" + if not self._closed and self._file: + self._file.close() + self._file = None # invalidate reference, causing subsequent log_metric to fail. + self._closed = True + + def __enter__(self): + """Return self""" + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Execute self.close()""" + self.close() + + def __del__(self): + """Execute self.close()""" + self.close() + + def _get_metrics_file_path(self): + """Get file path to store metrics""" + pid_filename = "{}.json".format(str(os.getpid())) + metrics_file_path = self._metrics_file_path or os.path.join(METRICS_DIR, pid_filename) + logger.debug("metrics_file_path = %s", metrics_file_path) + return metrics_file_path + + +class SageMakerMetricsWriterException(Exception): + """SageMakerMetricsWriterException""" + + def __init__(self, message, errors=None): + """Construct a `SageMakerMetricsWriterException` instance""" + super().__init__(message) + if errors: + self.errors = errors + + +class _RawMetricData(object): + """A Raw Metric Data Object""" + + MetricName = None + Value = None + Timestamp = None + Step = None + + def __init__(self, metric_name, value, timestamp=None, step=None): + """Construct a `_RawMetricData` instance. + + Args: + metric_name (str): The name of the metric. + value (float): The value of the metric. + timestamp (datetime.datetime or float or str): Timestamp of the metric. + If not specified, the current UTC time will be used. + step (int): Iteration number of the metric (default: None). + """ + if timestamp is None: + timestamp = time.time() + elif isinstance(timestamp, datetime.datetime): + # If the input is a datetime then convert it to UTC time. + # Assume a naive datetime is in local timezone + if not timestamp.tzinfo: + timestamp = timestamp.replace(tzinfo=dateutil.tz.tzlocal()) + timestamp = (timestamp - timestamp.utcoffset()).replace(tzinfo=datetime.timezone.utc) + timestamp = timestamp.timestamp() + else: + timestamp = float(timestamp) + + if timestamp < (time.time() - METRIC_TS_LOWER_BOUND_TO_NOW) or timestamp > ( + time.time() + METRIC_TS_UPPER_BOUND_FROM_NOW + ): + raise ValueError( + "Supplied timestamp %f is invalid." + " Timestamps must be between two weeks before and two hours from now." % timestamp + ) + value = float(value) + + self.MetricName = metric_name + self.Value = float(value) + self.Timestamp = timestamp + if step is not None: + if not isinstance(step, int): + raise ValueError("step must be int.") + self.Step = step + + def to_record(self): + """Convert the `_RawMetricData` object to dict""" + return self.__dict__ + + def to_raw_metric_data(self): + """Converts the metric data to a BatchPutMetrics RawMetricData item""" + # Convert timestamp from float to timestamp str. + # Otherwise will get ParamValidationError + raw_metric_data = { + "MetricName": self.MetricName, + "Value": self.Value, + "Timestamp": str(int(self.Timestamp)), + } + if self.Step is not None: + raw_metric_data["Step"] = int(self.Step) + return raw_metric_data + + def __str__(self): + """String representation of the `_RawMetricData` object.""" + return repr(self) + + def __repr__(self): + """Return a string representation of this _RawMetricData` object.""" + return "{}({})".format( + type(self).__name__, + ",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]), + ) + + +class _MetricsManager(object): + """Collects metrics and sends them directly to SageMaker Metrics data plane APIs.""" + + def __init__(self, trial_component_name: str, sagemaker_session: Session, sink=None) -> None: + """Initialize a `_MetricsManager` instance + + Args: + trial_component_name (str): The Name of the Trial Component to log metrics to + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + sink (object): The metrics sink to use. + """ + if sink is None: + self.sink = _SyncMetricsSink( + trial_component_name, sagemaker_session.sagemaker_metrics_client + ) + else: + self.sink = sink + + def log_metric(self, metric_name, value, timestamp=None, step=None): + """Sends a metric to metrics service.""" + + metric_data = _RawMetricData(metric_name, value, timestamp, step) + self.sink.log_metric(metric_data) + + def __enter__(self): + """Return self""" + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Execute self.close()""" + self.sink.close() + + def close(self): + """Close the metrics object.""" + self.sink.close() + + +class _SyncMetricsSink(object): + """Collects metrics and sends them directly to metrics service.""" + + def __init__(self, trial_component_name, metrics_client) -> None: + """Initialize a `_SyncMetricsSink` instance + + Args: + trial_component_name (str): The Name of the Trial Component to log metrics. + metrics_client (boto3.client): boto client for metrics service + """ + self._trial_component_name = trial_component_name + self._metrics_client = metrics_client + self._buffer = [] + + def log_metric(self, metric_data): + """Sends a metric to metrics service.""" + + # this is a simplistic solution which calls BatchPutMetrics + # on the same thread as the client code + self._buffer.append(metric_data) + self._drain() + + def _drain(self, close=False): + """Pops off all metrics in the buffer and starts sending them to metrics service.""" + + if not self._buffer: + return + + if len(self._buffer) < BATCH_SIZE and not close: + return + + # pop all the available metrics + available_metrics, self._buffer = self._buffer, [] + + self._send_metrics(available_metrics) + + def _send_metrics(self, metrics): + """Calls BatchPutMetrics directly on the metrics service.""" + while metrics: + batch, metrics = ( + metrics[:BATCH_SIZE], + metrics[BATCH_SIZE:], + ) + request = self._construct_batch_put_metrics_request(batch) + response = self._metrics_client.batch_put_metrics(**request) + errors = response["Errors"] if "Errors" in response else None + if errors: + message = errors[0]["Message"] + raise Exception(f'{len(errors)} errors with message "{message}"') + + def _construct_batch_put_metrics_request(self, batch): + """Creates dictionary object used as request to metrics service.""" + return { + "TrialComponentName": self._trial_component_name.lower(), + "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)), + } + + def close(self): + """Drains any remaining metrics.""" + self._drain(close=True) + + +class _MetricQueue(object): + """A thread safe queue for sending metrics to SageMaker. + + Args: + trial_component_name (str): the ARN of the resource + metric_name (str): the name of the metric + metrics_client (boto_client): the boto client for SageMaker Metrics service + """ + + _CONSUMER_SLEEP_SECONDS = 5 + + def __init__(self, trial_component_name, metric_name, metrics_client): + # infinite queue size + self._queue = queue.Queue() + self._buffer = [] + self._thread = threading.Thread(target=self._run) + self._started = False + self._finished = False + self._trial_component_name = trial_component_name + self._metrics_client = metrics_client + self._metric_name = metric_name + self._logged_metrics = 0 + + def log_metric(self, metric_data): + """Adds a metric data point to the queue""" + self._buffer.append(metric_data) + + if len(self._buffer) < BATCH_SIZE: + return + + self._enqueue_all() + + if not self._started: + self._thread.start() + self._started = True + + def _run(self): + """Starts the metric thread which sends metrics to SageMaker in batches""" + + while not self._queue.empty() or not self._finished: + if self._queue.empty(): + time.sleep(self._CONSUMER_SLEEP_SECONDS) + else: + batch = self._queue.get() + self._send_metrics(batch) + + def _send_metrics(self, metrics_batch): + """Calls BatchPutMetrics directly on the metrics service.""" + request = self._construct_batch_put_metrics_request(metrics_batch) + self._logged_metrics += len(metrics_batch) + self._metrics_client.batch_put_metrics(**request) + + def _construct_batch_put_metrics_request(self, batch): + """Creates dictionary object used as request to metrics service.""" + + return { + "TrialComponentName": self._trial_component_name, + "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)), + } + + def _enqueue_all(self): + """Enqueue all buffered metrics to be sent to SageMaker""" + + available_metrics, self._buffer = self._buffer, [] + if available_metrics: + self._queue.put(available_metrics) + + def close(self): + """Flushes any buffered metrics""" + + self._enqueue_all() + self._finished = True + + def is_active(self): + """Is the thread active (still draining metrics to SageMaker)""" + + return self._thread.is_alive() + + +class _AsyncMetricsSink(object): + """Collects metrics and sends them directly to metrics service.""" + + _COMPLETE_SLEEP_SECONDS = 1.0 + + def __init__(self, trial_component_name, metrics_client) -> None: + """Initialize a `_AsyncMetricsSink` instance + + Args: + trial_component_name (str): The Name of the Trial Component to log metrics to. + metrics_client (boto3.client): boto client for metrics service + """ + self._trial_component_name = trial_component_name + self._metrics_client = metrics_client + self._buffer = [] + self._is_draining = False + self._metric_queues = {} + + def log_metric(self, metric_data): + """Sends a metric to metrics service.""" + + if metric_data.MetricName in self._metric_queues: + self._metric_queues[metric_data.MetricName].log_metric(metric_data) + else: + cur_metric_queue = _MetricQueue( + self._trial_component_name, metric_data.MetricName, self._metrics_client + ) + self._metric_queues[metric_data.MetricName] = cur_metric_queue + cur_metric_queue.log_metric(metric_data) + + def close(self): + """Closes the metric file.""" + logging.debug("Closing") + for q in self._metric_queues.values(): + q.close() + + # TODO should probably use join + while any(map(lambda x: x.is_active(), self._metric_queues.values())): + time.sleep(self._COMPLETE_SLEEP_SECONDS) + logging.debug("Closed") diff --git a/src/sagemaker/experiments/_run_context.py b/src/sagemaker/experiments/_run_context.py new file mode 100644 index 0000000000..9a7dada5f4 --- /dev/null +++ b/src/sagemaker/experiments/_run_context.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment _RunContext class.""" +from __future__ import absolute_import + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sagemaker.experiments import Run + + +class _RunContext: + """A static context variable to keep track of the current Run object""" + + _context_run = None + + @classmethod + def add_run_object(cls, run: "Run"): + """Keep track of the current executing Run object + + by adding it to a class static variable. + + Args: + run (Run): The current Run object to be tracked. + """ + cls._context_run = run + + @classmethod + def drop_current_run(cls) -> "Run": + """Drop the Run object tracked in the global static variable + + as its execution finishes (its "with" block ends). + + Return: + Run: the dropped Run object. + """ + current_run = cls._context_run + cls._context_run = None + return current_run + + @classmethod + def get_current_run(cls) -> "Run": + """Return the current Run object without dropping it. + + Return: + Run: the current Run object to be returned. + """ + return cls._context_run diff --git a/src/sagemaker/experiments/_utils.py b/src/sagemaker/experiments/_utils.py new file mode 100644 index 0000000000..5ef5d99dad --- /dev/null +++ b/src/sagemaker/experiments/_utils.py @@ -0,0 +1,218 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment utility methods.""" +from __future__ import absolute_import + +import logging +import os + +import mimetypes +import urllib +from functools import wraps +from typing import Optional + +from sagemaker import Session +from sagemaker.apiutils import _utils +from sagemaker.experiments._environment import _RunEnvironment, _EnvironmentType +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression +from sagemaker.utils import retry_with_backoff + + +def resolve_artifact_name(file_path): + """Resolve artifact name from given file path. + + If not specified, will auto create one. + + Args: + file_path (str): Path to the file. + + Returns: + str: The resolved artifact name. + """ + _, filename = os.path.split(file_path) + if filename: + return filename + + return _utils.name("artifact") + + +def guess_media_type(file_path): + """Infer the media type of a file based on its file name. + + Args: + file_path (str): Path to the file. + + Returns: + str: The guessed media type. + """ + file_url = urllib.parse.urljoin("file:", urllib.request.pathname2url(file_path)) + guessed_media_type, _ = mimetypes.guess_type(file_url, strict=False) + return guessed_media_type + + +def verify_length_of_true_and_predicted(true_labels, predicted_attrs, predicted_attrs_name): + """Verify if lengths match between lists of true labels and predicted attributes. + + Args: + true_labels (list or array): The list of the true labels. + predicted_attrs (list or array): The list of the predicted labels/probabilities/scores. + predicted_attrs_name (str): The name of the predicted attributes. + + Raises: + ValueError: If lengths mismatch between true labels and predicted attributes. + """ + if len(true_labels) != len(predicted_attrs): + raise ValueError( + "Lengths mismatch between true labels and {}: " + "({} vs {}).".format(predicted_attrs_name, len(true_labels), len(predicted_attrs)) + ) + + +def validate_invoked_inside_run_context(func): + """A Decorator to force the decorated method called under Run context.""" + + @wraps(func) + def wrapper(*args, **kwargs): + self_instance = args[0] + if not self_instance._inside_load_context and not self_instance._inside_init_context: + raise RuntimeError("This method should be called inside context of 'with' statement.") + return func(*args, **kwargs) + + return wrapper + + +def is_already_exist_error(error): + """Check if the error indicates resource already exists + + Args: + error (dict): The "Error" field in the response of the + `botocore.exceptions.ClientError` + """ + return error["Code"] == "ValidationException" and "already exists" in error["Message"] + + +def get_tc_and_exp_config_from_job_env( + environment: _RunEnvironment, + sagemaker_session: Session, +) -> dict: + """Retrieve an experiment config from the job environment. + + Args: + environment (_RunEnvironment): The run environment object with job specific data. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + """ + job_name = environment.source_arn.split("/")[-1] + if environment.environment_type == _EnvironmentType.SageMakerTrainingJob: + job_response = retry_with_backoff( + callable_func=lambda: sagemaker_session.describe_training_job(job_name), + num_attempts=4, + ) + elif environment.environment_type == _EnvironmentType.SageMakerProcessingJob: + job_response = retry_with_backoff( + callable_func=lambda: sagemaker_session.describe_processing_job(job_name), + num_attempts=4, + ) + else: # environment.environment_type == _EnvironmentType.SageMakerTransformJob + raise RuntimeError( + "Failed to load the Run as loading experiment config " + "from transform job environment is not currently supported. " + "As a workaround, please explicitly pass in " + "the experiment_name and run_name in load_run." + ) + + job_exp_config = job_response.get("ExperimentConfig", dict()) + from sagemaker.experiments.run import RUN_NAME + + if job_exp_config.get(RUN_NAME, None): + return job_exp_config + raise RuntimeError( + "Not able to fetch RunName in ExperimentConfig of the sagemaker job. " + "Please make sure the ExperimentConfig is correctly set." + ) + + +def verify_load_input_names( + run_name: Optional[str] = None, + experiment_name: Optional[str] = None, +): + """Verify the run_name and the experiment_name inputs in load_run. + + Args: + run_name (str): The run_name supplied by the user (default: None). + experiment_name (str): The experiment_name supplied by the user + (default: None). + + Raises: + ValueError: If run_name is supplied while experiment_name is not. + """ + if not run_name and experiment_name: + logging.warning( + "No run_name is supplied. Ignoring the provided experiment_name " + "since it only takes effect along with run_name. " + "Will load the Run object from the job environment or current Run context." + ) + if run_name and not experiment_name: + raise ValueError( + "Invalid input: experiment_name is missing when run_name is supplied. " + "Please supply a valid experiment_name when the run_name is not None." + ) + + +def is_run_trial_component(trial_component_name: str, sagemaker_session: Session) -> bool: + """Check if a trial component is generated by `sagemaker.experiments.Run` + + Args: + trial_component_name (str): The name of the trial component. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + bool: Indicate whether the trial component is created by + `sagemaker.experiments.Run` or not. + """ + search_filter = Filter( + name="TrialComponentName", + operator=Operator.EQUALS, + value=trial_component_name, + ) + search_expression = SearchExpression(filters=[search_filter]) + + def search(): + return list( + _TrialComponent.search( + search_expression=search_expression, + max_results=1, # TrialComponentName is unique in an account + sagemaker_session=sagemaker_session, + ) + )[0] + + try: + tc_search_res = retry_with_backoff(search, 4) + from sagemaker.experiments.run import RUN_TC_TAG + + if not tc_search_res.tags or RUN_TC_TAG not in tc_search_res.tags: + return False + return True + except Exception as ex: # pylint: disable=broad-except + logging.warning( + "Failed to inspect the type of the trial component (%s), due to (%s)", + trial_component_name, + str(ex), + ) + return False diff --git a/src/sagemaker/experiments/experiment.py b/src/sagemaker/experiments/experiment.py new file mode 100644 index 0000000000..8f59ff36b3 --- /dev/null +++ b/src/sagemaker/experiments/experiment.py @@ -0,0 +1,237 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment class.""" +from __future__ import absolute_import + +import time + +from sagemaker.apiutils import _base_types +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + + +class _Experiment(_base_types.Record): + """An Amazon SageMaker experiment, which is a collection of related trials. + + New experiments are created by calling `experiments.experiment._Experiment.create`. + Existing experiments can be reloaded by calling `experiments.experiment._Experiment.load`. + + Attributes: + experiment_name (str): The name of the experiment. The name must be unique + within an account. + display_name (str): Name of the experiment that will appear in UI, + such as SageMaker Studio. + description (str): A description of the experiment. + tags (List[Dict[str, str]]): A list of tags to associate with the experiment. + """ + + experiment_name = None + display_name = None + description = None + tags = None + + _boto_create_method = "create_experiment" + _boto_load_method = "describe_experiment" + _boto_update_method = "update_experiment" + _boto_delete_method = "delete_experiment" + + _boto_update_members = ["experiment_name", "description", "display_name"] + _boto_delete_members = ["experiment_name"] + + _MAX_DELETE_ALL_ATTEMPTS = 3 + + def save(self): + """Save the state of this Experiment to SageMaker. + + Returns: + dict: Update experiment API response. + """ + return self._invoke_api(self._boto_update_method, self._boto_update_members) + + def delete(self): + """Delete this Experiment from SageMaker. + + Deleting an Experiment does not delete associated Trials and their Trial Components. + It requires that each Trial in the Experiment is first deleted. + + Returns: + dict: Delete experiment API response. + """ + return self._invoke_api(self._boto_delete_method, self._boto_delete_members) + + @classmethod + def load(cls, experiment_name, sagemaker_session=None): + """Load an existing experiment and return an `_Experiment` object representing it. + + Args: + experiment_name: (str): Name of the experiment + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.experiment._Experiment: A SageMaker `_Experiment` object + """ + return cls._construct( + cls._boto_load_method, + experiment_name=experiment_name, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def create( + cls, + experiment_name, + display_name=None, + description=None, + tags=None, + sagemaker_session=None, + ): + """Create a new experiment in SageMaker and return an `_Experiment` object. + + Args: + experiment_name: (str): Name of the experiment. Must be unique. Required. + display_name: (str): Name of the experiment that will appear in UI, + such as SageMaker Studio (default: None). + description: (str): Description of the experiment (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + tags (List[Dict[str, str]]): A list of tags to associate with the experiment + (default: None). + + Returns: + experiments.experiment._Experiment: A SageMaker `_Experiment` object + """ + return cls._construct( + cls._boto_create_method, + experiment_name=experiment_name, + display_name=display_name, + description=description, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def _load_or_create( + cls, + experiment_name, + display_name=None, + description=None, + tags=None, + sagemaker_session=None, + ): + """Load an experiment by name and create a new one if it does not exist. + + Args: + experiment_name: (str): Name of the experiment. Must be unique. Required. + display_name: (str): Name of the experiment that will appear in UI, + such as SageMaker Studio (default: None). This is used only when the + given `experiment_name` does not exist and a new experiment has to be created. + description: (str): Description of the experiment (default: None). + This is used only when the given `experiment_name` does not exist and + a new experiment has to be created. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + tags (List[Dict[str, str]]): A list of tags to associate with the experiment + (default: None). This is used only when the given `experiment_name` does not + exist and a new experiment has to be created. + + Returns: + experiments.experiment._Experiment: A SageMaker `_Experiment` object + """ + sagemaker_client = sagemaker_session.sagemaker_client + try: + experiment = _Experiment.load(experiment_name, sagemaker_session) + except sagemaker_client.exceptions.ResourceNotFound: + experiment = _Experiment.create( + experiment_name=experiment_name, + display_name=display_name, + description=description, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return experiment + + def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None): + """List trials in this experiment matching the specified criteria. + + Args: + created_before (datetime.datetime): Return trials created before this instant + (default: None). + created_after (datetime.datetime): Return trials created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' + (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + + Returns: + collections.Iterator[experiments._api_types.TrialSummary] : + An iterator over trials matching the criteria. + """ + return _Trial.list( + experiment_name=self.experiment_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + sagemaker_session=self.sagemaker_session, + ) + + def _delete_all(self, action): + """Force to delete the experiment and associated trials, trial components. + + Args: + action (str): The string '--force' is required to pass in to confirm recursively + delete the experiments, and all its trials and trial components. + """ + if action != "--force": + raise ValueError( + "Must confirm with string '--force' in order to delete the experiment and " + "associated trials, trial components." + ) + + delete_attempt_count = 0 + last_exception = None + while True: + if delete_attempt_count == self._MAX_DELETE_ALL_ATTEMPTS: + raise Exception("Failed to delete, please try again.") from last_exception + try: + for trial_summary in self.list_trials(): + trial = _Trial.load( + sagemaker_session=self.sagemaker_session, + trial_name=trial_summary.trial_name, + ) + for ( + trial_component_summary + ) in trial.list_trial_components(): # pylint: disable=no-member + tc = _TrialComponent.load( + sagemaker_session=self.sagemaker_session, + trial_component_name=trial_component_summary.trial_component_name, + ) + tc.delete(force_disassociate=True) + # to prevent throttling + time.sleep(1.2) + trial.delete() # pylint: disable=no-member + # to prevent throttling + time.sleep(1.2) + self.delete() + break + except Exception as ex: # pylint: disable=broad-except + last_exception = ex + finally: + delete_attempt_count = delete_attempt_count + 1 diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py new file mode 100644 index 0000000000..1492b6bafa --- /dev/null +++ b/src/sagemaker/experiments/run.py @@ -0,0 +1,882 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment Run class.""" +from __future__ import absolute_import + +import datetime +import logging +from enum import Enum +from math import isnan, isinf +from numbers import Number +from typing import Optional, List, Dict, TYPE_CHECKING, Union + +import dateutil +from numpy import array + +from sagemaker.apiutils import _utils +from sagemaker.experiments import _api_types +from sagemaker.experiments._api_types import TrialComponentArtifact, _TrialComponentStatusType +from sagemaker.experiments._helper import ( + _ArtifactUploader, + _LineageArtifactTracker, +) +from sagemaker.experiments._environment import _RunEnvironment +from sagemaker.experiments._run_context import _RunContext +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments._metrics import _MetricsManager +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + +from sagemaker.utils import ( + get_module, + unique_name_from_base, +) + +from sagemaker.experiments._utils import ( + guess_media_type, + resolve_artifact_name, + verify_length_of_true_and_predicted, + validate_invoked_inside_run_context, + get_tc_and_exp_config_from_job_env, + verify_load_input_names, + is_run_trial_component, +) + +if TYPE_CHECKING: + from sagemaker import Session + +logger = logging.getLogger(__name__) + +RUN_NAME_BASE = "Sagemaker-Run".lower() +TRIAL_NAME_TEMPLATE = "Default-Run-Group-{}" +MAX_RUN_TC_ARTIFACTS_LEN = 30 +MAX_NAME_LEN_IN_BACKEND = 120 +EXPERIMENT_NAME = "ExperimentName" +TRIAL_NAME = "TrialName" +RUN_NAME = "RunName" +DELIMITER = "-" +RUN_TC_TAG_KEY = "sagemaker:trial-component-source" +RUN_TC_TAG_VALUE = "run" +RUN_TC_TAG = {"Key": RUN_TC_TAG_KEY, "Value": RUN_TC_TAG_VALUE} + + +class SortByType(Enum): + """The type of property by which to sort the `list_runs` results.""" + + CREATION_TIME = "CreationTime" + NAME = "Name" + + +class SortOrderType(Enum): + """The type of order to sort the list or search results.""" + + ASCENDING = "Ascending" + DESCENDING = "Descending" + + +class Run(object): + """A collection of parameters, metrics, and artifacts to create a ML model.""" + + def __init__( + self, + experiment_name: str, + run_name: Optional[str] = None, + experiment_display_name: Optional[str] = None, + run_display_name: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, + sagemaker_session: Optional["Session"] = None, + ): + """Construct a `Run` instance. + + SageMaker Experiments automatically tracks the inputs, parameters, configurations, + and results of your iterations as runs. + You can assign, group, and organize these runs into experiments. + You can also create, compare, and evaluate runs. + + The code sample below shows how to initialize a run, log parameters to the Run object + and invoke a training job under the context of this Run object, which automatically + passes the run's ``experiment_config`` (including the experiment name, run name etc.) + to the training job. + + Note: + All log methods (e.g. ``log_parameter``, ``log_metric``, etc.) have to be called within + the run context (i.e. the ``with`` statement). Otherwise, a ``RuntimeError`` is thrown. + + .. code:: python + + with Run(experiment_name="my-exp", run_name="my-run", ...) as run: + run.log_parameter(...) + ... + estimator.fit(job_name="my-job") # Create a training job + + In order to reuse an existing run to log extra data, ``load_run`` is recommended. + The code snippet below displays how to load the run initialized above + in a custom training job script, where no ``run_name`` or ``experiment_name`` + is presented as they are automatically retrieved from the experiment config + in the job environment. + + Note: + Instead of the ``Run`` constructor, the ``load_run`` is recommended to use + in a job script to load the existing run created before the job launch. + Otherwise, a new run may be created each time you launch a job. + + .. code:: python + + with load_run() as run: + run.log_metric(...) + ... + + Args: + experiment_name (str): The name of the experiment. The name must be unique + within an account. + run_name (str): The name of the run. If it is not specified, one is auto generated. + experiment_display_name (str): Name of the experiment that will appear in UI, + such as SageMaker Studio. (default: None). This display name is used in + a create experiment call. If an experiment with the specified name already exists, + this display name won't take effect. + run_display_name (str): The display name of the run used in UI (default: None). + This display name is used in a create run call. If a run with the + specified name already exists, this display name won't take effect. + tags (List[Dict[str, str]]): A list of tags to be used for all create calls, + e.g. to create an experiment, a run group, etc. (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + """ + # TODO: we should revert the lower casting once backend fix reaches prod + self.experiment_name = experiment_name.lower() + sagemaker_session = sagemaker_session or _utils.default_session() + self.run_name = run_name or unique_name_from_base(RUN_NAME_BASE) + + # avoid confusion due to mis-match in casing between run name and TC name + self.run_name = self.run_name.lower() + + trial_component_name = Run._generate_trial_component_name( + run_name=self.run_name, experiment_name=self.experiment_name + ) + self.run_group_name = Run._generate_trial_name(self.experiment_name) + + self._experiment = _Experiment._load_or_create( + experiment_name=self.experiment_name, + display_name=experiment_display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + self._trial = _Trial._load_or_create( + experiment_name=self.experiment_name, + trial_name=self.run_group_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + self._trial_component, is_existed = _TrialComponent._load_or_create( + trial_component_name=trial_component_name, + display_name=run_display_name, + tags=Run._append_run_tc_label_to_tags(tags), + sagemaker_session=sagemaker_session, + ) + if is_existed: + logger.info( + "The run (%s) under experiment (%s) already exists. Loading it. " + "Note: sagemaker.experiments.load_run is recommended to use when " + "the desired run already exists.", + self.run_name, + self.experiment_name, + ) + self._trial.add_trial_component(self._trial_component) + + self._artifact_uploader = _ArtifactUploader( + trial_component_name=self._trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + self._lineage_artifact_tracker = _LineageArtifactTracker( + trial_component_arn=self._trial_component.trial_component_arn, + sagemaker_session=sagemaker_session, + ) + self._metrics_manager = _MetricsManager( + trial_component_name=self._trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + self._inside_init_context = False + self._inside_load_context = False + self._in_load = False + + @property + def experiment_config(self) -> dict: + """Get experiment config from run attributes.""" + return { + EXPERIMENT_NAME: self.experiment_name, + TRIAL_NAME: self.run_group_name, + RUN_NAME: self._trial_component.trial_component_name, + } + + @validate_invoked_inside_run_context + def log_parameter(self, name: str, value: Union[str, int, float]): + """Record a single parameter value for this run. + + Overwrites any previous value recorded for the specified parameter name. + + Args: + name (str): The name of the parameter. + value (str or int or float): The value of the parameter. + """ + if self._is_input_valid("parameter", name, value): + self._trial_component.parameters[name] = value + + @validate_invoked_inside_run_context + def log_parameters(self, parameters: Dict[str, Union[str, int, float]]): + """Record a collection of parameter values for this run. + + Args: + parameters (dict[str, str or int or float]): The parameters to record. + """ + filtered_parameters = { + key: value + for (key, value) in parameters.items() + if self._is_input_valid("parameter", key, value) + } + self._trial_component.parameters.update(filtered_parameters) + + @validate_invoked_inside_run_context + def log_metric( + self, + name: str, + value: float, + timestamp: Optional[datetime.datetime] = None, + step: Optional[int] = None, + ): + """Record a custom scalar metric value for this run. + + Note: + This method is for manual custom metrics, for automatic metrics see the + ``enable_sagemaker_metrics`` parameter on the ``estimator`` class. + + Args: + name (str): The name of the metric. + value (float): The value of the metric. + timestamp (datetime.datetime): The timestamp of the metric. + If not specified, the current UTC time will be used. + step (int): The integer iteration number of the metric value (default: None). + """ + if self._is_input_valid("metric", name, value): + self._metrics_manager.log_metric( + metric_name=name, value=value, timestamp=timestamp, step=step + ) + + @validate_invoked_inside_run_context + def log_precision_recall( + self, + y_true: Union[list, array], + predicted_probabilities: Union[list, array], + positive_label: Optional[Union[str, int]] = None, + title: Optional[str] = None, + is_output: bool = True, + no_skill: Optional[int] = None, + ): + """Create and log a precision recall graph artifact for Studio UI to render. + + The artifact is stored in S3 and represented as a lineage artifact + with an association with the run. + + You can view the artifact in the UI. + If your job is created by a pipeline execution you can view the artifact + by selecting the corresponding step in the pipelines UI. + See also `SageMaker Pipelines `_ + + This method requires sklearn library. + + Args: + y_true (list or array): True labels. If labels are not binary + then positive_label should be given. + predicted_probabilities (list or array): Estimated/predicted probabilities. + positive_label (str or int): Label of the positive class (default: None). + title (str): Title of the graph (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + no_skill (int): The precision threshold under which the classifier cannot discriminate + between the classes and would predict a random class or a constant class in + all cases (default: None). + """ + + verify_length_of_true_and_predicted( + true_labels=y_true, + predicted_attrs=predicted_probabilities, + predicted_attrs_name="predicted probabilities", + ) + + get_module("sklearn") + from sklearn.metrics import precision_recall_curve, average_precision_score + + kwargs = {} + if positive_label is not None: + kwargs["pos_label"] = positive_label + + precision, recall, _ = precision_recall_curve(y_true, predicted_probabilities, **kwargs) + + kwargs["average"] = "micro" + ap = average_precision_score(y_true, predicted_probabilities, **kwargs) + + data = { + "type": "PrecisionRecallCurve", + "version": 0, + "title": title, + "precision": precision.tolist(), + "recall": recall.tolist(), + "averagePrecisionScore": ap, + "noSkill": no_skill, + } + self._log_graph_artifact( + artifact_name=title, data=data, graph_type="PrecisionRecallCurve", is_output=is_output + ) + + @validate_invoked_inside_run_context + def log_roc_curve( + self, + y_true: Union[list, array], + y_score: Union[list, array], + title: Optional[str] = None, + is_output: bool = True, + ): + """Create and log a receiver operating characteristic (ROC curve) artifact. + + The artifact is stored in S3 and represented as a lineage artifact + with an association with the run. + + You can view the artifact in the UI. + If your job is created by a pipeline execution you can view the artifact + by selecting the corresponding step in the pipelines UI. + See also `SageMaker Pipelines `_ + + This method requires sklearn library. + + Args: + y_true (list or array): True labels. If labels are not binary + then positive_label should be given. + y_score (list or array): Estimated/predicted probabilities. + title (str): Title of the graph (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + verify_length_of_true_and_predicted( + true_labels=y_true, predicted_attrs=y_score, predicted_attrs_name="predicted scores" + ) + + get_module("sklearn") + from sklearn.metrics import roc_curve, auc + + fpr, tpr, _ = roc_curve(y_true, y_score) + + auc = auc(fpr, tpr) + + data = { + "type": "ROCCurve", + "version": 0, + "title": title, + "falsePositiveRate": fpr.tolist(), + "truePositiveRate": tpr.tolist(), + "areaUnderCurve": auc, + } + self._log_graph_artifact( + artifact_name=title, data=data, graph_type="ROCCurve", is_output=is_output + ) + + @validate_invoked_inside_run_context + def log_confusion_matrix( + self, + y_true: Union[list, array], + y_pred: Union[list, array], + title: Optional[str] = None, + is_output: bool = True, + ): + """Create and log a confusion matrix artifact. + + The artifact is stored in S3 and represented as a lineage artifact + with an association with the run. + + You can view the artifact in the UI. + If your job is created by a pipeline execution you can view the + artifact by selecting the corresponding step in the pipelines UI. + See also `SageMaker Pipelines `_ + This method requires sklearn library. + + Args: + y_true (list or array): True labels. If labels are not binary + then positive_label should be given. + y_pred (list or array): Predicted labels. + title (str): Title of the graph (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + verify_length_of_true_and_predicted( + true_labels=y_true, predicted_attrs=y_pred, predicted_attrs_name="predicted labels" + ) + + get_module("sklearn") + from sklearn.metrics import confusion_matrix + + matrix = confusion_matrix(y_true, y_pred) + + data = { + "type": "ConfusionMatrix", + "version": 0, + "title": title, + "confusionMatrix": matrix.tolist(), + } + self._log_graph_artifact( + artifact_name=title, data=data, graph_type="ConfusionMatrix", is_output=is_output + ) + + @validate_invoked_inside_run_context + def log_artifact( + self, name: str, value: str, media_type: Optional[str] = None, is_output: bool = True + ): + """Record a single artifact for this run. + + Overwrites any previous value recorded for the specified name. + + Args: + name (str): The name of the artifact. + value (str): The value. + media_type (str): The MediaType (MIME type) of the value (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + self._verify_trial_component_artifacts_length(is_output=is_output) + if is_output: + self._trial_component.output_artifacts[name] = TrialComponentArtifact( + value, media_type=media_type + ) + else: + self._trial_component.input_artifacts[name] = TrialComponentArtifact( + value, media_type=media_type + ) + + @validate_invoked_inside_run_context + def log_file( + self, + file_path: str, + name: Optional[str] = None, + media_type: Optional[str] = None, + is_output: bool = True, + ): + """Upload a file to s3 and store it as an input/output artifact in this run. + + Args: + file_path (str): The path of the local file to upload. + name (str): The name of the artifact (default: None). + media_type (str): The MediaType (MIME type) of the file. + If not specified, this library will attempt to infer the media type + from the file extension of ``file_path``. + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + self._verify_trial_component_artifacts_length(is_output) + media_type = media_type or guess_media_type(file_path) + name = name or resolve_artifact_name(file_path) + s3_uri, _ = self._artifact_uploader.upload_artifact(file_path) + if is_output: + self._trial_component.output_artifacts[name] = TrialComponentArtifact( + value=s3_uri, media_type=media_type + ) + else: + self._trial_component.input_artifacts[name] = TrialComponentArtifact( + value=s3_uri, media_type=media_type + ) + + def close(self): + """Persist any data saved locally.""" + try: + # Update the trial component with additions from the Run object + self._trial_component.save() + # Create Lineage entities for the artifacts + self._lineage_artifact_tracker.save() + finally: + if self._metrics_manager: + self._metrics_manager.close() + + @staticmethod + def _generate_trial_name(base_name) -> str: + """Generate the reserved trial name based on experiment name + + Args: + base_name (str): The ``experiment_name`` of this ``Run`` object. + """ + available_length = MAX_NAME_LEN_IN_BACKEND - len(TRIAL_NAME_TEMPLATE) + return TRIAL_NAME_TEMPLATE.format(base_name[:available_length]) + + @staticmethod + def _is_input_valid(input_type, field_name, field_value) -> bool: + """Check if the input is valid or not + + Args: + input_type (str): The type of the input, one of ``parameter``, ``metric``. + field_name (str): The name of the field to be checked. + field_value (str or int or float): The value of the field to be checked. + """ + if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)): + logger.warning( + "Failed to log %s %s. Received invalid value: %s.", + input_type, + field_name, + field_value, + ) + return False + return True + + def _log_graph_artifact(self, data, graph_type, is_output, artifact_name=None): + """Log an artifact. + + Logs an artifact by uploading data to S3, creating an artifact, and associating that + artifact with the run trial component. + + Args: + data (dict): Artifacts data that will be saved to S3. + graph_type (str): The type of the artifact. + is_output (bool): Determines direction of association to the + trial component. Defaults to True (output artifact). + If set to False then represented as input association. + artifact_name (str): Name of the artifact (default: None). + """ + # generate an artifact name + if not artifact_name: + unique_name_from_base(graph_type) + + # create a json file in S3 + s3_uri, etag = self._artifact_uploader.upload_object_artifact( + artifact_name, data, file_extension="json" + ) + + # create an artifact and association for the table + if is_output: + self._lineage_artifact_tracker.add_output_artifact( + name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type + ) + else: + self._lineage_artifact_tracker.add_input_artifact( + name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type + ) + + def _verify_trial_component_artifacts_length(self, is_output): + """Verify the length of trial component artifacts + + Args: + is_output (bool): Determines direction of association to the + trial component. + + Raises: + ValueError: If the length of trial component artifacts exceeds the limit. + """ + err_msg_template = "Cannot add more than {} {}_artifacts under run" + if is_output: + if len(self._trial_component.output_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN: + raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "output")) + else: + if len(self._trial_component.input_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN: + raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "input")) + + @staticmethod + def _generate_trial_component_name(run_name: str, experiment_name: str) -> str: + """Generate the TrialComponentName based on run_name and experiment_name + + Args: + run_name (str): The run_name supplied by the user. + experiment_name (str): The experiment_name supplied by the user, + which is prepended to the run_name to generate the TrialComponentName. + + Returns: + str: The TrialComponentName used to create a trial component + which is unique in an account. + + Raises: + ValueError: If either the run_name or the experiment_name exceeds + the length limit. + """ + buffer = 1 # leave length buffers for delimiters + max_len = int(MAX_NAME_LEN_IN_BACKEND / 2) - buffer + err_msg_template = "The {} (length: {}) must have length less than or equal to {}" + if len(run_name) > max_len: + raise ValueError(err_msg_template.format("run_name", len(run_name), max_len)) + if len(experiment_name) > max_len: + raise ValueError( + err_msg_template.format("experiment_name", len(experiment_name), max_len) + ) + trial_component_name = "{}{}{}".format(experiment_name, DELIMITER, run_name) + # due to mixed-case concerns on the backend + trial_component_name = trial_component_name.lower() + return trial_component_name + + @staticmethod + def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: str) -> str: + """Extract the user supplied run name from a trial component name. + + Args: + trial_component_name (str): The name of a run trial component. + experiment_name (str): The experiment_name supplied by the user, + which was prepended to the run_name to generate the trial_component_name. + + Returns: + str: The name of the Run object supplied by a user. + """ + return trial_component_name.replace("{}{}".format(experiment_name, DELIMITER), "", 1) + + @staticmethod + def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list: + """Append the run trial component label to tags used to create a trial component. + + Args: + tags (List[Dict[str, str]]): The tags supplied by users to initialize a Run object. + + Returns: + list: The updated tags with the appended run trial component label. + """ + if not tags: + tags = [] + tags.append(RUN_TC_TAG) + return tags + + def __enter__(self): + """Updates the start time of the run. + + Returns: + object: self. + """ + nested_with_err_msg_template = ( + "It is not allowed to use nested 'with' statements on the {}." + ) + if self._in_load: + if self._inside_load_context: + raise RuntimeError(nested_with_err_msg_template.format("load_run")) + self._inside_load_context = True + else: + if _RunContext.get_current_run(): + raise RuntimeError(nested_with_err_msg_template.format("Run")) + self._inside_init_context = True + _RunContext.add_run_object(self) + + if not self._trial_component.start_time: + start_time = datetime.datetime.now(dateutil.tz.tzlocal()) + self._trial_component.start_time = start_time + self._trial_component.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, + message="Within a run context", + ) + # Save the start_time and status changes to backend + self._trial_component.save() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Updates the end time of the run. + + Args: + exc_type (str): The exception type. + exc_value (str): The exception value. + exc_traceback (str): The stack trace of the exception. + """ + if self._in_load: + self._inside_load_context = False + self._in_load = False + else: + self._inside_init_context = False + _RunContext.drop_current_run() + + end_time = datetime.datetime.now(dateutil.tz.tzlocal()) + self._trial_component.end_time = end_time + if exc_value: + self._trial_component.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.Failed.value, message=str(exc_value) + ) + else: + self._trial_component.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.Completed.value + ) + + self.close() + + +def load_run( + run_name: Optional[str] = None, + experiment_name: Optional[str] = None, + sagemaker_session: Optional["Session"] = None, +) -> Run: + """Load an existing run. + + In order to reuse an existing run to log extra data, ``load_run`` is recommended. + It can be used in several ways: + + 1. Use ``load_run`` by explicitly passing in ``run_name`` and ``experiment_name``. + + If ``run_name`` and ``experiment_name`` are passed in, they are honored over + the default experiment config in the job environment or the run context + (i.e. within the ``with`` block). + + Note: + Both ``run_name`` and ``experiment_name`` should be supplied to make this usage work. + Otherwise, you may get a ``ValueError``. + + .. code:: python + + with load_run(experiment_name="my-exp", run_name="my-run") as run: + run.log_metric(...) + ... + + 2. Use the ``load_run`` in a job script without supplying ``run_name`` and ``experiment_name``. + + In this case, the default experiment config (specified when creating the job) is fetched + from the job environment to load the run. + + .. code:: python + + # In a job script + with load_run() as run: + run.log_metric(...) + ... + + 3. Use the ``load_run`` in a notebook within a run context (i.e. the ``with`` block) + but without supplying ``run_name`` and ``experiment_name``. + + Every time we call ``with Run(...) as run1:``, the initialized ``run1`` is tracked + in the run context. Then when we call ``load_run()`` under this with statement, the ``run1`` + in the context is loaded by default. + + .. code:: python + + # In a notebook + with Run(experiment_name="my-exp", run_name="my-run", ...) as run1: + run1.log_parameter(...) + + with load_run() as run2: # run2 is the same object as run1 + run2.log_metric(...) + ... + + Args: + run_name (str): The name of the run to be loaded (default: None). + If it is None, the ``RunName`` in the ``ExperimentConfig`` of the job will be + fetched to load the run. + experiment_name (str): The name of the Experiment that the to be loaded run + is associated with (default: None). + Note: the experiment_name must be supplied along with a valid run_name. + Otherwise, it will be ignored. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + Run: The loaded Run object. + """ + sagemaker_session = sagemaker_session or _utils.default_session() + environment = _RunEnvironment.load() + + verify_load_input_names(run_name=run_name, experiment_name=experiment_name) + + if run_name or environment: + if run_name: + logger.warning( + "run_name is explicitly supplied in load_run, " + "which will be prioritized to load the Run object. " + "In other words, the run name in the experiment config, fetched from the " + "job environment or the current run context, will be ignored." + ) + else: + exp_config = get_tc_and_exp_config_from_job_env( + environment=environment, sagemaker_session=sagemaker_session + ) + run_name = Run._extract_run_name_from_tc_name( + trial_component_name=exp_config[RUN_NAME], + experiment_name=exp_config[EXPERIMENT_NAME], + ) + experiment_name = exp_config[EXPERIMENT_NAME] + + run_instance = Run( + experiment_name=experiment_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ) + elif _RunContext.get_current_run(): + run_instance = _RunContext.get_current_run() + else: + raise RuntimeError( + "Failed to load a Run object. " + "Please make sure a Run object has been initialized already." + ) + + run_instance._in_load = True + return run_instance + + +def list_runs( + experiment_name: str, + created_before: Optional[datetime.datetime] = None, + created_after: Optional[datetime.datetime] = None, + sagemaker_session: Optional["Session"] = None, + max_results: Optional[int] = None, + next_token: Optional[str] = None, + sort_by: SortByType = SortByType.CREATION_TIME, + sort_order: SortOrderType = SortOrderType.DESCENDING, +) -> list: + """Return a list of ``Run`` objects matching the given criteria. + + Args: + experiment_name (str): Only Run objects related to the specified experiment + are returned. + created_before (datetime.datetime): Return Run objects created before this instant + (default: None). + created_after (datetime.datetime): Return Run objects created after this instant + (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + max_results (int): Maximum number of Run objects to retrieve (default: None). + next_token (str): Token for next page of results (default: None). + sort_by (SortByType): The property to sort results by. One of NAME, CREATION_TIME + (default: CREATION_TIME). + sort_order (SortOrderType): One of ASCENDING, or DESCENDING (default: DESCENDING). + + Returns: + list: A list of ``Run`` objects. + """ + tc_summaries = _TrialComponent.list( + experiment_name=experiment_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by.value, + sort_order=sort_order.value, + sagemaker_session=sagemaker_session, + max_results=max_results, + next_token=next_token, + ) + run_list = [] + for tc_summary in tc_summaries: + if not is_run_trial_component( + trial_component_name=tc_summary.trial_component_name, + sagemaker_session=sagemaker_session, + ): + continue + run_instance = Run( + experiment_name=experiment_name, + run_name=Run._extract_run_name_from_tc_name( + trial_component_name=tc_summary.trial_component_name, + experiment_name=experiment_name, + ), + sagemaker_session=sagemaker_session, + ) + run_list.append(run_instance) + return run_list diff --git a/src/sagemaker/experiments/trial.py b/src/sagemaker/experiments/trial.py new file mode 100644 index 0000000000..146b24f18b --- /dev/null +++ b/src/sagemaker/experiments/trial.py @@ -0,0 +1,289 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the Trial class.""" +from __future__ import absolute_import + +from sagemaker.apiutils import _base_types +from sagemaker.experiments import _api_types +from sagemaker.experiments.trial_component import _TrialComponent + + +class _Trial(_base_types.Record): + """An execution of a data-science workflow with an experiment. + + Consists of a list of trial component objects, which document individual + activities within the workflow. + + Attributes: + trial_name (str): The name of the trial. + experiment_name (str): The name of the trial's experiment. + display_name (str): The name of the trial that will appear in UI, + such as SageMaker Studio. + tags (List[Dict[str, str]]): A list of tags to associate with the trial. + """ + + trial_name = None + experiment_name = None + display_name = None + tags = None + + _boto_create_method = "create_trial" + _boto_load_method = "describe_trial" + _boto_delete_method = "delete_trial" + _boto_update_method = "update_trial" + + _boto_update_members = ["trial_name", "display_name"] + _boto_delete_members = ["trial_name"] + + @classmethod + def _boto_ignore(cls): + """Response fields to ignore by default.""" + return super(_Trial, cls)._boto_ignore() + ["CreatedBy"] + + def save(self): + """Save the state of this Trial to SageMaker. + + Returns: + dict: Update trial response. + """ + return self._invoke_api(self._boto_update_method, self._boto_update_members) + + def delete(self): + """Delete this Trial from SageMaker. + + Does not delete associated Trial Components. + + Returns: + dict: Delete trial response. + """ + return self._invoke_api(self._boto_delete_method, self._boto_delete_members) + + @classmethod + def load(cls, trial_name, sagemaker_session=None): + """Load an existing trial and return a `_Trial` object. + + Args: + trial_name: (str): Name of the Trial. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial._Trial: A SageMaker `_Trial` object + """ + return super(_Trial, cls)._construct( + cls._boto_load_method, + trial_name=trial_name, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def create( + cls, experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None + ): + """Create a new trial and return a `_Trial` object. + + Args: + experiment_name: (str): Name of the experiment to create this trial in. + trial_name: (str): Name of the Trial. + display_name (str): Name of the trial that will appear in UI, + such as SageMaker Studio (default: None). + tags (List[dict]): A list of tags to associate with the trial (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial._Trial: A SageMaker `_Trial` object + """ + trial = super(_Trial, cls)._construct( + cls._boto_create_method, + trial_name=trial_name, + experiment_name=experiment_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return trial + + @classmethod + def list( + cls, + experiment_name=None, + trial_component_name=None, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + sagemaker_session=None, + ): + """List all trials matching the specified criteria. + + Args: + experiment_name (str): Name of the experiment. If specified, only trials in + the experiment will be returned (default: None). + trial_component_name (str): Name of the trial component. If specified, only + trials with this trial component name will be returned (default: None). + created_before (datetime.datetime): Return trials created before this instant + (default: None). + created_after (datetime.datetime): Return trials created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' + (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + Returns: + collections.Iterator[experiments._api_types.TrialSummary]: An iterator over trials + matching the specified criteria. + """ + return super(_Trial, cls)._list( + "list_trials", + _api_types.TrialSummary.from_boto, + "TrialSummaries", + experiment_name=experiment_name, + trial_component_name=trial_component_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + sagemaker_session=sagemaker_session, + ) + + def add_trial_component(self, trial_component): + """Add the specified trial component to this trial. + + A trial component may belong to many trials and a trial may have many trial components. + + Args: + trial_component (str or _TrialComponent): The trial component to add. + Can be one of a _TrialComponent instance, or a string containing + the name of the trial component to add. + """ + if isinstance(trial_component, _TrialComponent): + trial_component_name = trial_component.trial_component_name + elif isinstance(trial_component, str): + trial_component_name = trial_component + else: + raise TypeError( + "Unsupported type of trail component {}. " + "It has to be one type of _TrialComponent or str".format(trial_component) + ) + self.sagemaker_session.sagemaker_client.associate_trial_component( + TrialName=self.trial_name, TrialComponentName=trial_component_name + ) + + def remove_trial_component(self, trial_component): + """Remove the specified trial component from this trial. + + Args: + trial_component (str or _TrialComponent): The trial component to add. + Can be one of a _TrialComponent instance, or a string containing + the name of the trial component to add. + """ + if isinstance(trial_component, _TrialComponent): + trial_component_name = trial_component.trial_component_name + elif isinstance(trial_component, str): + trial_component_name = trial_component + else: + raise TypeError( + "Unsupported type of trail component {}. " + "It has to be one type of _TrialComponent or str".format(trial_component) + ) + self.sagemaker_session.sagemaker_client.disassociate_trial_component( + TrialName=self.trial_name, TrialComponentName=trial_component_name + ) + + def list_trial_components( + self, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + max_results=None, + next_token=None, + ): + """List trial components in this trial matching the specified criteria. + + Args: + created_before (datetime.datetime): Return trials created before this instant + (default: None). + created_after (datetime.datetime): Return trials created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', + 'CreationTime' (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + max_results (int): maximum number of trial components to retrieve (default: None). + next_token (str): token for next page of results (default: None). + + Returns: + collections.Iterator[experiments._api_types.TrialComponentSummary] : An iterator over + trials matching the criteria. + """ + return _TrialComponent.list( + trial_name=self.trial_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + max_results=max_results, + next_token=next_token, + sagemaker_session=self.sagemaker_session, + ) + + @classmethod + def _load_or_create( + cls, experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None + ): + """Load a trial by name and create a new one if it does not exist. + + Args: + experiment_name: (str): Name of the experiment to create this trial in. + trial_name: (str): Name of the Trial. + display_name (str): Name of the trial that will appear in UI, + such as SageMaker Studio (default: None). This is used only when the given + `trial_name` does not exist and a new trial has to be created. + tags (List[dict]): A list of tags to associate with the trial (default: None). + This is used only when the given `trial_name` does not exist and + a new trial has to be created. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial._Trial: A SageMaker `_Trial` object + """ + sagemaker_client = sagemaker_session.sagemaker_client + try: + trial = _Trial.load(trial_name, sagemaker_session) + if trial.experiment_name != experiment_name: # pylint: disable=no-member + raise ValueError( + "The given experiment_name {} ".format(experiment_name) + + "does not match that in the loaded trial {}".format( + trial.experiment_name # pylint: disable=no-member + ) + ) + except sagemaker_client.exceptions.ResourceNotFound: + trial = _Trial.create( + experiment_name=experiment_name, + trial_name=trial_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return trial diff --git a/src/sagemaker/experiments/trial_component.py b/src/sagemaker/experiments/trial_component.py new file mode 100644 index 0000000000..e5701b2119 --- /dev/null +++ b/src/sagemaker/experiments/trial_component.py @@ -0,0 +1,341 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the TrialComponent class.""" +from __future__ import absolute_import + +import time + +from sagemaker.apiutils import _base_types +from sagemaker.experiments import _api_types +from sagemaker.experiments._api_types import TrialComponentSearchResult + + +class _TrialComponent(_base_types.Record): + """This class represents a SageMaker trial component object. + + A trial component is a stage in a trial. + Trial components are created automatically within the SageMaker runtime and + may not be created directly. To automatically associate trial components with + a trial and experiment, supply an experiment config when creating a job. + For example: https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html + + Attributes: + trial_component_name (str): The name of the trial component. Generated by SageMaker + from the name of the source job with a suffix specific to the type of source job. + trial_component_arn (str): The ARN of the trial component. + display_name (str): The name of the trial component that will appear in UI, + such as SageMaker Studio. + source (TrialComponentSource): A TrialComponentSource object with a source_arn attribute. + status (str): Status of the source job. + start_time (datetime): When the source job started. + end_time (datetime): When the source job ended. + creation_time (datetime): When the source job was created. + created_by (obj): Contextual info on which account created the trial component. + last_modified_time (datetime): When the trial component was last modified. + last_modified_by (obj): Contextual info on which account last modified the trial component. + parameters (dict): Dictionary of parameters to the source job. + input_artifacts (dict): Dictionary of input artifacts. + output_artifacts (dict): Dictionary of output artifacts. + metrics (obj): Aggregated metrics for the job. + parameters_to_remove (list): The hyperparameters to remove from the component. + input_artifacts_to_remove (list): The input artifacts to remove from the component. + output_artifacts_to_remove (list): The output artifacts to remove from the component. + tags (List[Dict[str, str]]): A list of tags to associate with the trial component. + """ + + trial_component_name = None + trial_component_arn = None + display_name = None + source = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + parameters = None + input_artifacts = None + output_artifacts = None + metrics = None + parameters_to_remove = None + input_artifacts_to_remove = None + output_artifacts_to_remove = None + tags = None + + _boto_load_method = "describe_trial_component" + _boto_create_method = "create_trial_component" + _boto_update_method = "update_trial_component" + _boto_delete_method = "delete_trial_component" + + _custom_boto_types = { + "source": (_api_types.TrialComponentSource, False), + "status": (_api_types.TrialComponentStatus, False), + "parameters": (_api_types.TrialComponentParameters, False), + "input_artifacts": (_api_types.TrialComponentArtifact, True), + "output_artifacts": (_api_types.TrialComponentArtifact, True), + "metrics": (_api_types.TrialComponentMetricSummary, True), + } + + _boto_update_members = [ + "trial_component_name", + "display_name", + "status", + "start_time", + "end_time", + "parameters", + "input_artifacts", + "output_artifacts", + "parameters_to_remove", + "input_artifacts_to_remove", + "output_artifacts_to_remove", + ] + _boto_delete_members = ["trial_component_name"] + + def __init__(self, sagemaker_session=None, **kwargs): + """Init for _TrialComponent""" + super().__init__(sagemaker_session, **kwargs) + self.parameters = self.parameters or {} + self.input_artifacts = self.input_artifacts or {} + self.output_artifacts = self.output_artifacts or {} + + @classmethod + def _boto_ignore(cls): + """Response fields to ignore by default.""" + return super(_TrialComponent, cls)._boto_ignore() + ["CreatedBy"] + + def save(self): + """Save the state of this TrialComponent to SageMaker.""" + return self._invoke_api(self._boto_update_method, self._boto_update_members) + + def delete(self, force_disassociate=False): + """Delete this TrialComponent from SageMaker. + + Args: + force_disassociate (boolean): Indicates whether to force disassociate the + trial component with the trials before deletion (default: False). + If set to true, force disassociate the trial component with associated trials + first, then delete the trial component. + If it's not set or set to false, it will delete the trial component directory + without disassociation. + + Returns: + dict: Delete trial component response. + """ + if force_disassociate: + next_token = None + + while True: + if next_token: + list_trials_response = self.sagemaker_session.sagemaker_client.list_trials( + TrialComponentName=self.trial_component_name, NextToken=next_token + ) + else: + list_trials_response = self.sagemaker_session.sagemaker_client.list_trials( + TrialComponentName=self.trial_component_name + ) + + # Disassociate the trials and trial components + for per_trial in list_trials_response["TrialSummaries"]: + # to prevent DisassociateTrialComponent throttling + time.sleep(1.2) + self.sagemaker_session.sagemaker_client.disassociate_trial_component( + TrialName=per_trial["TrialName"], + TrialComponentName=self.trial_component_name, + ) + + if "NextToken" in list_trials_response: + next_token = list_trials_response["NextToken"] + else: + break + + return self._invoke_api(self._boto_delete_method, self._boto_delete_members) + + @classmethod + def load(cls, trial_component_name, sagemaker_session=None): + """Load an existing trial component and return an `_TrialComponent` object representing it. + + Args: + trial_component_name (str): Name of the trial component + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object + """ + trial_component = cls._construct( + cls._boto_load_method, + trial_component_name=trial_component_name, + sagemaker_session=sagemaker_session, + ) + return trial_component + + @classmethod + def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None): + """Create a trial component and return a `_TrialComponent` object representing it. + + Args: + trial_component_name (str): The name of the trial component. + display_name (str): Display name of the trial component used by Studio (default: None). + tags (List[Dict[str, str]]): Tags to add to the trial component (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object. + """ + return super(_TrialComponent, cls)._construct( + cls._boto_create_method, + trial_component_name=trial_component_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def list( + cls, + source_arn=None, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + sagemaker_session=None, + trial_name=None, + experiment_name=None, + max_results=None, + next_token=None, + ): + """Return a list of trial component summaries. + + Args: + source_arn (str): A SageMaker Training or Processing Job ARN (default: None). + created_before (datetime.datetime): Return trial components created before this instant + (default: None). + created_after (datetime.datetime): Return trial components created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' + (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + trial_name (str): If provided only trial components related to the trial are returned + (default: None). + experiment_name (str): If provided only trial components related to the experiment are + returned (default: None). + max_results (int): maximum number of trial components to retrieve (default: None). + next_token (str): token for next page of results (default: None). + Returns: + collections.Iterator[experiments._api_types.TrialComponentSummary]: An iterator + over `TrialComponentSummary` objects. + """ + return super(_TrialComponent, cls)._list( + "list_trial_components", + _api_types.TrialComponentSummary.from_boto, + "TrialComponentSummaries", + source_arn=source_arn, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + sagemaker_session=sagemaker_session, + trial_name=trial_name, + experiment_name=experiment_name, + max_results=max_results, + next_token=next_token, + ) + + @classmethod + def search( + cls, + search_expression=None, + sort_by=None, + sort_order=None, + max_results=None, + sagemaker_session=None, + ): + """Search Experiment Trail Component. + + Returns SearchResults in the account matching the search criteria. + + Args: + search_expression: (SearchExpression): A Boolean conditional statement (default: None). + Resource objects must satisfy this condition to be included in search results. + You must provide at least one subexpression, filter, or nested filter. + sort_by (str): The name of the resource property used to sort the SearchResults + (default: None). + sort_order (str): How SearchResults are ordered. Valid values are Ascending or + Descending (default: None). + max_results (int): The maximum number of results to return in a SearchResponse + (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + collections.Iterator[SearchResult] : An iterator over search results matching the + search criteria. + """ + return super(_TrialComponent, cls)._search( + search_resource="ExperimentTrialComponent", + search_item_factory=TrialComponentSearchResult.from_boto, + search_expression=None if search_expression is None else search_expression.to_boto(), + sort_by=sort_by, + sort_order=sort_order, + max_results=max_results, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def _load_or_create( + cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None + ): + """Load a trial component by name and create a new one if it does not exist. + + Args: + trial_component_name (str): The name of the trial component. + display_name (str): Display name of the trial component used by Studio (default: None). + This is used only when the given `trial_component_name` does not + exist and a new trial component has to be created. + tags (List[Dict[str, str]]): Tags to add to the trial component (default: None). + This is used only when the given `trial_component_name` does not + exist and a new trial component has to be created. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object. + bool: A boolean variable indicating whether the trail component already exists + """ + sagemaker_client = sagemaker_session.sagemaker_client + is_existed = False + try: + run_tc = _TrialComponent.load(trial_component_name, sagemaker_session) + is_existed = True + except sagemaker_client.exceptions.ResourceNotFound: + run_tc = _TrialComponent.create( + trial_component_name=trial_component_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return run_tc, is_existed diff --git a/src/sagemaker/feature_store/dataset_builder.py b/src/sagemaker/feature_store/dataset_builder.py new file mode 100644 index 0000000000..fc82997379 --- /dev/null +++ b/src/sagemaker/feature_store/dataset_builder.py @@ -0,0 +1,990 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Dataset Builder + +A Dataset Builder is a builder class for generating a dataset by providing conditions. +""" +from __future__ import absolute_import + +import datetime +from enum import Enum +import os +from typing import Any, Dict, List, Tuple, Union + +import attr +import pandas as pd + +from sagemaker import Session, s3, utils +from sagemaker.feature_store.feature_group import FeatureDefinition, FeatureGroup, FeatureTypeEnum + + +_DEFAULT_CATALOG = "AwsDataCatalog" +_DEFAULT_DATABASE = "sagemaker_featurestore" + + +@attr.s +class TableType(Enum): + """Enum of Table types. + + The data type of a table can be FeatureGroup or DataFrame. + """ + + FEATURE_GROUP = "FeatureGroup" + DATA_FRAME = "DataFrame" + + +@attr.s +class FeatureGroupToBeMerged: + """FeatureGroup metadata which will be used for SQL join. + + This class instantiates a FeatureGroupToBeMerged object that comprises a list of feature names, + a list of feature names which will be included in SQL query, a database, an Athena table name, + a feature name of record identifier, a feature name of event time identifier and a feature name + of base which is the target join key. + + Attributes: + features (List[str]): A list of strings representing feature names of this FeatureGroup. + included_feature_names (List[str]): A list of strings representing features to be + included in the sql join. + projected_feature_names (List[str]): A list of strings representing features to be + included for final projection in output. + catalog (str): A string representing the catalog. + database (str): A string representing the database. + table_name (str): A string representing the Athena table name of this FeatureGroup. + record_dentifier_feature_name (str): A string representing the record identifier feature. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + target_feature_name_in_base (str): A string representing the feature name in base which will + be used as target join key (default: None). + table_type (TableType): A TableType representing the type of table if it is Feature Group or + Panda Data Frame (default: None). + """ + + features: List[str] = attr.ib() + included_feature_names: List[str] = attr.ib() + projected_feature_names: List[str] = attr.ib() + catalog: str = attr.ib() + database: str = attr.ib() + table_name: str = attr.ib() + record_identifier_feature_name: str = attr.ib() + event_time_identifier_feature: FeatureDefinition = attr.ib() + target_feature_name_in_base: str = attr.ib(default=None) + table_type: TableType = attr.ib(default=None) + + +def construct_feature_group_to_be_merged( + feature_group: FeatureGroup, + included_feature_names: List[str], + target_feature_name_in_base: str = None, +) -> FeatureGroupToBeMerged: + """Construct a FeatureGroupToBeMerged object by provided parameters. + + Args: + feature_group (FeatureGroup): A FeatureGroup object. + included_feature_names (List[str]): A list of strings representing features to be + included in the output. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as target join key (default: None). + Returns: + A FeatureGroupToBeMerged object. + + Raises: + ValueError: Invalid feature name(s) in included_feature_names. + """ + feature_group_metadata = feature_group.describe() + data_catalog_config = feature_group_metadata.get("OfflineStoreConfig", {}).get( + "DataCatalogConfig", None + ) + if not data_catalog_config: + raise RuntimeError(f"No metastore is configured with FeatureGroup {feature_group.name}.") + + record_identifier_feature_name = feature_group_metadata.get("RecordIdentifierFeatureName", None) + feature_definitions = feature_group_metadata.get("FeatureDefinitions", []) + event_time_identifier_feature_name = feature_group_metadata.get("EventTimeFeatureName", None) + event_time_identifier_feature_type = FeatureTypeEnum( + next( + filter( + lambda f: f.get("FeatureName", None) == event_time_identifier_feature_name, + feature_definitions, + ), + {}, + ).get("FeatureType", None) + ) + table_name = data_catalog_config.get("TableName", None) + database = data_catalog_config.get("Database", None) + disable_glue = feature_group_metadata.get("DisableGlueTableCreation", False) + catalog = data_catalog_config.get("Catalog", None) if disable_glue else _DEFAULT_CATALOG + features = [feature.get("FeatureName", None) for feature in feature_definitions] + + for included_feature in included_feature_names or []: + if included_feature not in features: + raise ValueError( + f"Feature {included_feature} not found in FeatureGroup {feature_group.name}" + ) + if not included_feature_names: + included_feature_names = features + projected_feature_names = features.copy() + else: + projected_feature_names = included_feature_names.copy() + if record_identifier_feature_name not in included_feature_names: + included_feature_names.append(record_identifier_feature_name) + if event_time_identifier_feature_name not in included_feature_names: + included_feature_names.append(event_time_identifier_feature_name) + return FeatureGroupToBeMerged( + features, + included_feature_names, + projected_feature_names, + catalog, + database, + table_name, + record_identifier_feature_name, + FeatureDefinition(event_time_identifier_feature_name, event_time_identifier_feature_type), + target_feature_name_in_base, + TableType.FEATURE_GROUP, + ) + + +@attr.s +class DatasetBuilder: + """DatasetBuilder definition. + + This class instantiates a DatasetBuilder object that comprises a base, a list of feature names, + an output path and a KMS key ID. + + Attributes: + _sagemaker_session (Session): Session instance to perform boto calls. + _base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a + pandas.DataFrame and will be used to merge other FeatureGroups and generate a Dataset. + _output_path (str): An S3 URI which stores the output .csv file. + _record_identifier_feature_name (str): A string representing the record identifier feature + if base is a DataFrame (default: None). + _event_time_identifier_feature_name (str): A string representing the event time identifier + feature if base is a DataFrame (default: None). + _included_feature_names (List[str]): A list of strings representing features to be + included in the output (default: None). + _kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file + (default: None). + _point_in_time_accurate_join (bool): A boolean representing whether using point in time join + or not (default: False). + _include_duplicated_records (bool): A boolean representing whether including duplicated + records or not (default: False). + _include_deleted_records (bool): A boolean representing whether including deleted records or + not (default: False). + _number_of_recent_records (int): An int that how many records will be returned for each + record identifier (default: 1). + _number_of_records (int): An int that how many records will be returned (default: None). + _write_time_ending_timestamp (datetime.datetime): A datetime that all records' write time in + dataset will be before it (default: None). + _event_time_starting_timestamp (datetime.datetime): A datetime that all records' event time + in dataset will be after it (default: None). + _event_time_ending_timestamp (datetime.datetime): A datetime that all records' event time in + dataset will be before it (default: None). + _feature_groups_to_be_merged (List[FeatureGroupToBeMerged]): A list of + FeatureGroupToBeMerged which will be joined to base (default: []). + _event_time_identifier_feature_type (FeatureTypeEnum): A FeatureTypeEnum representing the + type of event time identifier feature (default: None). + """ + + _sagemaker_session: Session = attr.ib() + _base: Union[FeatureGroup, pd.DataFrame] = attr.ib() + _output_path: str = attr.ib() + _record_identifier_feature_name: str = attr.ib(default=None) + _event_time_identifier_feature_name: str = attr.ib(default=None) + _included_feature_names: List[str] = attr.ib(default=None) + _kms_key_id: str = attr.ib(default=None) + + _point_in_time_accurate_join: bool = attr.ib(init=False, default=False) + _include_duplicated_records: bool = attr.ib(init=False, default=False) + _include_deleted_records: bool = attr.ib(init=False, default=False) + _number_of_recent_records: int = attr.ib(init=False, default=None) + _number_of_records: int = attr.ib(init=False, default=None) + _write_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None) + _event_time_starting_timestamp: datetime.datetime = attr.ib(init=False, default=None) + _event_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None) + _feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = attr.ib(init=False, factory=list) + _event_time_identifier_feature_type: FeatureTypeEnum = attr.ib(default=None) + + _DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP = { + "object": "STRING", + "int64": "INT", + "float64": "DOUBLE", + "bool": "BOOLEAN", + "datetime64[ns]": "TIMESTAMP", + } + + def with_feature_group( + self, + feature_group: FeatureGroup, + target_feature_name_in_base: str = None, + included_feature_names: List[str] = None, + ): + """Join FeatureGroup with base. + + Args: + feature_group (FeatureGroup): A FeatureGroup which will be joined to base. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as target join key (default: None). + included_feature_names (List[str]): A list of strings representing features to be + included in the output (default: None). + Returns: + This DatasetBuilder object. + """ + self._feature_groups_to_be_merged.append( + construct_feature_group_to_be_merged( + feature_group, included_feature_names, target_feature_name_in_base + ) + ) + return self + + def point_in_time_accurate_join(self): + """Set join type as point in time accurate join. + + Returns: + This DatasetBuilder object. + """ + self._point_in_time_accurate_join = True + return self + + def include_duplicated_records(self): + """Include duplicated records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_duplicated_records = True + return self + + def include_deleted_records(self): + """Include deleted records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_deleted_records = True + return self + + def with_number_of_recent_records_by_record_identifier(self, number_of_recent_records: int): + """Set number_of_recent_records field with provided input. + + Args: + number_of_recent_records (int): An int that how many recent records will be returned for + each record identifier. + Returns: + This DatasetBuilder object. + """ + self._number_of_recent_records = number_of_recent_records + return self + + def with_number_of_records_from_query_results(self, number_of_records: int): + """Set number_of_records field with provided input. + + Args: + number_of_records (int): An int that how many records will be returned. + Returns: + This DatasetBuilder object. + """ + self._number_of_records = number_of_records + return self + + def as_of(self, timestamp: datetime.datetime): + """Set write_time_ending_timestamp field with provided input. + + Args: + timestamp (datetime.datetime): A datetime that all records' write time in dataset will + be before it. + Returns: + This DatasetBuilder object. + """ + self._write_time_ending_timestamp = timestamp + return self + + def with_event_time_range( + self, + starting_timestamp: datetime.datetime = None, + ending_timestamp: datetime.datetime = None, + ): + """Set event_time_starting_timestamp and event_time_ending_timestamp with provided inputs. + + Args: + starting_timestamp (datetime.datetime): A datetime that all records' event time in + dataset will be after it (default: None). + ending_timestamp (datetime.datetime): A datetime that all records' event time in dataset + will be before it (default: None). + Returns: + This DatasetBuilder object. + """ + self._event_time_starting_timestamp = starting_timestamp + self._event_time_ending_timestamp = ending_timestamp + return self + + def to_csv_file(self) -> Tuple[str, str]: + """Get query string and result in .csv format file + + Returns: + The S3 path of the .csv file. + The query string executed. + """ + if isinstance(self._base, pd.DataFrame): + temp_id = utils.unique_name_from_base("dataframe-base") + local_file_name = f"{temp_id}.csv" + desired_s3_folder = f"{self._output_path}/{temp_id}" + self._base.to_csv(local_file_name, index=False, header=False) + s3.S3Uploader.upload( + local_path=local_file_name, + desired_s3_uri=desired_s3_folder, + sagemaker_session=self._sagemaker_session, + kms_key=self._kms_key_id, + ) + os.remove(local_file_name) + temp_table_name = f'dataframe_{temp_id.replace("-", "_")}' + self._create_temp_table(temp_table_name, desired_s3_folder) + base_features = list(self._base.columns) + event_time_identifier_feature_dtype = self._base[ + self._event_time_identifier_feature_name + ].dtypes + self._event_time_identifier_feature_type = ( + FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( + str(event_time_identifier_feature_dtype), None + ) + ) + query_string = self._construct_query_string( + FeatureGroupToBeMerged( + base_features, + self._included_feature_names if self._included_feature_names else base_features, + self._included_feature_names if self._included_feature_names else base_features, + _DEFAULT_CATALOG, + _DEFAULT_DATABASE, + temp_table_name, + self._record_identifier_feature_name, + FeatureDefinition( + self._event_time_identifier_feature_name, + self._event_time_identifier_feature_type, + ), + None, + TableType.DATA_FRAME, + ) + ) + query_result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + # TODO: cleanup temp table, need more clarification, keep it for now + return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get( + "OutputLocation", None + ), query_result.get("QueryExecution", {}).get("Query", None) + if isinstance(self._base, FeatureGroup): + base_feature_group = construct_feature_group_to_be_merged( + self._base, self._included_feature_names + ) + self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name + self._event_time_identifier_feature_name = ( + base_feature_group.event_time_identifier_feature.feature_name + ) + self._event_time_identifier_feature_type = ( + base_feature_group.event_time_identifier_feature.feature_type + ) + query_string = self._construct_query_string(base_feature_group) + query_result = self._run_query( + query_string, + base_feature_group.catalog, + base_feature_group.database, + ) + return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get( + "OutputLocation", None + ), query_result.get("QueryExecution", {}).get("Query", None) + raise ValueError("Base must be either a FeatureGroup or a DataFrame.") + + def to_dataframe(self) -> Tuple[pd.DataFrame, str]: + """Get query string and result in pandas.Dataframe + + Returns: + The pandas.DataFrame object. + The query string executed. + """ + csv_file, query_string = self.to_csv_file() + s3.S3Downloader.download( + s3_uri=csv_file, + local_path="./", + kms_key=self._kms_key_id, + sagemaker_session=self._sagemaker_session, + ) + local_file_name = csv_file.split("/")[-1] + df = pd.read_csv(local_file_name) + os.remove(local_file_name) + + local_metadata_file_name = local_file_name + ".metadata" + if os.path.exists(local_metadata_file_name): + os.remove(local_file_name + ".metadata") + + if "row_recent" in df: + df = df.drop("row_recent", axis="columns") + return df, query_string + + def _construct_event_time_conditions( + self, + table_name: str, + event_time_identifier_feature: FeatureDefinition, + ) -> List[str]: + """Internal method for constructing event time range sql range as string. + + Args: + table_name (str): name of the table. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + Returns: + The list of query strings. + """ + event_time_conditions = [] + timestamp_cast_function_name = "from_unixtime" + if event_time_identifier_feature.feature_type == FeatureTypeEnum.STRING: + timestamp_cast_function_name = "from_iso8601_timestamp" + if self._event_time_starting_timestamp: + event_time_conditions.append( + f"{timestamp_cast_function_name}({table_name}." + + f'"{event_time_identifier_feature.feature_name}") >= ' + + f"from_unixtime({self._event_time_starting_timestamp.timestamp()})" + ) + if self._event_time_ending_timestamp: + event_time_conditions.append( + f"{timestamp_cast_function_name}({table_name}." + + f'"{event_time_identifier_feature.feature_name}") <= ' + + f"from_unixtime({self._event_time_ending_timestamp.timestamp()})" + ) + return event_time_conditions + + def _construct_write_time_condition( + self, + table_name: str, + ) -> str: + """Internal method for constructing write time condition. + + Args: + table_name (str): name of the table. + Returns: + string of write time condition. + """ + write_time_condition = ( + f'{table_name}."write_time" <= ' + f"to_timestamp('{self._write_time_ending_timestamp.replace(microsecond=0)}', " + f"'yyyy-mm-dd hh24:mi:ss')" + ) + return write_time_condition + + def _construct_where_query_string( + self, + suffix: str, + event_time_identifier_feature: FeatureDefinition, + where_conditions: List[str], + ) -> str: + """Internal method for constructing SQL WHERE query string by parameters. + + Args: + suffix (str): A temp identifier of the FeatureGroup. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + where_conditions (List[str]): A list of strings representing existing where clauses. + Returns: + The WHERE query string. + + Raises: + ValueError: FeatureGroup not provided while using as_of(). Only found pandas.DataFrame. + """ + if self._number_of_recent_records: + if self._number_of_recent_records < 0: + raise ValueError( + "Please provide non-negative integer for number_of_recent_records." + ) + if self._number_of_records: + if self._number_of_records < 0: + raise ValueError("Please provide non-negative integer for number_of_records.") + if self._include_deleted_records: + if isinstance(self._base, pd.DataFrame): + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "include_deleted_records() only works for FeatureGroup," + " if there is no join operation." + ) + if self._include_duplicated_records: + if isinstance(self._base, pd.DataFrame): + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "include_duplicated_records() only works for FeatureGroup," + " if there is no join operation." + ) + if self._point_in_time_accurate_join: + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "point_in_time_accurate_join() this operation only works when there is " + "more than one feature group to join." + ) + if self._write_time_ending_timestamp: + if isinstance(self._base, pd.DataFrame): + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "as_of() only works for FeatureGroup," " if there is no join operation." + ) + if isinstance(self._base, FeatureGroup): + if self._write_time_ending_timestamp: + where_conditions.append(self._construct_write_time_condition(f"table_{suffix}")) + + event_time_conditions = self._construct_event_time_conditions( + f"table_{suffix}", event_time_identifier_feature + ) + where_conditions.extend(event_time_conditions) + + if len(where_conditions) == 0: + return "" + return "WHERE " + "\nAND ".join(where_conditions) + + def _construct_dedup_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing removing duplicate records SQL query string. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The SQL query string. + """ + record_feature_name = feature_group.record_identifier_feature_name + event_time_identifier_feature = feature_group.event_time_identifier_feature + event_time_feature_name = feature_group.event_time_identifier_feature.feature_name + rank_query_string = "" + where_conditions = [] + where_conditions_str = "" + is_dedup_enabled = False + + if feature_group.table_type is TableType.FEATURE_GROUP: + is_dedup_enabled = True + rank_query_string = ( + f'ORDER BY origin_{suffix}."api_invocation_time" DESC, ' + + f'origin_{suffix}."write_time" DESC\n' + ) + + if self._write_time_ending_timestamp: + where_conditions.append(self._construct_write_time_condition(f"origin_{suffix}")) + + event_time_conditions = self._construct_event_time_conditions( + f"origin_{suffix}", event_time_identifier_feature + ) + where_conditions.extend(event_time_conditions) + + if len(where_conditions) != 0: + where_conditions_str = "WHERE " + "\nAND ".join(where_conditions) + "\n" + + dedup_where_clause = f"WHERE dedup_row_{suffix} = 1\n" if is_dedup_enabled else "" + return ( + f"table_{suffix} AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + f'PARTITION BY origin_{suffix}."{record_feature_name}", ' + + f'origin_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + f") AS dedup_row_{suffix}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n' + + where_conditions_str + + ")\n" + + dedup_where_clause + + ")" + ) + + def _construct_deleted_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing removing deleted records SQL query string. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The SQL query string. + """ + record_feature_name = feature_group.record_identifier_feature_name + event_time_identifier_feature = feature_group.event_time_identifier_feature + event_time_feature_name = feature_group.event_time_identifier_feature.feature_name + rank_query_string = f'ORDER BY origin_{suffix}."{event_time_feature_name}" DESC' + write_time_condition = "\n" + event_time_starting_condition = "" + event_time_ending_condition = "" + + if feature_group.table_type is TableType.FEATURE_GROUP: + rank_query_string += ( + f', origin_{suffix}."api_invocation_time" DESC, ' + + f'origin_{suffix}."write_time" DESC\n' + ) + + if self._write_time_ending_timestamp: + write_time_condition += " AND " + write_time_condition += self._construct_write_time_condition(f"origin_{suffix}") + write_time_condition += "\n" + + if self._event_time_starting_timestamp and self._event_time_ending_timestamp: + event_time_conditions = self._construct_event_time_conditions( + f"origin_{suffix}", event_time_identifier_feature + ) + event_time_starting_condition = "AND " + event_time_conditions[0] + "\n" + event_time_ending_condition = "AND " + event_time_conditions[1] + "\n" + + return ( + f"deleted_{suffix} AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + f'PARTITION BY origin_{suffix}."{record_feature_name}"\n' + + rank_query_string + + f") AS deleted_row_{suffix}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n' + + "WHERE is_deleted" + + write_time_condition + + event_time_starting_condition + + event_time_ending_condition + + ")\n" + + f"WHERE deleted_row_{suffix} = 1\n" + + ")" + ) + + def _construct_table_included_features( + self, feature_group: FeatureGroupToBeMerged, suffix: str + ) -> str: + """Internal method for constructing included features string of table. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object + which has the metadata. + suffix (str): A temp identifier of the table. + Returns: + The string that includes all feature to be included of table. + """ + + included_features = ", ".join( + [ + f'table_{suffix}."{include_feature_name}"' + for include_feature_name in feature_group.included_feature_names + ] + ) + return included_features + + def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing SQL query string by parameters. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The query string. + """ + included_features = self._construct_table_included_features(feature_group, suffix) + + # If base is a FeatureGroup then included_features_write_time will have a write_time column + # Or included_features_write_time is same as included_features + included_features_write_time = included_features + + if feature_group.table_type is TableType.FEATURE_GROUP: + included_features_write_time += f', table_{suffix}."write_time"' + record_feature_name = feature_group.record_identifier_feature_name + event_time_feature_name = feature_group.event_time_identifier_feature.feature_name + if self._include_duplicated_records and self._include_deleted_records: + return ( + f"SELECT {included_features}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" table_{suffix}\n' + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, ["NOT is_deleted"] + ) + ) + if feature_group.table_type is TableType.FEATURE_GROUP and self._include_deleted_records: + rank_query_string = "" + if feature_group.table_type is TableType.FEATURE_GROUP: + rank_query_string = ( + f'ORDER BY origin_{suffix}."api_invocation_time" DESC, ' + + f'origin_{suffix}."write_time" DESC\n' + ) + return ( + f"SELECT {included_features}\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + f'PARTITION BY origin_{suffix}."{record_feature_name}", ' + + f'origin_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + f") AS row_{suffix}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n' + + "WHERE NOT is_deleted" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, + feature_group.event_time_identifier_feature, + [f"row_{suffix} = 1"], + ) + ) + rank_query_string = "" + if feature_group.table_type is TableType.FEATURE_GROUP: + rank_query_string = ( + f'OR (table_{suffix}."{event_time_feature_name}" = ' + + f'deleted_{suffix}."{event_time_feature_name}" ' + + f'AND table_{suffix}."api_invocation_time" > ' + + f'deleted_{suffix}."api_invocation_time")\n' + + f'OR (table_{suffix}."{event_time_feature_name}" = ' + + f'deleted_{suffix}."{event_time_feature_name}" ' + + f'AND table_{suffix}."api_invocation_time" = ' + + f'deleted_{suffix}."api_invocation_time" ' + + f'AND table_{suffix}."write_time" > deleted_{suffix}."write_time")\n' + ) + + final_query_string = "" + if feature_group.table_type is TableType.FEATURE_GROUP: + if self._include_duplicated_records: + final_query_string = ( + f"WITH {self._construct_deleted_query(feature_group, suffix)}\n" + + f"SELECT {included_features}\n" + + "FROM (\n" + + f"SELECT {included_features_write_time}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}"' + + f" table_{suffix}\n" + + f"LEFT JOIN deleted_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + f'WHERE deleted_{suffix}."{record_feature_name}" IS NULL\n' + + "UNION ALL\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM deleted_{suffix}\n" + + f'JOIN "{feature_group.database}"."{feature_group.table_name}"' + + f" table_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + "AND (\n" + + f'table_{suffix}."{event_time_feature_name}" > ' + + f'deleted_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + ")\n" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, [] + ) + ) + else: + final_query_string = ( + f"WITH {self._construct_dedup_query(feature_group, suffix)},\n" + + f"{self._construct_deleted_query(feature_group, suffix)}\n" + + f"SELECT {included_features}\n" + + "FROM (\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM table_{suffix}\n" + + f"LEFT JOIN deleted_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + f'WHERE deleted_{suffix}."{record_feature_name}" IS NULL\n' + + "UNION ALL\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM deleted_{suffix}\n" + + f"JOIN table_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + "AND (\n" + + f'table_{suffix}."{event_time_feature_name}" > ' + + f'deleted_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + ")\n" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, [] + ) + ) + else: + final_query_string = ( + f"WITH {self._construct_dedup_query(feature_group, suffix)}\n" + + f"SELECT {included_features}\n" + + "FROM (\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM table_{suffix}\n" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, [] + ) + ) + return final_query_string + + def _construct_query_string(self, base: FeatureGroupToBeMerged) -> str: + """Internal method for constructing SQL query string by parameters. + + Args: + base (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the metadata. + Returns: + The query string. + + Raises: + ValueError: target_feature_name_in_base is an invalid feature name. + """ + base_table_query_string = self._construct_table_query(base, "base") + query_string = f"WITH fg_base AS ({base_table_query_string})" + if len(self._feature_groups_to_be_merged) > 0: + with_subquery_string = "".join( + [ + f",\nfg_{i} AS ({self._construct_table_query(feature_group, str(i))})" + for i, feature_group in enumerate(self._feature_groups_to_be_merged) + ] + ) + query_string += with_subquery_string + + selected_features = "" + selected_features += ", ".join(map("fg_base.{0}".format, base.projected_feature_names)) + if len(self._feature_groups_to_be_merged) > 0: + for i, feature_group in enumerate(self._feature_groups_to_be_merged): + selected_features += ", " + selected_features += ", ".join( + [ + f'fg_{i}."{feature_name}" as "{feature_name}.{(i+1)}"' + for feature_name in feature_group.projected_feature_names + ] + ) + + selected_features_final = "" + selected_features_final += ", ".join(base.projected_feature_names) + if len(self._feature_groups_to_be_merged) > 0: + for i, feature_group in enumerate(self._feature_groups_to_be_merged): + selected_features_final += ", " + selected_features_final += ", ".join( + [ + '"{0}.{1}"'.format(feature_name, (i + 1)) + for feature_name in feature_group.projected_feature_names + ] + ) + + query_string += ( + f"\nSELECT {selected_features_final}\n" + + "FROM (\n" + + f"SELECT {selected_features}, row_number() OVER (\n" + + f'PARTITION BY fg_base."{base.record_identifier_feature_name}"\n' + + f'ORDER BY fg_base."{base.event_time_identifier_feature.feature_name}" DESC' + ) + + recent_record_where_clause = "" + if self._number_of_recent_records is not None and self._number_of_recent_records >= 0: + recent_record_where_clause = f"WHERE row_recent <= {self._number_of_recent_records}" + + join_subquery_strings = [] + if len(self._feature_groups_to_be_merged) > 0: + for i, feature_group in enumerate(self._feature_groups_to_be_merged): + if not feature_group.target_feature_name_in_base: + feature_group.target_feature_name_in_base = self._record_identifier_feature_name + else: + if feature_group.target_feature_name_in_base not in base.features: + raise ValueError( + f"Feature {feature_group.target_feature_name_in_base} not found in base" + ) + query_string += ( + f', fg_{i}."{feature_group.event_time_identifier_feature.feature_name}" DESC' + ) + join_subquery_strings.append(self._construct_join_condition(feature_group, str(i))) + + query_string += ( + "\n) AS row_recent\n" + + "FROM fg_base" + + "".join(join_subquery_strings) + + "\n)\n" + + f"{recent_record_where_clause}" + ) + + if self._number_of_records is not None and self._number_of_records >= 0: + query_string += f"\nLIMIT {self._number_of_records}" + return query_string + + def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing SQL JOIN query string by parameters. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The JOIN query string. + """ + join_condition_string = ( + f"\nJOIN fg_{suffix}\n" + + f'ON fg_base."{feature_group.target_feature_name_in_base}" = ' + + f'fg_{suffix}."{feature_group.record_identifier_feature_name}"' + ) + base_timestamp_cast_function_name = "from_unixtime" + if self._event_time_identifier_feature_type == FeatureTypeEnum.STRING: + base_timestamp_cast_function_name = "from_iso8601_timestamp" + timestamp_cast_function_name = "from_unixtime" + if feature_group.event_time_identifier_feature.feature_type == FeatureTypeEnum.STRING: + timestamp_cast_function_name = "from_iso8601_timestamp" + if self._point_in_time_accurate_join: + join_condition_string += ( + f"\nAND {base_timestamp_cast_function_name}(fg_base." + + f'"{self._event_time_identifier_feature_name}") >= ' + + f"{timestamp_cast_function_name}(fg_{suffix}." + + f'"{feature_group.event_time_identifier_feature.feature_name}")' + ) + return join_condition_string + + def _create_temp_table(self, temp_table_name: str, desired_s3_folder: str): + """Internal method for creating a temp Athena table for the base pandas.Dataframe. + + Args: + temp_table_name (str): The Athena table name of base pandas.DataFrame. + desired_s3_folder (str): The S3 URI of the folder of the data. + """ + columns_string = ", ".join( + [self._construct_athena_table_column_string(column) for column in self._base.columns] + ) + serde_properties = '"separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\"' + query_string = ( + f"CREATE EXTERNAL TABLE {temp_table_name} ({columns_string}) " + + "ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' " + + f"WITH SERDEPROPERTIES ({serde_properties}) " + + f"LOCATION '{desired_s3_folder}';" + ) + self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + + def _construct_athena_table_column_string(self, column: str) -> str: + """Internal method for constructing string of Athena column. + + Args: + column (str): The column name from pandas.Dataframe. + Returns: + The Athena column string. + + Raises: + RuntimeError: The type of pandas.Dataframe column is not support yet. + """ + dataframe_type = self._base[column].dtypes + if str(dataframe_type) not in self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.keys(): + raise RuntimeError(f"The dataframe type {dataframe_type} is not supported yet.") + return f"{column} {self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.get(str(dataframe_type), None)}" + + def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]: + """Internal method for execute Athena query, wait for query finish and get query result. + + Args: + query_string (str): The SQL query statements to be executed. + catalog (str): The name of the data catalog used in the query execution. + database (str): The name of the database used in the query execution. + Returns: + The query result. + + Raises: + RuntimeError: Athena query failed. + """ + query = self._sagemaker_session.start_query_execution( + catalog=catalog, + database=database, + query_string=query_string, + output_location=self._output_path, + kms_key=self._kms_key_id, + ) + query_id = query.get("QueryExecutionId", None) + self._sagemaker_session.wait_for_athena_query(query_execution_id=query_id) + query_result = self._sagemaker_session.get_query_execution(query_execution_id=query_id) + query_state = query_result.get("QueryExecution", {}).get("Status", {}).get("State", None) + + if query_state != "SUCCEEDED": + raise RuntimeError(f"Failed to execute query {query_id}.") + return query_result diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index d486ab8a01..855e11488f 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -435,13 +435,14 @@ class FeatureGroup: "uint64", ] _FLOAT_TYPES = ["float_", "float16", "float32", "float64"] - _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = { + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = { type: FeatureTypeEnum.INTEGRAL for type in _INTEGER_TYPES } - _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update( + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update( {type: FeatureTypeEnum.FRACTIONAL for type in _FLOAT_TYPES} ) - _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["object"] = FeatureTypeEnum.STRING _FEATURE_TYPE_TO_DDL_DATA_TYPE_MAP = { FeatureTypeEnum.INTEGRAL.value: "INT", @@ -629,7 +630,7 @@ def load_feature_definitions( """ feature_definitions = [] for column in data_frame: - feature_type = self._DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( + feature_type = self.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( str(data_frame[column].dtype), None ) if feature_type: @@ -644,6 +645,23 @@ def load_feature_definitions( self.feature_definitions = feature_definitions return self.feature_definitions + def get_record( + self, record_identifier_value_as_string: str, feature_names: Sequence[str] = None + ) -> Sequence[Dict[str, str]]: + """Get a single record in a FeatureGroup + + Args: + record_identifier_value_as_string (String): + a String representing the value of the record identifier. + feature_names (Sequence[String]): + a list of Strings representing feature names. + """ + return self.sagemaker_session.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + feature_group_name=self.name, + feature_names=feature_names, + ).get("Record") + def put_record(self, record: Sequence[FeatureValue]): """Put a single record in the FeatureGroup. @@ -654,6 +672,25 @@ def put_record(self, record: Sequence[FeatureValue]): feature_group_name=self.name, record=[value.to_dict() for value in record] ) + def delete_record( + self, + record_identifier_value_as_string: str, + event_time: str, + ): + """Delete a single record from a FeatureGroup. + + Args: + record_identifier_value_as_string (String): + a String representing the value of the record identifier. + event_time (String): + a timestamp format String indicating when the deletion event occurred. + """ + return self.sagemaker_session.delete_record( + feature_group_name=self.name, + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=event_time, + ) + def ingest( self, data_frame: DataFrame, diff --git a/src/sagemaker/feature_store/feature_store.py b/src/sagemaker/feature_store/feature_store.py new file mode 100644 index 0000000000..def8b2b2da --- /dev/null +++ b/src/sagemaker/feature_store/feature_store.py @@ -0,0 +1,130 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Feature Store. + +Amazon SageMaker Feature Store is a fully managed, purpose-built repository to store, share, and +manage features for machine learning (ML) models. +""" +from __future__ import absolute_import + +import datetime +from typing import Any, Dict, Sequence, Union + +import attr +import pandas as pd + +from sagemaker import Session +from sagemaker.feature_store.dataset_builder import DatasetBuilder +from sagemaker.feature_store.feature_group import FeatureGroup + + +@attr.s +class FeatureStore: + """FeatureStore definition. + + This class instantiates a FeatureStore object that comprises a SageMaker session instance. + + Attributes: + sagemaker_session (Session): session instance to perform boto calls. + """ + + sagemaker_session: Session = attr.ib(default=Session) + + def create_dataset( + self, + base: Union[FeatureGroup, pd.DataFrame], + output_path: str, + record_identifier_feature_name: str = None, + event_time_identifier_feature_name: str = None, + included_feature_names: Sequence[str] = None, + kms_key_id: str = None, + ) -> DatasetBuilder: + """Create a Dataset Builder for generating a Dataset. + + Args: + base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a + pandas.DataFrame and will be used to merge other FeatureGroups and generate a + Dataset. + output_path (str): An S3 URI which stores the output .csv file. + record_identifier_feature_name (str): A string representing the record identifier + feature if base is a DataFrame (default: None). + event_time_identifier_feature_name (str): A string representing the event time + identifier feature if base is a DataFrame (default: None). + included_feature_names (List[str]): A list of features to be included in the output + (default: None). + kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file + (default: None). + + Raises: + ValueError: Base is a Pandas DataFrame but no record identifier feature name nor event + time identifier feature name is provided. + """ + if isinstance(base, pd.DataFrame): + if record_identifier_feature_name is None or event_time_identifier_feature_name is None: + raise ValueError( + "You must provide a record identifier feature name and an event time " + + "identifier feature name if specify DataFrame as base." + ) + return DatasetBuilder( + self.sagemaker_session, + base, + output_path, + record_identifier_feature_name, + event_time_identifier_feature_name, + included_feature_names, + kms_key_id, + ) + + def list_feature_groups( + self, + name_contains: str = None, + feature_group_status_equals: str = None, + offline_store_status_equals: str = None, + creation_time_after: datetime.datetime = None, + creation_time_before: datetime.datetime = None, + sort_order: str = None, + sort_by: str = None, + max_results: int = None, + next_token: str = None, + ) -> Dict[str, Any]: + """List all FeatureGroups satisfying given filters. + + Args: + name_contains (str): A string that partially matches one or more FeatureGroups' names. + Filters FeatureGroups by name. + feature_group_status_equals (str): A FeatureGroup status. + Filters FeatureGroups by FeatureGroup status. + offline_store_status_equals (str): An OfflineStore status. + Filters FeatureGroups by OfflineStore status. + creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups + created after a specific date and time. + creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups + created before a specific date and time. + sort_order (str): The order in which FeatureGroups are listed. + sort_by (str): The value on which the FeatureGroup list is sorted. + max_results (int): The maximum number of results returned by ListFeatureGroups. + next_token (str): A token to resume pagination of ListFeatureGroups results. + Returns: + Response dict from service. + """ + return self.sagemaker_session.list_feature_groups( + name_contains=name_contains, + feature_group_status_equals=feature_group_status_equals, + offline_store_status_equals=offline_store_status_equals, + creation_time_after=creation_time_after, + creation_time_before=creation_time_before, + sort_order=sort_order, + sort_by=sort_by, + max_results=max_results, + next_token=next_token, + ) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index d82d3596ac..a91aff1761 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -80,6 +80,7 @@ "ml.p3.16xlarge", "ml.p3dn.24xlarge", "ml.p4d.24xlarge", + "ml.p4de.24xlarge", "local_gpu", ) SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = { @@ -103,6 +104,7 @@ "2.8.0", "2.9", "2.9.1", + "2.9.2", "2.10", "2.10.0", ], @@ -493,7 +495,7 @@ def framework_name_from_image(image_uri): # We must support both the legacy and current image name format. name_pattern = re.compile( r"""^(?:sagemaker(?:-rl)?-)? - (tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost + (tensorflow|mxnet|chainer|pytorch|pytorch-trcomp|scikit-learn|xgboost |huggingface-tensorflow|huggingface-pytorch |huggingface-tensorflow-trcomp|huggingface-pytorch-trcomp)(?:-)? (scriptmode|training)? diff --git a/src/sagemaker/git_utils.py b/src/sagemaker/git_utils.py index 80bd62d5be..c424753286 100644 --- a/src/sagemaker/git_utils.py +++ b/src/sagemaker/git_utils.py @@ -279,9 +279,8 @@ def _run_clone_command(repo_url, dest_dir): subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) elif repo_url.startswith("git@"): with tempfile.NamedTemporaryFile() as sshnoprompt: - write_pipe = open(sshnoprompt.name, "w") - write_pipe.write("ssh -oBatchMode=yes $@") - write_pipe.close() + with open(sshnoprompt.name, "w") as write_pipe: + write_pipe.write("ssh -oBatchMode=yes $@") os.chmod(sshnoprompt.name, 0o511) my_env["GIT_SSH"] = sshnoprompt.name subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 3cc488c55d..590b6e5f82 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -26,9 +26,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -56,9 +58,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -86,9 +90,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -116,9 +122,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -146,9 +154,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -176,9 +186,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -204,6 +216,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -211,15 +224,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -237,6 +254,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -244,15 +262,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -270,6 +292,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -277,15 +300,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -303,6 +330,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -310,15 +338,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -336,6 +368,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -343,15 +376,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -369,6 +406,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -376,15 +414,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/blazingtext.json b/src/sagemaker/image_uri_config/blazingtext.json index c588d65c73..ae4295c59a 100644 --- a/src/sagemaker/image_uri_config/blazingtext.json +++ b/src/sagemaker/image_uri_config/blazingtext.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/factorization-machines.json b/src/sagemaker/image_uri_config/factorization-machines.json index 0f9930357f..8fb1895707 100644 --- a/src/sagemaker/image_uri_config/factorization-machines.json +++ b/src/sagemaker/image_uri_config/factorization-machines.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/forecasting-deepar.json b/src/sagemaker/image_uri_config/forecasting-deepar.json index 1acc96ed3e..e9beb7acb6 100644 --- a/src/sagemaker/image_uri_config/forecasting-deepar.json +++ b/src/sagemaker/image_uri_config/forecasting-deepar.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "522234722520", "us-east-2": "566113047672", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "156387875391" diff --git a/src/sagemaker/image_uri_config/huggingface-neuron.json b/src/sagemaker/image_uri_config/huggingface-neuron.json index 1e2246cb11..980dceed17 100644 --- a/src/sagemaker/image_uri_config/huggingface-neuron.json +++ b/src/sagemaker/image_uri_config/huggingface-neuron.json @@ -15,21 +15,25 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-training-compiler.json b/src/sagemaker/image_uri_config/huggingface-training-compiler.json index e771e2a548..482264b773 100644 --- a/src/sagemaker/image_uri_config/huggingface-training-compiler.json +++ b/src/sagemaker/image_uri_config/huggingface-training-compiler.json @@ -60,6 +60,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -89,6 +90,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -123,6 +125,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index 317c17030a..a0caa59a55 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -38,9 +38,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -70,9 +72,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -108,9 +112,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -140,9 +146,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -180,9 +188,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -213,9 +223,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -246,9 +258,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -279,9 +293,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -320,9 +336,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -353,9 +371,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -386,9 +406,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -419,9 +441,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -458,9 +482,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -491,9 +517,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -530,9 +558,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -563,9 +593,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -602,9 +634,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -635,9 +669,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -674,6 +710,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -681,15 +718,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -707,6 +748,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -714,15 +756,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -740,6 +786,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -747,15 +794,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -781,6 +832,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -788,15 +840,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -814,6 +870,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -821,15 +878,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -847,6 +908,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -854,15 +916,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -880,6 +946,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -887,15 +954,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -919,6 +990,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -926,15 +998,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -952,6 +1028,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -959,15 +1036,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -991,6 +1072,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -998,15 +1080,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1024,6 +1110,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1031,15 +1118,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1063,6 +1154,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1070,15 +1162,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1096,6 +1192,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1103,15 +1200,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/image-classification.json b/src/sagemaker/image_uri_config/image-classification.json index 44ccb3f08d..61e037da08 100644 --- a/src/sagemaker/image_uri_config/image-classification.json +++ b/src/sagemaker/image_uri_config/image-classification.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/ipinsights.json b/src/sagemaker/image_uri_config/ipinsights.json index 4e56c149dc..cf3c70194f 100644 --- a/src/sagemaker/image_uri_config/ipinsights.json +++ b/src/sagemaker/image_uri_config/ipinsights.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/kmeans.json b/src/sagemaker/image_uri_config/kmeans.json index 952724ce11..e8e947f094 100644 --- a/src/sagemaker/image_uri_config/kmeans.json +++ b/src/sagemaker/image_uri_config/kmeans.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/knn.json b/src/sagemaker/image_uri_config/knn.json index 79b239966d..89e8ef6224 100644 --- a/src/sagemaker/image_uri_config/knn.json +++ b/src/sagemaker/image_uri_config/knn.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/linear-learner.json b/src/sagemaker/image_uri_config/linear-learner.json index bb027284ab..606edd3791 100644 --- a/src/sagemaker/image_uri_config/linear-learner.json +++ b/src/sagemaker/image_uri_config/linear-learner.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json index 12bc40fccf..588a03a76e 100644 --- a/src/sagemaker/image_uri_config/mxnet.json +++ b/src/sagemaker/image_uri_config/mxnet.json @@ -245,9 +245,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -277,9 +279,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -309,9 +313,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -341,9 +347,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -373,9 +381,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -619,6 +629,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -626,15 +637,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -651,6 +666,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -658,15 +674,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -683,6 +703,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -690,15 +711,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -715,6 +740,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -722,15 +748,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -747,6 +777,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -754,15 +785,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -852,6 +887,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -859,15 +895,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -884,6 +924,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -891,15 +932,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -916,6 +961,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -923,15 +969,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/neo-pytorch.json b/src/sagemaker/image_uri_config/neo-pytorch.json index bd15a6450e..c46dd3de5d 100644 --- a/src/sagemaker/image_uri_config/neo-pytorch.json +++ b/src/sagemaker/image_uri_config/neo-pytorch.json @@ -11,7 +11,9 @@ "1.7.0": "1.7", "1.7.1": "1.7", "1.8.0": "1.8", - "1.8.1": "1.8" + "1.8.1": "1.8", + "1.12.0": "1.12", + "1.12.1": "1.12" }, "versions": { "1.4": { @@ -173,6 +175,38 @@ "us-west-2": "301217895009" }, "repository": "sagemaker-inference-pytorch" + }, + "1.12": { + "py_versions": ["py3"], + "registries": { + "af-south-1": "774647643957", + "ap-east-1": "110948597952", + "ap-northeast-1": "941853720454", + "ap-northeast-2": "151534178276", + "ap-northeast-3": "925152966179", + "ap-south-1": "763008648453", + "ap-southeast-1": "324986816169", + "ap-southeast-2": "355873309152", + "ca-central-1": "464438896020", + "cn-north-1": "472730292857", + "cn-northwest-1": "474822919863", + "eu-central-1": "746233611703", + "eu-north-1": "601324751636", + "eu-south-1": "966458181534", + "eu-west-1": "802834080501", + "eu-west-2": "205493899709", + "eu-west-3": "254080097072", + "me-south-1": "836785723513", + "sa-east-1": "756306329178", + "us-east-1": "785573368785", + "us-east-2": "007439368137", + "us-gov-west-1": "263933020539", + "us-iso-east-1": "167761179201", + "us-isob-east-1": "406031935815", + "us-west-1": "710691900526", + "us-west-2": "301217895009" + }, + "repository": "sagemaker-inference-pytorch" } } } diff --git a/src/sagemaker/image_uri_config/ntm.json b/src/sagemaker/image_uri_config/ntm.json index 115264b346..16f9565405 100644 --- a/src/sagemaker/image_uri_config/ntm.json +++ b/src/sagemaker/image_uri_config/ntm.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/object-detection.json b/src/sagemaker/image_uri_config/object-detection.json index 6a7ba03695..67b60fe587 100644 --- a/src/sagemaker/image_uri_config/object-detection.json +++ b/src/sagemaker/image_uri_config/object-detection.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/object2vec.json b/src/sagemaker/image_uri_config/object2vec.json index 39614d1273..b166cc96ff 100644 --- a/src/sagemaker/image_uri_config/object2vec.json +++ b/src/sagemaker/image_uri_config/object2vec.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/pca.json b/src/sagemaker/image_uri_config/pca.json index 5f87d8528c..11982e2197 100644 --- a/src/sagemaker/image_uri_config/pca.json +++ b/src/sagemaker/image_uri_config/pca.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/pytorch-neuron.json b/src/sagemaker/image_uri_config/pytorch-neuron.json index b116a8a36b..5b29406955 100644 --- a/src/sagemaker/image_uri_config/pytorch-neuron.json +++ b/src/sagemaker/image_uri_config/pytorch-neuron.json @@ -28,6 +28,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/pytorch-training-compiler.json b/src/sagemaker/image_uri_config/pytorch-training-compiler.json new file mode 100644 index 0000000000..fd7df875a3 --- /dev/null +++ b/src/sagemaker/image_uri_config/pytorch-training-compiler.json @@ -0,0 +1,41 @@ +{ + "training": { + "processors": [ + "gpu" + ], + "version_aliases": { + "1.12": "1.12.0" + }, + "versions": { + "1.12.0": { + "py_versions": [ + "py38" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-trcomp-training" + } + } + } +} diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 3bf8016ba8..85681a3423 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -17,6 +17,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -25,7 +26,9 @@ "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-north-1": "763104351884", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", + "eu-south-2": "503227376785", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-west-2": "763104351884" @@ -39,8 +42,11 @@ "registries": { "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-3": "907027046896", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", + "eu-south-2": "503227376785", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-west-2": "763104351884" @@ -182,6 +188,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -189,15 +196,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -217,6 +228,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -224,15 +236,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -251,6 +267,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -258,15 +275,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -285,6 +306,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -292,15 +314,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -320,6 +346,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -327,15 +354,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -355,6 +386,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -362,15 +394,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -390,6 +426,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -397,15 +434,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -425,6 +466,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -432,15 +474,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -459,6 +505,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -466,15 +513,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -493,6 +544,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -500,15 +552,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -527,6 +583,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -534,15 +591,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -561,6 +622,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -568,15 +630,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -595,6 +661,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -602,15 +669,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -629,6 +700,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -636,15 +708,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -663,6 +739,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -670,15 +747,18 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -707,6 +787,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -714,15 +795,18 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -879,9 +963,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -914,9 +1000,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -949,9 +1037,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -983,9 +1073,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1018,9 +1110,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1053,9 +1147,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1088,9 +1184,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1123,9 +1221,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1157,9 +1257,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1191,9 +1293,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1225,9 +1329,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1259,9 +1365,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1293,9 +1401,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1327,9 +1437,11 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1364,6 +1476,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/randomcutforest.json b/src/sagemaker/image_uri_config/randomcutforest.json index ae7a3574be..15dc84dfc5 100644 --- a/src/sagemaker/image_uri_config/randomcutforest.json +++ b/src/sagemaker/image_uri_config/randomcutforest.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/semantic-segmentation.json b/src/sagemaker/image_uri_config/semantic-segmentation.json index 866dd606b4..f49bc43109 100644 --- a/src/sagemaker/image_uri_config/semantic-segmentation.json +++ b/src/sagemaker/image_uri_config/semantic-segmentation.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/seq2seq.json b/src/sagemaker/image_uri_config/seq2seq.json index bb3daf93b6..87810ad09d 100644 --- a/src/sagemaker/image_uri_config/seq2seq.json +++ b/src/sagemaker/image_uri_config/seq2seq.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/sklearn.json b/src/sagemaker/image_uri_config/sklearn.json index 7961fde282..4d093f5f62 100644 --- a/src/sagemaker/image_uri_config/sklearn.json +++ b/src/sagemaker/image_uri_config/sklearn.json @@ -24,10 +24,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -57,10 +59,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -90,10 +94,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -127,10 +133,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -160,10 +168,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -193,10 +203,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -230,10 +242,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 6a01c3e3e6..6bb36057fa 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -141,6 +141,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -148,15 +149,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "eu-south-2": "503227376785", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -172,6 +177,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -179,15 +185,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -203,6 +213,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -210,15 +221,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -234,6 +249,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -241,15 +257,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -285,7 +305,10 @@ "2.5": "2.5.1", "2.6": "2.6.3", "2.7": "2.7.0", - "2.8": "2.8.0" + "2.8": "2.8.0", + "2.9": "2.9.2", + "2.10": "2.10.0", + "2.11": "2.11.0" }, "versions": { "1.10.0": { @@ -386,6 +409,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -393,15 +417,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -417,6 +445,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -424,15 +453,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -448,6 +481,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -455,15 +489,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -479,6 +517,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -486,15 +525,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -510,6 +553,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -517,15 +561,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -541,6 +589,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -548,15 +597,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -572,6 +625,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -579,15 +633,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -795,6 +853,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -802,15 +861,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -826,6 +889,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -833,15 +897,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -857,6 +925,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -864,15 +933,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -888,6 +961,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -895,15 +969,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -919,6 +997,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -926,15 +1005,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -950,6 +1033,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -957,15 +1041,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -981,6 +1069,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -988,15 +1077,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1012,6 +1105,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1019,15 +1113,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1043,6 +1141,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1050,15 +1149,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1074,6 +1177,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1081,15 +1185,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1105,6 +1213,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1112,15 +1221,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1136,6 +1249,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1143,15 +1257,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1167,6 +1285,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1174,15 +1293,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1198,6 +1321,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1205,15 +1329,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1229,6 +1357,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1236,15 +1365,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1260,6 +1393,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1267,15 +1401,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1291,6 +1429,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1298,15 +1437,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1322,6 +1465,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1329,15 +1473,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1353,6 +1501,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1360,15 +1509,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1384,6 +1537,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1391,15 +1545,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1415,6 +1573,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1422,15 +1581,19 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1446,6 +1609,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1453,15 +1617,124 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, + "2.9.2": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, + "2.10.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, + "2.11.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1490,6 +1763,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1497,15 +1771,18 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1543,7 +1820,7 @@ "2.6": "2.6.3", "2.7": "2.7.1", "2.8": "2.8.0", - "2.9": "2.9.1", + "2.9": "2.9.2", "2.10": "2.10.0" }, "versions": { @@ -1696,9 +1973,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1732,9 +2011,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1767,9 +2048,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1803,9 +2086,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1839,9 +2124,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1875,9 +2162,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1911,9 +2200,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2138,9 +2429,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2173,9 +2466,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2208,9 +2503,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2242,9 +2539,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2276,9 +2575,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2311,9 +2612,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2346,9 +2649,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2380,9 +2685,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2414,9 +2721,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2448,9 +2757,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2482,9 +2793,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2516,9 +2829,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2550,9 +2865,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2584,9 +2901,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2618,9 +2937,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2652,9 +2973,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2686,9 +3009,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2720,9 +3045,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2754,9 +3081,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2788,9 +3117,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2822,9 +3153,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2856,9 +3189,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2890,9 +3225,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2924,9 +3261,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2934,7 +3273,7 @@ }, "repository": "tensorflow-training" }, - "2.9.1": { + "2.9.2": { "py_versions": [ "py39" ], @@ -2958,9 +3297,11 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2995,6 +3336,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/xgboost.json b/src/sagemaker/image_uri_config/xgboost.json index a809083c4a..946e78ecc4 100644 --- a/src/sagemaker/image_uri_config/xgboost.json +++ b/src/sagemaker/image_uri_config/xgboost.json @@ -25,10 +25,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" @@ -58,10 +60,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -91,10 +95,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -124,10 +130,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -155,10 +163,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -186,10 +196,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -217,10 +229,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -248,10 +262,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -286,10 +302,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" @@ -319,10 +337,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -352,10 +372,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -385,10 +407,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -416,10 +440,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -447,10 +473,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -478,10 +506,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -509,10 +539,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -544,10 +576,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -575,10 +609,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 7d1d3bd835..c42ce02188 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -146,7 +146,7 @@ def retrieve( tolerate_deprecated_model, ) - if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK): + if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]): final_image_scope = image_scope config = _config_for_framework_and_scope( framework + "-training-compiler", final_image_scope, accelerator_type diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 202edff9ad..db607770a7 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -20,7 +20,7 @@ import boto3 import botocore from packaging.version import Version -from packaging.specifiers import SpecifierSet +from packaging.specifiers import SpecifierSet, InvalidSpecifier from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, @@ -371,7 +371,10 @@ def _select_version( return None return str(max(available_versions)) - spec = SpecifierSet(f"=={semantic_version_str}") + try: + spec = SpecifierSet(f"=={semantic_version_str}") + except InvalidSpecifier: + raise KeyError(f"Bad semantic version: {semantic_version_str}") available_versions_filtered = list(spec.filter(available_versions)) return ( str(max(available_versions_filtered)) if available_versions_filtered != [] else None diff --git a/src/sagemaker/lineage/_utils.py b/src/sagemaker/lineage/_utils.py index 28732b0174..7c833a468e 100644 --- a/src/sagemaker/lineage/_utils.py +++ b/src/sagemaker/lineage/_utils.py @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. """SageMaker lineage utility methods.""" from __future__ import absolute_import -from importlib import import_module from sagemaker.lineage import association @@ -38,22 +37,6 @@ def _disassociate(source_arn=None, destination_arn=None, sagemaker_session=None) curr_association.delete() -def get_module(module_name): - """Import a module. - - Args: - module_name (str): name of the module to import. - - Returns: - [obj]: The imported module. - Raises exceptions when the module name is not found - """ - try: - return import_module(module_name) - except ImportError: - raise Exception("Cannot import module {}, please try again.".format(module_name)) - - def get_resource_name_from_arn(arn): """Extract the resource name from an ARN string. diff --git a/src/sagemaker/lineage/artifact.py b/src/sagemaker/lineage/artifact.py index 3921562beb..718344095a 100644 --- a/src/sagemaker/lineage/artifact.py +++ b/src/sagemaker/lineage/artifact.py @@ -29,8 +29,9 @@ LineageEntityEnum, LineageQueryDirectionEnum, ) -from sagemaker.lineage._utils import get_module, _disassociate, get_resource_name_from_arn +from sagemaker.lineage._utils import _disassociate, get_resource_name_from_arn from sagemaker.lineage.association import Association +from sagemaker.utils import get_module LOGGER = logging.getLogger("sagemaker") diff --git a/src/sagemaker/model_monitor/clarify_model_monitoring.py b/src/sagemaker/model_monitor/clarify_model_monitoring.py index 1a788a0d53..030de7c6db 100644 --- a/src/sagemaker/model_monitor/clarify_model_monitoring.py +++ b/src/sagemaker/model_monitor/clarify_model_monitoring.py @@ -842,8 +842,8 @@ def __init__(self, bias_config, headers=None, label=None): bias_config (sagemaker.clarify.BiasConfig): Config object related to bias configurations. headers (list[str]): A list of column names in the input dataset. - label (str): Target attribute for the model required by bias metrics. - Specified as column name or index for CSV dataset, or as JSONPath for JSONLines. + label (str): Target attribute for the model required by bias metrics. Specified as + column name or index for CSV dataset, or as JMESPath expression for JSONLines. """ self.analysis_config = bias_config.get_config() if headers is not None: @@ -889,9 +889,10 @@ def suggest_baseline( model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its endpoint to be created. model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`): - Index or JSONPath to locate the predicted scores in the model output. This is not - required if the model output is a single score. Alternatively, it can be an instance - of ModelPredictedLabelConfig to provide more parameters like label_headers. + Index or JMESPath expression to locate the predicted scores in the model output. + This is not required if the model output is a single score. Alternatively, + it can be an instance of ModelPredictedLabelConfig to provide more parameters + like label_headers. wait (bool): Whether the call should wait until the job completes (default: False). logs (bool): Whether to show the logs produced by the job. Only meaningful when wait is True (default: False). @@ -1302,12 +1303,12 @@ def __init__( Args: analysis_config (BiasAnalysisConfig or ExplainabilityAnalysisConfig): analysis config from configurations of the baselining job. - features_attribute (str): JSONpath to locate features in predictor request payload. - Only required when predictor content type is JSONlines. - inference_attribute (str): Index, header or JSONpath to locate predicted label in - predictor response payload. - probability_attribute (str): Index or JSONpath location in the model output for - probabilities or scores to be used for explainability. + features_attribute (str): JMESPath expression to locate features in predictor request + payload. Only required when predictor content type is JSONlines. + inference_attribute (str): Index, header or JMESPath expression to locate predicted + label in predictor response payload. + probability_attribute (str): Index or JMESPath expression to locate probabilities or + scores in the model output for computing feature attribution. probability_threshold_attribute (float): Value to indicate the threshold to select the binary label in the case of binary classification. Default is 0.5. """ diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 817d951255..2f8266a43a 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -1061,12 +1061,13 @@ def _generate_env_map( dataset_format (dict): The format of the baseline_dataset. dataset_source_container_path (str): The path to the dataset source. inference_attribute (str): Index or JSONpath to locate predicted label(s). - Only used for ModelQualityMonitor, ModelBiasMonitor, and ModelExplainabilityMonitor + Only used for ModelQualityMonitor. probability_attribute (str or int): Index or JSONpath to locate probabilities. - Only used for ModelQualityMonitor, ModelBiasMonitor and ModelExplainabilityMonitor - ground_truth_attribute (str): Index or JSONpath to locate actual label(s). + Only used for ModelQualityMonitor. + ground_truth_attribute (str): Index to locate actual label(s). + Only used for ModelQualityMonitor. probability_threshold_attribute (float): threshold to convert probabilities to binaries - Only used for ModelQualityMonitor, ModelBiasMonitor and ModelExplainabilityMonitor + Only used for ModelQualityMonitor. Returns: dict: Dictionary of environment keys and values. @@ -2600,10 +2601,13 @@ def suggest_baseline( problem_type (str): The type of problem of this model quality monitoring. Valid values are "Regression", "BinaryClassification", "MulticlassClassification". inference_attribute (str): Index or JSONpath to locate predicted label(s). + Only used for ModelQualityMonitor. probability_attribute (str or int): Index or JSONpath to locate probabilities. - ground_truth_attribute (str): Index or JSONpath to locate actual label(s). + Only used for ModelQualityMonitor. + ground_truth_attribute (str): Index to locate actual label(s). + Only used for ModelQualityMonitor. probability_threshold_attribute (float): threshold to convert probabilities to binaries - Only used for ModelQualityMonitor, ModelBiasMonitor and ModelExplainabilityMonitor + Only used for ModelQualityMonitor. post_analytics_processor_script (str): The path to the record post-analytics processor script. This can be a local path or an S3 uri. output_s3_uri (str): Desired S3 destination Destination of the constraint_violations diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index db6ce2badd..af52da6288 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -23,6 +23,7 @@ import logging from textwrap import dedent from typing import Dict, List, Optional, Union +from copy import copy import attr @@ -32,7 +33,12 @@ from sagemaker.job import _Job from sagemaker.local import LocalSession from sagemaker.network import NetworkConfig -from sagemaker.utils import base_name_from_image, get_config_value, name_from_base +from sagemaker.utils import ( + base_name_from_image, + get_config_value, + name_from_base, + check_and_get_run_experiment_config, +) from sagemaker.session import Session from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.functions import Join @@ -202,6 +208,7 @@ def run( outputs=outputs, ) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_job = ProcessingJob.start_new( processor=self, inputs=normalized_inputs, @@ -604,6 +611,7 @@ def run( kms_key=kms_key, ) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_job = ProcessingJob.start_new( processor=self, inputs=normalized_inputs, @@ -1587,13 +1595,13 @@ def run( # type: ignore[override] framework script to run.Path (absolute or relative) to the local Python source file which should be executed as the entry point to training. When `code` is an S3 URI, ignore `source_dir`, - `dependencies, and `git_config`. If ``source_dir`` is specified, + `dependencies`, and `git_config`. If ``source_dir`` is specified, then ``code`` must point to a file located at the root of ``source_dir``. source_dir (str): Path (absolute, relative or an S3 URI) to a directory with any other processing source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when processing on Amazon SageMaker (default: None). + point to a file named `sourcedir.tar.gz`. Structure within this directory + are preserved when processing on Amazon SageMaker (default: None). dependencies (list[str]): A list of paths to directories (absolute or relative) with any additional libraries that will be exported to the container (default: []). The library folders will be @@ -1730,20 +1738,17 @@ def _pack_and_upload_code( "sagemaker_session unspecified when creating your Processor to have one set up " "automatically." ) + if "/sourcedir.tar.gz" in estimator.uploaded_code.s3_prefix: + # Upload the bootstrapping code as s3://.../jobname/source/runproc.sh. + entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace( + "sourcedir.tar.gz", + "runproc.sh", + ) + else: + raise RuntimeError("S3 source_dir file must be named `sourcedir.tar.gz.`") - # Upload the bootstrapping code as s3://.../jobname/source/runproc.sh. - entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace( - "sourcedir.tar.gz", - "runproc.sh", - ) script = estimator.uploaded_code.script_name - s3_runproc_sh = S3Uploader.upload_string_as_file_body( - self._generate_framework_script(script), - desired_s3_uri=entrypoint_s3_uri, - kms_key=kms_key, - sagemaker_session=self.sagemaker_session, - ) - logger.info("runproc.sh uploaded to %s", s3_runproc_sh) + s3_runproc_sh = self._create_and_upload_runproc(script, kms_key, entrypoint_s3_uri) return s3_runproc_sh, inputs, job_name @@ -1827,14 +1832,17 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput # a7399455f5386d83ddc5cb15c0db00c04bd518ec/src/sagemaker/processing.py#L425-L426 if inputs is None: inputs = [] - inputs.append( + + # make a shallow copy of user inputs + patched_inputs = copy(inputs) + patched_inputs.append( ProcessingInput( input_name="code", source=s3_payload, destination="/opt/ml/processing/input/code/", ) ) - return inputs + return patched_inputs def _set_entrypoint(self, command, user_script_name): """Framework processor override for setting processing job entrypoint. @@ -1850,3 +1858,42 @@ def _set_entrypoint(self, command, user_script_name): ) ) self.entrypoint = self.framework_entrypoint_command + [user_script_location] + + def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): + """Create runproc shell script and upload to S3 bucket. + + If leveraging a pipeline session with optimized S3 artifact paths, + the runproc.sh file is hashed and uploaded to a separate S3 location. + + + Args: + user_script (str): Relative path to ```code``` in the source bundle + - e.g. 'process.py'. + kms_key (str): THe kms key used for encryption. + entrypoint_s3_uri (str): The S3 upload path for the runproc script. + """ + from sagemaker.workflow.utilities import _pipeline_config, hash_object + + if _pipeline_config and _pipeline_config.pipeline_name: + runproc_file_str = self._generate_framework_script(user_script) + runproc_file_hash = hash_object(runproc_file_str) + s3_uri = ( + f"s3://{self.sagemaker_session.default_bucket()}/{_pipeline_config.pipeline_name}/" + f"code/{runproc_file_hash}/runproc.sh" + ) + s3_runproc_sh = S3Uploader.upload_string_as_file_body( + runproc_file_str, + desired_s3_uri=s3_uri, + kms_key=kms_key, + sagemaker_session=self.sagemaker_session, + ) + else: + s3_runproc_sh = S3Uploader.upload_string_as_file_body( + self._generate_framework_script(user_script), + desired_s3_uri=entrypoint_s3_uri, + kms_key=kms_key, + sagemaker_session=self.sagemaker_session, + ) + logger.info("runproc.sh uploaded to %s", s3_runproc_sh) + + return s3_runproc_sh diff --git a/src/sagemaker/pytorch/__init__.py b/src/sagemaker/pytorch/__init__.py index cac5f94b9a..e2d14f4163 100644 --- a/src/sagemaker/pytorch/__init__.py +++ b/src/sagemaker/pytorch/__init__.py @@ -16,3 +16,5 @@ from sagemaker.pytorch.estimator import PyTorch # noqa: F401 from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor # noqa: F401 from sagemaker.pytorch.processing import PyTorchProcessor # noqa: F401 + +from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig # noqa: F401 diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 686de4a78c..29e254662f 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -28,6 +28,7 @@ ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel +from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable @@ -51,7 +52,8 @@ def __init__( hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, distribution: Optional[Dict] = None, - **kwargs + compiler_config: Optional[TrainingCompilerConfig] = None, + **kwargs, ): """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment. @@ -208,6 +210,31 @@ def __init__( To learn more, see `Training with parameter servers `_. + **To enable distributed training with + `SageMaker Training Compiler `_ + for PyTorch:** + + .. code:: python + + { + "pytorchxla": { + "enabled": True + } + } + + To learn more, see `SageMaker Training Compiler + `_ + in the *Amazon SageMaker Developer Guide*. + + .. note:: + + When you use this PyTorch XLA option for distributed training strategy, + you must add the ``compiler_config`` parameter and activate SageMaker + Training Compiler. + + compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`): + Configures SageMaker Training Compiler to accelerate training. + **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. @@ -250,6 +277,25 @@ def __init__( self.distribution = distribution or {} + if compiler_config is not None: + if not isinstance(compiler_config, TrainingCompilerConfig): + error_string = ( + f"Expected instance of type {TrainingCompilerConfig}" + f"for argument compiler_config. " + f"Instead got {type(compiler_config)}" + ) + raise ValueError(error_string) + if compiler_config: + compiler_config.validate(self) + elif distribution is not None and "pytorchxla" in distribution: + raise ValueError( + "Distributed training through PyTorch XLA is currently only supported " + "when SageMaker Training Compiler is enabled. To learn more, " + "see Enable SageMaker Training Compiler at " + "https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html." + ) + self.compiler_config = compiler_config + def _pytorch_distribution_configuration(self, distribution): """Returns a dict of distribution config for PyTorch training @@ -289,6 +335,12 @@ def hyperparameters(self): hyperparameters.update( EstimatorBase._json_encode_hyperparameters(additional_hyperparameters) ) + if self.compiler_config: + training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict() + hyperparameters.update( + EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters) + ) + return hyperparameters def create_model( @@ -299,7 +351,7 @@ def create_model( entry_point=None, source_dir=None, dependencies=None, - **kwargs + **kwargs, ): """Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``. @@ -350,7 +402,7 @@ def create_model( sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), dependencies=(dependencies or self.dependencies), - **kwargs + **kwargs, ) @classmethod @@ -371,6 +423,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na ) image_uri = init_params.pop("image_uri") framework, py_version, tag, _ = framework_name_from_image(image_uri) + if framework: + framework = framework.split("-")[0] if tag is None: framework_version = None diff --git a/src/sagemaker/pytorch/training_compiler/__init__.py b/src/sagemaker/pytorch/training_compiler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/pytorch/training_compiler/config.py b/src/sagemaker/pytorch/training_compiler/config.py new file mode 100644 index 0000000000..7faf8acbbd --- /dev/null +++ b/src/sagemaker/pytorch/training_compiler/config.py @@ -0,0 +1,151 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Configuration for the SageMaker Training Compiler.""" +from __future__ import absolute_import +import logging +from typing import Union +from packaging.specifiers import SpecifierSet +from packaging.version import Version + +from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig +from sagemaker.workflow.entities import PipelineVariable + +logger = logging.getLogger(__name__) + + +class TrainingCompilerConfig(BaseConfig): + """The SageMaker Training Compiler configuration class.""" + + SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"] + SUPPORTED_INSTANCE_TYPES_WITH_EFA = [ + "ml.g4dn.8xlarge", + "ml.g4dn.12xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", + "ml.p4d.24xlarge", + ] + + def __init__( + self, + enabled: Union[bool, PipelineVariable] = True, + debug: Union[bool, PipelineVariable] = False, + ): + """This class initializes a ``TrainingCompilerConfig`` instance. + + `Amazon SageMaker Training Compiler + `_ + is a feature of SageMaker Training + and speeds up training jobs by optimizing model execution graphs. + + You can compile PyTorch models + by passing the object of this configuration class to the ``compiler_config`` + parameter of the :class:`~sagemaker.pytorch.PyTorch` + estimator. + + Args: + enabled (bool or PipelineVariable): Optional. Switch to enable SageMaker + Training Compiler. The default is ``True``. + debug (bool or PipelineVariable): Optional. Whether to dump detailed logs + for debugging. This comes with a potential performance slowdown. + The default is ``False``. + + **Example**: The following code shows the basic usage of the + :class:`sagemaker.pytorch.TrainingCompilerConfig()` class + to run a PyTorch training job with the compiler. + + .. code-block:: python + + from sagemaker.pytorch import PyTorch, TrainingCompilerConfig + + pytorch_estimator=PyTorch( + ... + compiler_config=TrainingCompilerConfig() + ) + + .. seealso:: + + For more information about how to enable SageMaker Training Compiler + for various training settings such as distributed training, + see `Enable SageMaker Training Compiler + `_ + in the `Amazon SageMaker Training Compiler developer guide + `_. + + """ + + super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug) + + @classmethod + def validate( + cls, + estimator, + ): + """Checks if SageMaker Training Compiler is configured correctly. + + Args: + estimator (:class:`sagemaker.pytorch.PyTorch`): An estimator object. + If SageMaker Training Compiler is enabled, it will validate whether + the estimator is configured to be compatible with Training Compiler. + + Raises: + ValueError: Raised if the requested configuration is not compatible + with SageMaker Training Compiler. + """ + + super(TrainingCompilerConfig, cls).validate(estimator) + + if estimator.image_uri: + error_helper_string = ( + "Overriding the image URI is currently not supported " + "for SageMaker Training Compiler." + "Specify the following parameters to run the PyTorch training job " + "with SageMaker Training Compiler enabled: " + "framework_version, and compiler_config." + ) + raise ValueError(error_helper_string) + + if estimator.distribution: + pt_xla_present = "pytorchxla" in estimator.distribution + pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False) + if pt_xla_enabled: + if estimator.framework_version: + if Version(estimator.framework_version) in SpecifierSet("< 1.12"): + error_helper_string = ( + "Distribution mechanism 'pytorchxla' is currently only supported for " + "PyTorch >= 1.12 when SageMaker Training Compiler is enabled." + " Received framework_version={} which is unsupported." + ) + raise ValueError(error_helper_string.format(estimator.framework_version)) + if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA: + logger.warning( + "Consider using instances with EFA support when " + "training with PyTorch >= 1.12 and SageMaker Training Compiler " + "enabled. SageMaker Training Compiler leverages EFA to provide better " + "performance for distributed training." + ) + if not pt_xla_present: + if estimator.framework_version: + if Version(estimator.framework_version) in SpecifierSet(">= 1.12"): + error_helper_string = ( + "'pytorchxla' is the only distribution mechanism currently supported " + "for PyTorch >= 1.12 when SageMaker Training Compiler is enabled." + " Received distribution={} which is unsupported." + ) + raise ValueError(error_helper_string.format(estimator.distribution)) + elif estimator.instance_count and estimator.instance_count > 1: + if estimator.framework_version: + if Version(estimator.framework_version) in SpecifierSet(">= 1.12"): + logger.warning( + "Consider setting 'distribution' to 'pytorchxla' for distributed " + "training with PyTorch >= 1.12 and SageMaker Training Compiler enabled." + ) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 00797c9ea0..5404978200 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -89,6 +89,7 @@ def __init__( sagemaker_featurestore_runtime_client=None, default_bucket=None, settings=SessionSettings(), + sagemaker_metrics_client=None, ): """Initialize a SageMaker ``Session``. @@ -116,6 +117,10 @@ def __init__( Example: "sagemaker-my-custom-bucket". settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional parameters to apply to the session. + sagemaker_metrics_client (boto3.SageMakerMetrics.Client): + Client which makes SageMaker Metrics related calls to Amazon SageMaker + (default: None). If not provided, one will be created using + this instance's ``boto_session``. """ self._default_bucket = None self._default_bucket_name_override = default_bucket @@ -130,6 +135,7 @@ def __init__( sagemaker_client=sagemaker_client, sagemaker_runtime_client=sagemaker_runtime_client, sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client, + sagemaker_metrics_client=sagemaker_metrics_client, ) def _initialize( @@ -138,6 +144,7 @@ def _initialize( sagemaker_client, sagemaker_runtime_client, sagemaker_featurestore_runtime_client, + sagemaker_metrics_client, ): """Initialize this SageMaker Session. @@ -172,6 +179,12 @@ def _initialize( "sagemaker-featurestore-runtime" ) + if sagemaker_metrics_client: + self.sagemaker_metrics_client = sagemaker_metrics_client + else: + self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics") + prepend_user_agent(self.sagemaker_metrics_client) + self.local_mode = False @property @@ -312,7 +325,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): # For each object key, create the directory on the local machine if needed, and then # download the file. for key in keys: - tail_s3_uri_path = os.path.basename(key_prefix) + tail_s3_uri_path = os.path.basename(key) if not os.path.splitext(key_prefix)[1]: tail_s3_uri_path = os.path.relpath(key, key_prefix) destination_path = os.path.join(path, tail_s3_uri_path) @@ -548,8 +561,8 @@ def train( # noqa: C901 checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -558,6 +571,7 @@ def train( # noqa: C901 * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. enable_sagemaker_metrics (bool): enable SageMaker Metrics Time Series. For more information see: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries @@ -703,8 +717,8 @@ def _get_train_request( # noqa: C901 checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -713,6 +727,7 @@ def _get_train_request( # noqa: C901 * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. enable_sagemaker_metrics (bool): enable SageMaker Metrics Time Series. For more information see: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries @@ -2121,6 +2136,7 @@ def tune( # noqa: C901 stop_condition, tags, warm_start_config, + strategy_config=None, enable_network_isolation=False, image_uri=None, algorithm_arn=None, @@ -2130,12 +2146,15 @@ def tune( # noqa: C901 use_spot_instances=False, checkpoint_s3_uri=None, checkpoint_local_path=None, + random_seed=None, ): """Create an Amazon SageMaker hyperparameter tuning job. Args: job_name (str): Name of the tuning job being created. strategy (str): Strategy to be used for hyperparameter estimations. + strategy_config (dict): A configuration for the hyperparameter tuning + job optimisation strategy. objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize'. objective_metric_name (str): Name of the metric for evaluating training jobs. @@ -2208,6 +2227,9 @@ def tune( # noqa: C901 started. If the path is unset then SageMaker assumes the checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). + random_seed (int): An initial value used to initialize a pseudo-random number generator. + Setting a random seed will make the hyperparameter tuning search strategies to + produce more consistent configurations for the same tuning job. (default: ``None``). """ tune_request = { @@ -2220,6 +2242,8 @@ def tune( # noqa: C901 objective_metric_name=objective_metric_name, parameter_ranges=parameter_ranges, early_stopping_type=early_stopping_type, + random_seed=random_seed, + strategy_config=strategy_config, ), "TrainingJobDefinition": self._map_training_config( static_hyperparameters=static_hyperparameters, @@ -2375,6 +2399,8 @@ def _map_tuning_config( objective_type=None, objective_metric_name=None, parameter_ranges=None, + random_seed=None, + strategy_config=None, ): """Construct tuning job configuration dictionary. @@ -2392,6 +2418,11 @@ def _map_tuning_config( objective_metric_name (str): Name of the metric for evaluating training jobs. parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can be one of three types: Continuous, Integer, or Categorical. + random_seed (int): An initial value used to initialize a pseudo-random number generator. + Setting a random seed will make the hyperparameter tuning search strategies to + produce more consistent configurations for the same tuning job. + strategy_config (dict): A configuration for the hyperparameter tuning job optimisation + strategy. Returns: A dictionary of tuning job configuration. For format details, please refer to @@ -2408,6 +2439,9 @@ def _map_tuning_config( "TrainingJobEarlyStoppingType": early_stopping_type, } + if random_seed is not None: + tuning_config["RandomSeed"] = random_seed + tuning_objective = cls._map_tuning_objective(objective_type, objective_metric_name) if tuning_objective is not None: tuning_config["HyperParameterTuningJobObjective"] = tuning_objective @@ -2415,6 +2449,8 @@ def _map_tuning_config( if parameter_ranges is not None: tuning_config["ParameterRanges"] = parameter_ranges + if strategy_config is not None: + tuning_config["StrategyConfig"] = strategy_config return tuning_config @classmethod @@ -3300,6 +3336,11 @@ def create_endpoint_config_from_existing( if request_data_capture_config_dict is not None: request["DataCaptureConfig"] = request_data_capture_config_dict + if existing_endpoint_config_desc.get("AsyncInferenceConfig") is not None: + request["AsyncInferenceConfig"] = existing_endpoint_config_desc.get( + "AsyncInferenceConfig", None + ) + self.sagemaker_client.create_endpoint_config(**request) def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True): @@ -4332,6 +4373,56 @@ def update_feature_group( FeatureGroupName=feature_group_name, FeatureAdditions=feature_additions ) + def list_feature_groups( + self, + name_contains, + feature_group_status_equals, + offline_store_status_equals, + creation_time_after, + creation_time_before, + sort_order, + sort_by, + max_results, + next_token, + ) -> Dict[str, Any]: + """List all FeatureGroups satisfying given filters. + + Args: + name_contains (str): A string that partially matches one or more FeatureGroups' names. + Filters FeatureGroups by name. + feature_group_status_equals (str): A FeatureGroup status. + Filters FeatureGroups by FeatureGroup status. + offline_store_status_equals (str): An OfflineStore status. + Filters FeatureGroups by OfflineStore status. + creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups + created after a specific date and time. + creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups + created before a specific date and time. + sort_order (str): The order in which FeatureGroups are listed. + sort_by (str): The value on which the FeatureGroup list is sorted. + max_results (int): The maximum number of results returned by ListFeatureGroups. + next_token (str): A token to resume pagination of ListFeatureGroups results. + Returns: + Response dict from service. + """ + list_feature_groups_args = {} + + def check_object(key, value): + if value is not None: + list_feature_groups_args[key] = value + + check_object("NameContains", name_contains) + check_object("FeatureGroupStatusEquals", feature_group_status_equals) + check_object("OfflineStoreStatusEquals", offline_store_status_equals) + check_object("CreationTimeAfter", creation_time_after) + check_object("CreationTimeBefore", creation_time_before) + check_object("SortOrder", sort_order) + check_object("SortBy", sort_by) + check_object("MaxResults", max_results) + check_object("NextToken", next_token) + + return self.sagemaker_client.list_feature_groups(**list_feature_groups_args) + def update_feature_metadata( self, feature_group_name: str, @@ -4399,6 +4490,48 @@ def put_record( Record=record, ) + def delete_record( + self, + feature_group_name: str, + record_identifier_value_as_string: str, + event_time: str, + ): + """Deletes a single record from the FeatureGroup. + + Args: + feature_group_name (str): name of the FeatureGroup. + record_identifier_value_as_string (str): name of the record identifier. + event_time (str): a timestamp indicating when the deletion event occurred. + """ + return self.sagemaker_featurestore_runtime_client.delete_record( + FeatureGroupName=feature_group_name, + RecordIdentifierValueAsString=record_identifier_value_as_string, + EventTime=event_time, + ) + + def get_record( + self, + record_identifier_value_as_string: str, + feature_group_name: str, + feature_names: Sequence[str], + ) -> Dict[str, Sequence[Dict[str, str]]]: + """Gets a single record in the FeatureGroup. + + Args: + record_identifier_value_as_string (str): name of the record identifier. + feature_group_name (str): name of the FeatureGroup. + feature_names (Sequence[str]): list of feature names. + """ + get_record_args = { + "FeatureGroupName": feature_group_name, + "RecordIdentifierValueAsString": record_identifier_value_as_string, + } + + if feature_names: + get_record_args["FeatureNames"] = feature_names + + return self.sagemaker_featurestore_runtime_client.get_record(**get_record_args) + def start_query_execution( self, catalog: str, diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index dc3d26a355..912bc90d80 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -30,6 +30,7 @@ from enum import Enum from io import BytesIO from urllib.parse import urlparse +from copy import copy from typing import Union, List, Dict, Optional @@ -279,6 +280,10 @@ def run( def _extend_processing_args(self, inputs, outputs, **kwargs): """Extends processing job args such as inputs.""" + # make a shallow copy of user outputs + outputs = outputs or [] + extended_outputs = copy(outputs) + if kwargs.get("spark_event_logs_s3_uri"): spark_event_logs_s3_uri = kwargs.get("spark_event_logs_s3_uri") self._validate_s3_uri(spark_event_logs_s3_uri) @@ -297,16 +302,21 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): s3_upload_mode="Continuous", ) - outputs = outputs or [] - outputs.append(output) + extended_outputs.append(output) + + # make a shallow copy of user inputs + inputs = inputs or [] + extended_inputs = copy(inputs) if kwargs.get("configuration"): configuration = kwargs.get("configuration") self._validate_configuration(configuration) - inputs = inputs or [] - inputs.append(self._stage_configuration(configuration)) + extended_inputs.append(self._stage_configuration(configuration)) - return inputs, outputs + return ( + extended_inputs if extended_inputs else None, + extended_outputs if extended_outputs else None, + ) def start_history_server(self, spark_event_logs_s3_uri=None): """Starts a Spark history server. @@ -940,9 +950,16 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): outputs: Processing outputs. kwargs: Additional keyword arguments passed to `super()`. """ + + if inputs is None: + inputs = [] + + # make a shallow copy of user inputs + extended_inputs = copy(inputs) + self.command = [_SparkProcessorBase._default_command] extended_inputs = self._handle_script_dependencies( - inputs, kwargs.get("submit_py_files"), FileType.PYTHON + extended_inputs, kwargs.get("submit_py_files"), FileType.PYTHON ) extended_inputs = self._handle_script_dependencies( extended_inputs, kwargs.get("submit_jars"), FileType.JAR @@ -1199,8 +1216,14 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): else: raise ValueError("submit_class is required") + if inputs is None: + inputs = [] + + # make a shallow copy of user inputs + extended_inputs = copy(inputs) + extended_inputs = self._handle_script_dependencies( - inputs, kwargs.get("submit_jars"), FileType.JAR + extended_inputs, kwargs.get("submit_jars"), FileType.JAR ) extended_inputs = self._handle_script_dependencies( extended_inputs, kwargs.get("submit_files"), FileType.FILE diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index cfcc637b99..40ed143ebc 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -14,17 +14,24 @@ from __future__ import absolute_import from typing import Union, Optional, List, Dict -from botocore import exceptions +import logging +import copy +import time +from botocore import exceptions from sagemaker.job import _Job -from sagemaker.session import Session +from sagemaker.session import Session, get_execution_role from sagemaker.inputs import BatchDataCaptureConfig from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.functions import Join -from sagemaker.workflow.pipeline_context import runnable_by_pipeline +from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.execution_variables import ExecutionVariables -from sagemaker.utils import base_name_from_image, name_from_base +from sagemaker.utils import ( + base_name_from_image, + name_from_base, + check_and_get_run_experiment_config, +) class Transformer(object): @@ -248,6 +255,7 @@ def transform( ) self._reset_output_path = True + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_transform_job = _TransformJob.start_new( self, data, @@ -266,6 +274,155 @@ def transform( if wait: self.latest_transform_job.wait(logs=logs) + def transform_with_monitoring( + self, + monitoring_config, + monitoring_resource_config, + data: str, + data_type: str = "S3Prefix", + content_type: str = None, + compression_type: str = None, + split_type: str = None, + input_filter: str = None, + output_filter: str = None, + join_source: str = None, + model_client_config: Dict[str, str] = None, + batch_data_capture_config: BatchDataCaptureConfig = None, + monitor_before_transform: bool = False, + supplied_baseline_statistics: str = None, + supplied_baseline_constraints: str = None, + wait: bool = True, + pipeline_name: str = None, + role: str = None, + ): + """Runs a transform job with monitoring job. + + Note that this function will not start a transform job immediately, + instead, it will create a SageMaker Pipeline and execute it. + If you provide an existing pipeline_name, no new pipeline will be created, otherwise, + each transform_with_monitoring call will create a new pipeline and execute. + + Args: + monitoring_config (Union[ + `sagemaker.workflow.quality_check_step.QualityCheckConfig`, + `sagemaker.workflow.quality_check_step.ClarifyCheckConfig` + ]): the monitoring configuration used for run model monitoring. + monitoring_resource_config (`sagemaker.workflow.check_job_config.CheckJobConfig`): + the check job (processing job) cluster resource configuration. + transform_step_args (_JobStepArguments): the transform step transform arguments. + data (str): Input data location in S3 for the transform job + data_type (str): What the S3 location defines (default: 'S3Prefix'). + Valid values: + * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix + will be used as inputs for the transform job. + * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 + object to use as an input for the transform job. + content_type (str): MIME type of the input data (default: None). + compression_type (str): Compression type of the input data, if + compressed (default: None). Valid values: 'Gzip', None. + split_type (str): The record delimiter for the input object + (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and + 'TFRecord'. + input_filter (str): A JSONPath to select a portion of the input to + pass to the algorithm container for inference. If you omit the + field, it gets the value '$', representing the entire input. + For CSV data, each row is taken as a JSON array, + so only index-based JSONPaths can be applied, e.g. $[0], $[1:]. + CSV data should follow the `RFC format `_. + See `Supported JSONPath Operators + `_ + for a table of supported JSONPath operators. + For more information, see the SageMaker API documentation for + `CreateTransformJob + `_. + Some examples: "$[1:]", "$.features" (default: None). + output_filter (str): A JSONPath to select a portion of the + joined/original output to return as the output. + For more information, see the SageMaker API documentation for + `CreateTransformJob + `_. + Some examples: "$[1:]", "$.prediction" (default: None). + join_source (str): The source of data to be joined to the transform + output. It can be set to 'Input' meaning the entire input record + will be joined to the inference result. You can use OutputFilter + to select the useful portion before uploading to S3. (default: + None). Valid values: Input, None. + model_client_config (dict[str, str]): Model configuration. + Dictionary contains two optional keys, + 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. + (default: ``None``). + batch_data_capture_config (BatchDataCaptureConfig): Configuration object which + specifies the configurations related to the batch data capture for the transform job + (default: ``None``). + monitor_before_transform (bgool): If to run data quality + or model explainability monitoring type, + a true value of this flag indicates running the check step before the transform job. + fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the + check step when a violation is detected. + supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path + to the supplied statistics object representing the statistics JSON file + which will be used for drift to check (default: None). + supplied_baseline_constraints (Union[str, PipelineVariable]): The S3 path + to the supplied constraints object representing the constraints JSON file + which will be used for drift to check (default: None). + wait (bool): To determine if needed to wait for the pipeline execution to complete + pipeline_name (str): The name of the Pipeline for the monitoring and transfrom step + role (str): Execution role + """ + + transformer = self + if not isinstance(self.sagemaker_session, PipelineSession): + sagemaker_session = self.sagemaker_session + self.sagemaker_session = None + transformer = copy.deepcopy(self) + transformer.sagemaker_session = PipelineSession() + self.sagemaker_session = sagemaker_session + + transform_step_args = transformer.transform( + data=data, + data_type=data_type, + content_type=content_type, + compression_type=compression_type, + split_type=split_type, + input_filter=input_filter, + output_filter=output_filter, + batch_data_capture_config=batch_data_capture_config, + join_source=join_source, + model_client_config=model_client_config, + ) + + from sagemaker.workflow.monitor_batch_transform_step import MonitorBatchTransformStep + + monitoring_batch_step = MonitorBatchTransformStep( + name="MonitorBatchTransformStep", + display_name="MonitorBatchTransformStep", + description="", + transform_step_args=transform_step_args, + monitor_configuration=monitoring_config, + check_job_configuration=monitoring_resource_config, + monitor_before_transform=monitor_before_transform, + supplied_baseline_constraints=supplied_baseline_constraints, + supplied_baseline_statistics=supplied_baseline_statistics, + ) + + pipeline_name = ( + pipeline_name if pipeline_name else f"TransformWithMonitoring{int(time.time())}" + ) + # if pipeline exists, just start the execution + from sagemaker.workflow.pipeline import Pipeline + + pipeline = Pipeline( + name=pipeline_name, + steps=[monitoring_batch_step], + sagemaker_session=transformer.sagemaker_session, + ) + pipeline.upsert(role_arn=role if role else get_execution_role()) + execution = pipeline.start() + if wait: + logging.info("Waiting for transform with monitoring to execute ...") + execution.wait() + return execution + def delete_model(self): """Delete the corresponding SageMaker model for this Transformer.""" self.sagemaker_session.delete_model(self.model_name) diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 52b9d81d0d..45a6467c1f 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -282,8 +282,8 @@ def from_job_desc(cls, hyperband_strategy_config): Returns: sagemaker.tuner.HyperbandStrategyConfig: De-serialized instance of - HyperbandStrategyConfig containing the max_resource and min_resource provided as part of - ``hyperband_strategy_config``. + ``HyperbandStrategyConfig`` containing the max_resource + and min_resource provided as part of ``hyperband_strategy_config``. """ return cls( min_resource=hyperband_strategy_config[HYPERBAND_MIN_RESOURCE], @@ -306,7 +306,7 @@ def to_input_req(self): Returns: dict: Containing the "MaxResource" and - "MinResource" as the first class fields. + "MinResource" as the first class fields. """ return { HYPERBAND_MIN_RESOURCE: self.min_resource, @@ -330,7 +330,7 @@ def __init__( Args: hyperband_strategy_config (sagemaker.tuner.HyperbandStrategyConfig): The configuration - for the object that specifies the Hyperband strategy. + for the object that specifies the Hyperband strategy. This parameter is only supported for the Hyperband selection for Strategy within the HyperParameterTuningJobConfig. """ @@ -413,6 +413,7 @@ def __init__( strategy_config: Optional[StrategyConfig] = None, early_stopping_type: Union[str, PipelineVariable] = "Off", estimator_name: Optional[str] = None, + random_seed: Optional[int] = None, ): """Creates a ``HyperparameterTuner`` instance. @@ -461,7 +462,7 @@ def __init__( ``WarmStartConfig`` object that has been initialized with the configuration defining the nature of warm start tuning job. strategy_config (sagemaker.tuner.StrategyConfig): A configuration for "Hyperparameter" - tuning job optimisation strategy. + tuning job optimisation strategy. early_stopping_type (str or PipelineVariable): Specifies whether early stopping is enabled for the job. Can be either 'Auto' or 'Off' (default: 'Off'). If set to 'Off', early stopping will not be attempted. @@ -470,6 +471,9 @@ def __init__( estimator_name (str): A unique name to identify an estimator within the hyperparameter tuning job, when more than one estimator is used with the same tuning job (default: None). + random_seed (int): An initial value used to initialize a pseudo-random number generator. + Setting a random seed will make the hyperparameter tuning search strategies to + produce more consistent configurations for the same tuning job. """ if hyperparameter_ranges is None or len(hyperparameter_ranges) == 0: raise ValueError("Need to specify hyperparameter ranges") @@ -516,6 +520,7 @@ def __init__( self.latest_tuning_job = None self.warm_start_config = warm_start_config self.early_stopping_type = early_stopping_type + self.random_seed = random_seed def _prepare_for_tuning(self, job_name=None, include_cls_metadata=False): """Prepare the tuner instance for tuning (fit).""" @@ -1222,6 +1227,9 @@ def _prepare_init_params_from_job_description(cls, job_details): "base_tuning_job_name": base_from_name(job_details["HyperParameterTuningJobName"]), } + if "RandomSeed" in tuning_config: + params["random_seed"] = tuning_config["RandomSeed"] + if "HyperParameterTuningJobObjective" in tuning_config: params["objective_metric_name"] = tuning_config["HyperParameterTuningJobObjective"][ "MetricName" @@ -1483,6 +1491,7 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato warm_start_type=warm_start_type, parents=all_parents ), early_stopping_type=self.early_stopping_type, + random_seed=self.random_seed, ) if len(self.estimator_dict) > 1: @@ -1508,6 +1517,7 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato max_parallel_jobs=self.max_parallel_jobs, warm_start_config=WarmStartConfig(warm_start_type=warm_start_type, parents=all_parents), early_stopping_type=self.early_stopping_type, + random_seed=self.random_seed, ) @classmethod @@ -1526,6 +1536,7 @@ def create( tags=None, warm_start_config=None, early_stopping_type="Off", + random_seed=None, ): """Factory method to create a ``HyperparameterTuner`` instance. @@ -1569,7 +1580,7 @@ def create( strategy (str): Strategy to be used for hyperparameter estimations (default: 'Bayesian'). strategy_config (dict): The configuration for a training job launched by a - hyperparameter tuning job. + hyperparameter tuning job. objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize' (default: 'Maximize'). max_jobs (int): Maximum total number of training jobs to start for the hyperparameter @@ -1586,6 +1597,9 @@ def create( Can be either 'Auto' or 'Off' (default: 'Off'). If set to 'Off', early stopping will not be attempted. If set to 'Auto', early stopping of some training jobs may happen, but is not guaranteed to. + random_seed (int): An initial value used to initialize a pseudo-random number generator. + Setting a random seed will make the hyperparameter tuning search strategies to + produce more consistent configurations for the same tuning job. Returns: sagemaker.tuner.HyperparameterTuner: a new ``HyperparameterTuner`` object that can @@ -1624,6 +1638,7 @@ def create( tags=tags, warm_start_config=warm_start_config, early_stopping_type=early_stopping_type, + random_seed=random_seed, ) for estimator_name in estimator_names[1:]: @@ -1775,8 +1790,11 @@ def _get_tuner_args(cls, tuner, inputs): "early_stopping_type": tuner.early_stopping_type, } + if tuner.random_seed is not None: + tuning_config["random_seed"] = tuner.random_seed + if tuner.strategy_config is not None: - tuning_config["strategy_config"] = tuner.strategy_config + tuning_config["strategy_config"] = tuner.strategy_config.to_input_req() if tuner.objective_metric_name is not None: tuning_config["objective_type"] = tuner.objective_type diff --git a/src/sagemaker/utilities/search_expression.py b/src/sagemaker/utilities/search_expression.py new file mode 100644 index 0000000000..5b2aaf3226 --- /dev/null +++ b/src/sagemaker/utilities/search_expression.py @@ -0,0 +1,133 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Simplify Search Expression by provide a simplified DSL""" +from __future__ import absolute_import + +from enum import Enum, unique + +from sagemaker.apiutils._base_types import ApiObject + + +# TODO: we should update the lineage to use search expressions +# defined here in a separate change +@unique +class Operator(Enum): + """Search operators""" + + EQUALS = "Equals" + NOT_EQUALS = "NotEquals" + GREATER_THAN = "GreaterThan" + GREATER_THAN_OR_EQUAL = "GreaterThanOrEqualTo" + LESS_THAN = "LessThan" + LESS_THAN_OR_EQUAL = "LessThanOrEqualTo" + CONTAINS = "Contains" + EXISTS = "Exists" + NOT_EXISTS = "NotExists" + + +@unique +class BooleanOperator(Enum): + """Boolean search operation enum""" + + AND = "And" + OR = "Or" + + +class SearchObject(ApiObject): + """Search Object""" + + def to_boto(self): + """Convert a search object to boto""" + return ApiObject.to_boto(self) + + +class Filter(SearchObject): + """A Python class represent a Search Filter object.""" + + name = None + operator = None + value = None + + def __init__(self, name, operator=None, value=None, **kwargs): + """Construct a Filter object + + Args: + name (str): filter field name + operator (Operator): one of Operator enum + value (str): value of the field + """ + super().__init__(**kwargs) + self.name = name + self.operator = None if operator is None else operator.value + self.value = value + + +class NestedFilter(SearchObject): + """A Python class represent a Nested Filter object.""" + + nested_property_name = None + filters = None + + def __init__(self, property_name, filters, **kwargs): + """Construct a Nested Filter object + + Args: + property_name (str): nested property name + filters (List[Filter]): list of Filter objects + """ + super().__init__(**kwargs) + self.nested_property_name = property_name + self.filters = list(map(lambda x: x.to_boto(), filters)) + + +class SearchExpression(SearchObject): + """A Python class representation of a Search Expression object. + + A sample search expression defined in here: + https://boto3.amazonaws.com/v1/documentation/api/1.12.8/reference/services/sagemaker.html#SageMaker.Client.search + """ + + filters = None + nested_filters = None + operator = None + sub_expressions = None + + def __init__( + self, + filters=None, + nested_filters=None, + sub_expressions=None, + boolean_operator=BooleanOperator.AND, + **kwargs + ): + """Construct a Search Expression object + + Args: + filters (List[Filter]): list of Filter objects + nested_filters (List[NestedFilter]): list of Nested Filters objects + sub_expressions (List[SearchExpression]): list of Search Expression objects + boolean_operator (BooleanOperator): one of the boolean operator enums + """ + super().__init__(**kwargs) + if filters is None and nested_filters is None and sub_expressions is None: + raise ValueError( + "You must specify at least one subexpression, filter, or nested filter" + ) + self.filters = None if filters is None else list(map(lambda x: x.to_boto(), filters)) + self.nested_filters = ( + None if nested_filters is None else list(map(lambda x: x.to_boto(), nested_filters)) + ) + self.sub_expressions = ( + None if sub_expressions is None else list(map(lambda x: x.to_boto(), sub_expressions)) + ) + self.operator = boolean_operator.value diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index e668b2a8ed..9d28e3bf4e 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -29,6 +29,7 @@ from datetime import datetime from typing import Optional +from importlib import import_module import botocore from six.moves.urllib import parse @@ -590,6 +591,27 @@ def retries( ) +def retry_with_backoff(callable_func, num_attempts=8): + """Retry with backoff until maximum attempts are reached + + Args: + callable_func (callable): The callable function to retry. + num_attempts (int): The maximum number of attempts to retry. + """ + if num_attempts < 1: + raise ValueError( + "The num_attempts must be >= 1, but the given value is {}.".format(num_attempts) + ) + for i in range(num_attempts): + try: + return callable_func() + except Exception as ex: # pylint: disable=broad-except + if i == num_attempts - 1: + raise ex + logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex)) + time.sleep(2**i) + + def _botocore_resolver(): """Get the DNS suffix for the given region. @@ -874,3 +896,47 @@ def _start_waiting(waiting_time: int): print(progress, end="\r") time.sleep(interval) print(len(progress) * " ", end="\r") + + +def get_module(module_name): + """Import a module. + + Args: + module_name (str): name of the module to import. + + Returns: + object: The imported module. + + Raises: + Exception: when the module name is not found + """ + try: + return import_module(module_name) + except ImportError: + raise Exception("Cannot import module {}, please try again.".format(module_name)) + + +def check_and_get_run_experiment_config(experiment_config: Optional[dict] = None) -> dict: + """Check user input experiment_config or get it from the current Run object if exists. + + Args: + experiment_config (dict): The experiment_config supplied by the user. + + Returns: + dict: Return the user supplied experiment_config if it is not None. + Otherwise fetch the experiment_config from the current Run object if exists. + """ + from sagemaker.experiments._run_context import _RunContext + + run_obj = _RunContext.get_current_run() + if experiment_config: + if run_obj: + logger.warning( + "The function is invoked within an Experiment Run context " + "but another experiment_config (%s) was supplied, so " + "ignoring the experiment_config fetched from the Run object.", + experiment_config, + ) + return experiment_config + + return run_obj.experiment_config if run_obj else None diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 8ba65f1eee..cdef9537c1 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -13,6 +13,7 @@ """Scrapper utilities to support repacking of models.""" from __future__ import absolute_import +import logging import os import shutil import tarfile @@ -37,6 +38,8 @@ if TYPE_CHECKING: from sagemaker.workflow.step_collections import StepCollection +logger = logging.getLogger(__name__) + FRAMEWORK_VERSION = "0.23-1" INSTANCE_TYPE = "ml.m5.large" REPACK_SCRIPT = "_repack_model.py" @@ -479,10 +482,19 @@ def arguments(self) -> RequestType: request_dict = get_create_model_package_request(**model_package_args) # these are not available in the workflow service and will cause rejection + warn_msg_template = ( + "Popping out '%s' from the pipeline definition " + "since it will be overridden in pipeline execution time." + ) if "CertifyForMarketplace" in request_dict: request_dict.pop("CertifyForMarketplace") + logger.warning(warn_msg_template, "CertifyForMarketplace") if "Description" in request_dict: request_dict.pop("Description") + logger.warning(warn_msg_template, "Description") + if "ModelPackageName" in request_dict: + request_dict.pop("ModelPackageName") + logger.warning(warn_msg_template, "ModelPackageName") return request_dict diff --git a/src/sagemaker/workflow/clarify_check_step.py b/src/sagemaker/workflow/clarify_check_step.py index 9d350b01f3..22b6fc2051 100644 --- a/src/sagemaker/workflow/clarify_check_step.py +++ b/src/sagemaker/workflow/clarify_check_step.py @@ -132,8 +132,8 @@ class ModelExplainabilityCheckConfig(ClarifyCheckConfig): model_config (ModelConfig): Config of the model and its endpoint to be created. explainability_config (SHAPConfig): Config of the specific explainability method. Currently, only SHAP is supported. - model_scores (str or int or ModelPredictedLabelConfig): Index or JSONPath location - in the model output for the predicted scores to be explained (default: None). + model_scores (str or int or ModelPredictedLabelConfig): Index or JMESPath expression + to locate the predicted scores in the model output (default: None). This is not required if the model output is a single score. Alternatively, an instance of ModelPredictedLabelConfig can be provided but this field CANNOT be any type of the `PipelineVariable`. diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index 89d7c5dfd9..08c170d424 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -114,11 +114,12 @@ def get_code_hash(step: Entity) -> str: if isinstance(step, ProcessingStep) and step.step_args: kwargs = step.step_args.func_kwargs source_dir = kwargs.get("source_dir") + submit_class = kwargs.get("submit_class") dependencies = get_processing_dependencies( [ kwargs.get("dependencies"), kwargs.get("submit_py_files"), - kwargs.get("submit_class"), + [submit_class] if submit_class else None, kwargs.get("submit_jars"), kwargs.get("submit_files"), ] @@ -168,7 +169,7 @@ def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str] str: A hash string representing the unique code artifact(s) for the step """ - # FrameworkProcessor + # If FrameworkProcessor contains source_dir if source_dir: source_dir_url = urlparse(source_dir) if source_dir_url.scheme == "" or source_dir_url.scheme == "file": @@ -400,5 +401,5 @@ def execute_job_functions(step_args: _StepArguments): """ chained_args = step_args.func(*step_args.func_args, **step_args.func_kwargs) - if chained_args: + if isinstance(chained_args, _StepArguments): execute_job_functions(chained_args) diff --git a/tests/conftest.py b/tests/conftest.py index e92d98112b..f6682ebb8c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,6 +73,7 @@ "neo_pytorch", "neo_tensorflow", "pytorch", + "pytorch_training_compiler", "ray_pytorch", "ray_tensorflow", "sklearn", diff --git a/tests/data/experiment/inference.py b/tests/data/experiment/inference.py new file mode 100644 index 0000000000..cdb9a7b8c6 --- /dev/null +++ b/tests/data/experiment/inference.py @@ -0,0 +1,85 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +import logging +import os +import pickle as pkl + +import boto3 +import numpy as np +import sagemaker_xgboost_container.encoder as xgb_encoders + +sdk_name = "sagemaker-dev-1.0.tar.gz" +code_dir = "/opt/ml/code" + +sdk_file = f"{code_dir}/{sdk_name}" +os.system(f"pip install {sdk_file}") + +from sagemaker.session import Session +from sagemaker.experiments import load_run + +boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) +sagemaker_session = Session(boto_session=boto_session) + + +def model_fn(model_dir): + """ + Deserialize and return fitted model. + """ + with load_run( + experiment_name=os.environ["EXPERIMENT_NAME"], + run_name=os.environ["RUN_NAME"], + sagemaker_session=sagemaker_session, + ) as run: + logging.info(f"Run name: {run.run_name}") + logging.info(f"Experiment name: {run.experiment_name}") + logging.info(f"Trial component name: {run._trial_component.trial_component_name}") + run.log_parameters({"p3": 3.0, "p4": 4.0}) + run.log_metric("test-job-load-log-metric", 0.1) + + model_file = "xgboost-model" + booster = pkl.load(open(os.path.join(model_dir, model_file), "rb")) + return booster + + +def input_fn(request_body, request_content_type): + """ + The SageMaker XGBoost model server receives the request data body and the content type, + and invokes the `input_fn`. + Return a DMatrix (an object that can be passed to predict_fn). + """ + if request_content_type == "text/libsvm": + return xgb_encoders.libsvm_to_dmatrix(request_body) + else: + raise ValueError("Content type {} is not supported.".format(request_content_type)) + + +def predict_fn(input_data, model): + """ + SageMaker XGBoost model server invokes `predict_fn` on the return value of `input_fn`. + Return a two-dimensional NumPy array where the first columns are predictions + and the remaining columns are the feature contributions (SHAP values) for that prediction. + """ + prediction = model.predict(input_data) + feature_contribs = model.predict(input_data, pred_contribs=True, validate_features=False) + output = np.hstack((prediction[:, np.newaxis], feature_contribs)) + return output + + +def output_fn(predictions, content_type): + """ + After invoking predict_fn, the model server invokes `output_fn`. + """ + if content_type == "text/csv" or content_type == "application/json": + return ",".join(str(x) for x in predictions[0]) + else: + raise ValueError("Content type {} is not supported.".format(content_type)) diff --git a/tests/data/experiment/process_job_script_for_run_clz.py b/tests/data/experiment/process_job_script_for_run_clz.py new file mode 100644 index 0000000000..32fd0ab4f6 --- /dev/null +++ b/tests/data/experiment/process_job_script_for_run_clz.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This script file runs on SageMaker processing job""" +from __future__ import absolute_import + +import logging +import os +import boto3 + +sdk_file = "sagemaker-dev-1.0.tar.gz" +os.system(f"pip install {sdk_file}") + + +from sagemaker import Session +from sagemaker.experiments import load_run + + +boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) +sagemaker_session = Session(boto_session=boto_session) + + +with load_run(sagemaker_session=sagemaker_session) as run: + logging.info(f"Run name: {run.run_name}") + logging.info(f"Experiment name: {run.experiment_name}") + logging.info(f"Trial component name: {run._trial_component.trial_component_name}") + run.log_parameters({"p3": 3.0, "p4": 4.0}) + run.log_metric("test-job-load-log-metric", 0.1) diff --git a/tests/data/experiment/train_job_script_for_run_clz.py b/tests/data/experiment/train_job_script_for_run_clz.py new file mode 100644 index 0000000000..34c86e0993 --- /dev/null +++ b/tests/data/experiment/train_job_script_for_run_clz.py @@ -0,0 +1,71 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This script file runs on SageMaker training job""" +from __future__ import absolute_import + +import logging +import time +import os +import boto3 + +sdk_file = "sagemaker-dev-1.0.tar.gz" +os.system(f"pip install {sdk_file}") + +from sagemaker import Session +from sagemaker.experiments import load_run, Run + +boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) +sagemaker_session = Session(boto_session=boto_session) + +if os.environ["RUN_OPERATION"] == "init": + logging.info("Initializing a Run") + with Run( + experiment_name=os.environ["EXPERIMENT_NAME"], + run_name=os.environ["RUN_NAME"], + sagemaker_session=sagemaker_session, + ) as run: + logging.info(f"Run name: {run.run_name}") + logging.info(f"Experiment name: {run.experiment_name}") + logging.info(f"Trial component name: {run._trial_component.trial_component_name}") + run.log_parameter("p1", 1.0) + run.log_parameter("p2", 2) + + for i in range(2): + run.log_metric("A", i) + for i in range(2): + run.log_metric("B", i) + for i in range(2): + run.log_metric("C", i) + for i in range(2): + time.sleep(0.003) + run.log_metric("D", i) + for i in range(2): + time.sleep(0.003) + run.log_metric("E", i) + time.sleep(15) + +else: + logging.info("Loading a Run") + logging.info("Invoking load_run with name arguments") + with load_run( + experiment_name=os.environ["EXPERIMENT_NAME"], + run_name=os.environ["RUN_NAME"], + sagemaker_session=sagemaker_session, + ) as run: + run.log_parameters({"p3": 3.0, "p4": 4}) + run.log_metric("test-job-load-log-metric", 0.1) + + if os.environ.get("CALL_RUN_LOAD_WITH_NO_NAME_ARGS", None) == "True": + logging.info("Invoking load_run without name arguments") + with load_run(sagemaker_session=sagemaker_session) as run: + run.log_parameters({"p5": 5.0, "p6": 6}) diff --git a/tests/data/experiment/transform_job_materials/data.csv b/tests/data/experiment/transform_job_materials/data.csv new file mode 100644 index 0000000000..9f1b6c0bb0 --- /dev/null +++ b/tests/data/experiment/transform_job_materials/data.csv @@ -0,0 +1 @@ +-99 1:3 2:0.37 3:0.29 4:0.095 5:0.249 6:0.1045 7:0.058 8:0.067 \ No newline at end of file diff --git a/tests/data/experiment/transform_job_materials/xgb_model.tar.gz b/tests/data/experiment/transform_job_materials/xgb_model.tar.gz new file mode 100644 index 0000000000..3969bede9e Binary files /dev/null and b/tests/data/experiment/transform_job_materials/xgb_model.tar.gz differ diff --git a/tests/data/huggingface_byoc/requirements.txt b/tests/data/huggingface_byoc/requirements.txt new file mode 100644 index 0000000000..462542f1c1 --- /dev/null +++ b/tests/data/huggingface_byoc/requirements.txt @@ -0,0 +1,2 @@ +transformers +datasets diff --git a/tests/data/huggingface_byoc/run_glue.py b/tests/data/huggingface_byoc/run_glue.py new file mode 100644 index 0000000000..1060398fa4 --- /dev/null +++ b/tests/data/huggingface_byoc/run_glue.py @@ -0,0 +1,568 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning the library models for sequence classification on GLUE.""" +# You can also adapt this script on your own text classification task. Pointers for this are left as comments. + +import logging +import os +import random +import sys +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +from datasets import load_dataset, load_metric + +import transformers +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + HfArgumentParser, + PretrainedConfig, + Trainer, + TrainingArguments, + default_data_collator, + set_seed, +) +from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.utils import check_min_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.5.0") + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +logger = logging.getLogger(__name__) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + task_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, + ) + max_seq_length: int = field( + default=128, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} + ) + pad_to_max_length: bool = field( + default=True, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_val_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the training data."} + ) + validation_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the validation data."} + ) + test_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the test data."} + ) + + def __post_init__(self): + if self.task_name is not None: + self.task_name = self.task_name.lower() + if self.task_name not in task_to_keys.keys(): + raise ValueError( + "Unknown task, you should pick one in " + ",".join(task_to_keys.keys()) + ) + elif self.train_file is None or self.validation_file is None: + raise ValueError("Need either a GLUE task or a training/validation file.") + else: + train_extension = self.train_file.split(".")[-1] + assert train_extension in [ + "csv", + "json", + ], "`train_file` should be a csv or a json file." + validation_extension = self.validation_file.split(".")[-1] + assert ( + validation_extension == train_extension + ), "`validation_file` should have the same extension (csv or json) as `train_file`." + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained config name or path if not the same as model_name"}, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Where do you want to store the pretrained models downloaded from huggingface.co" + }, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={ + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." + }, + ) + model_revision: str = field( + default="main", + metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)." + }, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Detecting last checkpoint. + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the + # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named + # label if at least two columns are provided. + # + # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this + # single column. You can easily tweak this behavior (see below) + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.task_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset("glue", data_args.task_name) + else: + # Loading a dataset from your local files. + # CSV/JSON training and evaluation files are needed. + data_files = {"train": data_args.train_file, "validation": data_args.validation_file} + + # Get the test dataset: you can provide your own CSV/JSON test file (see below) + # when you use `do_predict` without specifying a GLUE benchmark task. + if training_args.do_predict: + if data_args.test_file is not None: + train_extension = data_args.train_file.split(".")[-1] + test_extension = data_args.test_file.split(".")[-1] + assert ( + test_extension == train_extension + ), "`test_file` should have the same extension (csv or json) as `train_file`." + data_files["test"] = data_args.test_file + else: + raise ValueError("Need either a GLUE task or a test file for `do_predict`.") + + for key in data_files.keys(): + logger.info(f"load a local file for {key}: {data_files[key]}") + + if data_args.train_file.endswith(".csv"): + # Loading a dataset from local csv files + datasets = load_dataset("csv", data_files=data_files) + else: + # Loading a dataset from local json files + datasets = load_dataset("json", data_files=data_files) + # See more about loading any type of standard or custom dataset at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Labels + if data_args.task_name is not None: + is_regression = data_args.task_name == "stsb" + if not is_regression: + label_list = datasets["train"].features["label"].names + num_labels = len(label_list) + else: + num_labels = 1 + else: + # Trying to have good defaults here, don't hesitate to tweak to your needs. + is_regression = datasets["train"].features["label"].dtype in ["float32", "float64"] + if is_regression: + num_labels = 1 + else: + # A useful fast method: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique + label_list = datasets["train"].unique("label") + label_list.sort() # Let's sort it for determinism + num_labels = len(label_list) + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + finetuning_task=data_args.task_name, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = AutoModelForSequenceClassification.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Preprocessing the datasets + if data_args.task_name is not None: + sentence1_key, sentence2_key = task_to_keys[data_args.task_name] + else: + # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. + non_label_column_names = [ + name for name in datasets["train"].column_names if name != "label" + ] + if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: + sentence1_key, sentence2_key = "sentence1", "sentence2" + else: + if len(non_label_column_names) >= 2: + sentence1_key, sentence2_key = non_label_column_names[:2] + else: + sentence1_key, sentence2_key = non_label_column_names[0], None + + # Padding strategy + if data_args.pad_to_max_length: + padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + padding = False + + # Some models have set the order of the labels to use, so let's make sure we do use it. + label_to_id = None + if ( + model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id + and data_args.task_name is not None + and not is_regression + ): + # Some have all caps in their config, some don't. + label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} + if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): + label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} + else: + logger.warn( + "Your model seems to have been trained with labels, but they don't match the dataset: ", + f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." + "\nIgnoring the model labels as a result.", + ) + elif data_args.task_name is None and not is_regression: + label_to_id = {v: i for i, v in enumerate(label_list)} + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + def preprocess_function(examples): + # Tokenize the texts + args = ( + (examples[sentence1_key],) + if sentence2_key is None + else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) + + # Map labels to IDs (not necessary for GLUE tasks) + if label_to_id is not None and "label" in examples: + result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] + return result + + datasets = datasets.map( + preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache + ) + if training_args.do_train: + if "train" not in datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = datasets["train"] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + + if training_args.do_eval: + if "validation" not in datasets and "validation_matched" not in datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = datasets[ + "validation_matched" if data_args.task_name == "mnli" else "validation" + ] + if data_args.max_val_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) + + if ( + training_args.do_predict + or data_args.task_name is not None + or data_args.test_file is not None + ): + if "test" not in datasets and "test_matched" not in datasets: + raise ValueError("--do_predict requires a test dataset") + test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] + if data_args.max_test_samples is not None: + test_dataset = test_dataset.select(range(data_args.max_test_samples)) + + # Log a few random samples from the training set: + if training_args.do_train: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # Get the metric function + if data_args.task_name is not None: + metric = load_metric("glue", data_args.task_name) + # TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from + # compute_metrics + + # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a + # predictions and label_ids field) and has to return a dictionary string to float. + def compute_metrics(p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) + if data_args.task_name is not None: + result = metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() + return result + elif is_regression: + return {"mse": ((preds - p.label_ids) ** 2).mean().item()} + else: + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} + + # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. + if data_args.pad_to_max_length: + data_collator = default_data_collator + elif training_args.fp16: + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + else: + data_collator = None + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + compute_metrics=compute_metrics, + tokenizer=tokenizer, + data_collator=data_collator, + ) + + # Training + if training_args.do_train: + checkpoint = None + if last_checkpoint is not None: + checkpoint = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + # Check the config from that potential checkpoint has the right number of labels before using it as a + # checkpoint. + if AutoConfig.from_pretrained(model_args.model_name_or_path).num_labels == num_labels: + checkpoint = model_args.model_name_or_path + + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples + if data_args.max_train_samples is not None + else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.save_model() # Saves the tokenizer too for easy upload + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + logger.info("*** Evaluate ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + eval_datasets = [eval_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + eval_datasets.append(datasets["validation_mismatched"]) + + for eval_dataset, task in zip(eval_datasets, tasks): + metrics = trainer.evaluate(eval_dataset=eval_dataset) + + max_val_samples = ( + data_args.max_val_samples + if data_args.max_val_samples is not None + else len(eval_dataset) + ) + metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.do_predict: + logger.info("*** Test ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + test_datasets = [test_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + test_datasets.append(datasets["test_mismatched"]) + + for test_dataset, task in zip(test_datasets, tasks): + # Removing the `label` columns because it contains -1 and Trainer won't like that. + test_dataset.remove_columns_("label") + predictions = trainer.predict(test_dataset=test_dataset).predictions + predictions = ( + np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) + ) + + output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt") + if trainer.is_world_process_zero(): + with open(output_test_file, "w") as writer: + logger.info(f"***** Test results {task} *****") + writer.write("index\tprediction\n") + for index, item in enumerate(predictions): + if is_regression: + writer.write(f"{index}\t{item:3.3f}\n") + else: + item = label_list[item] + writer.write(f"{index}\t{item}\n") + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/tests/data/huggingface_byoc/train/dummy.csv b/tests/data/huggingface_byoc/train/dummy.csv new file mode 100644 index 0000000000..fb1539d552 --- /dev/null +++ b/tests/data/huggingface_byoc/train/dummy.csv @@ -0,0 +1 @@ +# dummy data \ No newline at end of file diff --git a/tests/data/multimodel/container/Dockerfile b/tests/data/multimodel/container/Dockerfile index 4792a429c1..71c38a6605 100644 --- a/tests/data/multimodel/container/Dockerfile +++ b/tests/data/multimodel/container/Dockerfile @@ -1,4 +1,5 @@ -FROM public.ecr.aws/ubuntu/ubuntu:18.04 +# added latest image from https://gallery.ecr.aws/lts/ubuntu +FROM public.ecr.aws/ubuntu/ubuntu:22.04 # Set a docker label to advertise multi-model support on the container LABEL com.amazonaws.sagemaker.capabilities.multi-models=true @@ -15,7 +16,7 @@ RUN apt-get update && \ curl \ vim \ && rm -rf /var/lib/apt/lists/* \ - && curl -O https://bootstrap.pypa.io/pip/3.6/get-pip.py \ + && curl -O https://bootstrap.pypa.io/pip/get-pip.py \ && python3 get-pip.py RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1 diff --git a/tests/data/pipeline/test_source_dir/script_1.py b/tests/data/pipeline/test_source_dir/script_1.py new file mode 100644 index 0000000000..4a427b1898 --- /dev/null +++ b/tests/data/pipeline/test_source_dir/script_1.py @@ -0,0 +1,11 @@ +""" +Integ test file script_1.py +""" +import pathlib + +if __name__ == "__main__": + + print("writing file to /opt/ml/processing/test/test.py...") + pathlib.Path("/opt/ml/processing/test").mkdir(parents=True, exist_ok=True) + with open("/opt/ml/processing/test/test.py", "w") as f: + f.write('print("test...")') diff --git a/tests/data/pipeline/test_source_dir/script_2.py b/tests/data/pipeline/test_source_dir/script_2.py new file mode 100644 index 0000000000..6245dac987 --- /dev/null +++ b/tests/data/pipeline/test_source_dir/script_2.py @@ -0,0 +1,9 @@ +""" +Integ test file script_2.py +""" + +if __name__ == "__main__": + + print("reading file: /opt/ml/procesing/test/test.py") + with open("/opt/ml/processing/test/test.py", "r") as f: + print(f.read()) diff --git a/tests/data/pipeline/test_source_dir_2/script_2.py b/tests/data/pipeline/test_source_dir_2/script_2.py new file mode 100644 index 0000000000..6245dac987 --- /dev/null +++ b/tests/data/pipeline/test_source_dir_2/script_2.py @@ -0,0 +1,9 @@ +""" +Integ test file script_2.py +""" + +if __name__ == "__main__": + + print("reading file: /opt/ml/procesing/test/test.py") + with open("/opt/ml/processing/test/test.py", "r") as f: + print(f.read()) diff --git a/tests/data/pytorch_neo/code/inference.py b/tests/data/pytorch_neo/code/inference.py index 5b89c2bebc..79fe66d716 100644 --- a/tests/data/pytorch_neo/code/inference.py +++ b/tests/data/pytorch_neo/code/inference.py @@ -71,8 +71,8 @@ def model_fn(model_dir): logger.info("model_fn") neopytorch.config(model_dir=model_dir, neo_runtime=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # The compiled model is saved as "model.pth" - model = torch.jit.load(os.path.join(model_dir, "model.pth"), map_location=device) + # The compiled model is saved as "model.pth" or "model.pt" + model = torch.jit.load(os.path.join(model_dir, "model.pt"), map_location=device) # It is recommended to run warm-up inference during model load sample_input_path = os.path.join(model_dir, "sample_input.pkl") diff --git a/tests/data/spark/code/java/TestJarFile.jar b/tests/data/spark/code/java/TestJarFile.jar new file mode 100644 index 0000000000..d528331d55 Binary files /dev/null and b/tests/data/spark/code/java/TestJarFile.jar differ diff --git a/tests/data/spark/code/java/hello-java-spark/HelloJavaSparkApp.jar b/tests/data/spark/code/java/hello-java-spark/HelloJavaSparkApp.jar new file mode 100644 index 0000000000..056675146d Binary files /dev/null and b/tests/data/spark/code/java/hello-java-spark/HelloJavaSparkApp.jar differ diff --git a/tests/integ/__init__.py b/tests/integ/__init__.py index 00ed09577b..9133fc8904 100644 --- a/tests/integ/__init__.py +++ b/tests/integ/__init__.py @@ -158,7 +158,7 @@ "ap-northeast-1", "eu-central-1", ] -# TODO: SM Training Compiler team to add all supported regions. + TRAINING_COMPILER_SUPPORTED_REGIONS = [ "af-south-1", "ap-east-1", diff --git a/tests/integ/sagemaker/experiments/__init__.py b/tests/integ/sagemaker/experiments/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/experiments/conftest.py b/tests/integ/sagemaker/experiments/conftest.py new file mode 100644 index 0000000000..ca40a3ba6d --- /dev/null +++ b/tests/integ/sagemaker/experiments/conftest.py @@ -0,0 +1,177 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import glob +import logging +import os +import shutil +import tempfile +import time +import uuid + +import boto3 +import pytest + +from sagemaker.experiments import Run +from tests.integ import DATA_DIR + +from sagemaker.experiments import trial_component, trial, experiment +from sagemaker.utils import retry_with_backoff, unique_name_from_base +from tests.integ.sagemaker.experiments.helpers import name, names + +TAGS = [{"Key": "some-key", "Value": "some-value"}] +EXP_NAME_BASE_IN_LOCAL = "Job-Exp-in-Local" +RUN_NAME_IN_LOCAL = "job-run-in-local" + + +@pytest.fixture(scope="module") +def run_obj(sagemaker_session): + run = Run( + experiment_name=unique_name_from_base(EXP_NAME_BASE_IN_LOCAL), + run_name=RUN_NAME_IN_LOCAL, + sagemaker_session=sagemaker_session, + ) + try: + yield run + time.sleep(0.5) + finally: + exp = experiment._Experiment.load( + experiment_name=run.experiment_name, sagemaker_session=sagemaker_session + ) + exp._delete_all(action="--force") + + +@pytest.fixture(scope="module") +def trial_component_obj(sagemaker_session): + trial_component_obj = trial_component._TrialComponent.create( + trial_component_name=name(), + sagemaker_session=sagemaker_session, + tags=TAGS, + ) + yield trial_component_obj + time.sleep(0.5) + _delete_associations(trial_component_obj.trial_component_arn, sagemaker_session) + retry_with_backoff(trial_component_obj.delete) + + +@pytest.fixture(scope="module") +def experiment_obj(sagemaker_session): + client = sagemaker_session.sagemaker_client + description = "{}-{}".format("description", str(uuid.uuid4())) + boto3.set_stream_logger("", logging.INFO) + experiment_name = name() + experiment_obj = experiment._Experiment.create( + experiment_name=experiment_name, + description=description, + sagemaker_session=sagemaker_session, + tags=TAGS, + ) + yield experiment_obj + time.sleep(0.5) + experiment_obj.delete() + with pytest.raises(client.exceptions.ResourceNotFound): + client.describe_experiment(ExperimentName=experiment_name) + + +@pytest.fixture(scope="module") +def trial_obj(sagemaker_session, experiment_obj): + trial_obj = trial._Trial.create( + trial_name=name(), + experiment_name=experiment_obj.experiment_name, + tags=TAGS, + sagemaker_session=sagemaker_session, + ) + yield trial_obj + time.sleep(0.5) + trial_obj.delete() + + +@pytest.fixture(scope="module") +def trials(experiment_obj, sagemaker_session): + trial_objs = [] + for trial_name in names(): + next_trial = trial._Trial.create( + trial_name=trial_name, + experiment_name=experiment_obj.experiment_name, + sagemaker_session=sagemaker_session, + ) + trial_objs.append(next_trial) + time.sleep(0.5) + yield trial_objs + for trial_obj in trial_objs: + trial_obj.delete() + + +@pytest.fixture(scope="module") +def trial_component_with_force_disassociation_obj(trials, sagemaker_session): + trial_component_obj = trial_component._TrialComponent.create( + trial_component_name=name(), sagemaker_session=sagemaker_session + ) + for trial_obj in trials: + sagemaker_session.sagemaker_client.associate_trial_component( + TrialName=trial_obj.trial_name, + TrialComponentName=trial_component_obj.trial_component_name, + ) + yield trial_component_obj + time.sleep(0.5) + trial_component_obj.delete(force_disassociate=True) + + +@pytest.fixture(scope="module") +def trial_components(sagemaker_session): + trial_component_objs = [ + trial_component._TrialComponent.create( + trial_component_name=trial_component_name, + sagemaker_session=sagemaker_session, + ) + for trial_component_name in names() + ] + yield trial_component_objs + for trial_component_obj in trial_component_objs: + trial_component_obj.delete() + + +@pytest.fixture(scope="module") +def tempdir(): + temp_dir = tempfile.mkdtemp() + yield temp_dir + shutil.rmtree(temp_dir) + + +_EXP_PLUS_SDK_TAR = "sagemaker-dev-1.0.tar.gz" + + +@pytest.fixture(scope="module") +def dev_sdk_tar(): + resource_dir = os.path.join(DATA_DIR, "experiment") + os.system("python setup.py sdist") + sdist_path = max(glob.glob("dist/sagemaker-*"), key=os.path.getctime) + sdk_file = os.path.join(resource_dir, _EXP_PLUS_SDK_TAR) + shutil.copy(sdist_path, sdk_file) + return sdk_file + + +def _delete_associations(arn, sagemaker_session): + client = sagemaker_session.sagemaker_client + outgoing_associations = client.list_associations(SourceArn=arn)["AssociationSummaries"] + incoming_associations = client.list_associations(DestinationArn=arn)["AssociationSummaries"] + associations = [] + if outgoing_associations: + associations.extend(outgoing_associations) + if incoming_associations: + associations.extend(incoming_associations) + for association in associations: + source_arn = association["SourceArn"] + destination_arn = association["DestinationArn"] + client.delete_association(SourceArn=source_arn, DestinationArn=destination_arn) diff --git a/tests/integ/sagemaker/experiments/helpers.py b/tests/integ/sagemaker/experiments/helpers.py new file mode 100644 index 0000000000..b5e8064b08 --- /dev/null +++ b/tests/integ/sagemaker/experiments/helpers.py @@ -0,0 +1,42 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from contextlib import contextmanager + +from sagemaker import utils +from sagemaker.experiments.experiment import _Experiment + +EXP_INTEG_TEST_NAME_PREFIX = "experiments-integ" + + +def name(): + return utils.unique_name_from_base(EXP_INTEG_TEST_NAME_PREFIX) + + +def names(): + return [utils.unique_name_from_base(EXP_INTEG_TEST_NAME_PREFIX) for i in range(3)] + + +def to_seconds(dt): + return int(dt.timestamp()) + + +@contextmanager +def cleanup_exp_resources(exp_names, sagemaker_session): + try: + yield + finally: + for exp_name in exp_names: + exp = _Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session) + exp._delete_all(action="--force") diff --git a/tests/integ/sagemaker/experiments/test_experiment.py b/tests/integ/sagemaker/experiments/test_experiment.py new file mode 100644 index 0000000000..ff7d5fac37 --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_experiment.py @@ -0,0 +1,56 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.experiments import experiment +from tests.integ.sagemaker.experiments.helpers import name + + +def test_create_delete(experiment_obj): + # The fixture creates deletes, just ensure fixture is used at least once + assert experiment_obj.experiment_name + + +def test_create_tags(experiment_obj, sagemaker_session): + client = sagemaker_session.sagemaker_client + while True: + actual_tags = client.list_tags(ResourceArn=experiment_obj.experiment_arn)["Tags"] + if actual_tags: + break + for tag in actual_tags: + if "aws:tag" in tag.get("Key"): + actual_tags.remove(tag) + assert actual_tags == experiment_obj.tags + + +def test_save(experiment_obj): + description = name() + experiment_obj.description = description + experiment_obj.save() + + +def test_save_load(experiment_obj, sagemaker_session): + experiment_obj_two = experiment._Experiment.load( + experiment_name=experiment_obj.experiment_name, sagemaker_session=sagemaker_session + ) + assert experiment_obj.experiment_name == experiment_obj_two.experiment_name + assert experiment_obj.description == experiment_obj_two.description + + experiment_obj.description = name() + experiment_obj.display_name = name() + experiment_obj.save() + experiment_obj_three = experiment._Experiment.load( + experiment_name=experiment_obj.experiment_name, sagemaker_session=sagemaker_session + ) + assert experiment_obj.description == experiment_obj_three.description + assert experiment_obj.display_name == experiment_obj_three.display_name diff --git a/tests/integ/sagemaker/experiments/test_metrics.py b/tests/integ/sagemaker/experiments/test_metrics.py new file mode 100644 index 0000000000..15c0c2f9dc --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_metrics.py @@ -0,0 +1,39 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import random +from sagemaker.experiments._metrics import _MetricsManager +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.utils import retry_with_backoff + + +def test_end_to_end(trial_component_obj, sagemaker_session): + # The fixture creates deletes, just ensure fixture is used at least once + with _MetricsManager(trial_component_obj.trial_component_name, sagemaker_session) as mm: + for i in range(100): + mm.log_metric("test-x-step", random.random(), step=i) + mm.log_metric("test-x-timestamp", random.random()) + + def verify_metrics(): + updated_tc = _TrialComponent.load( + trial_component_name=trial_component_obj.trial_component_name, + sagemaker_session=sagemaker_session, + ) + metrics = updated_tc.metrics + # TODO: revert to len(metrics) == 2 once backend fix reaches prod + assert len(metrics) > 0 + assert list(filter(lambda x: x.metric_name == "test-x-step", metrics)) + assert list(filter(lambda x: x.metric_name == "test-x-timestamp", metrics)) + + # metrics -> eureka propagation + retry_with_backoff(verify_metrics) diff --git a/tests/integ/sagemaker/experiments/test_run.py b/tests/integ/sagemaker/experiments/test_run.py new file mode 100644 index 0000000000..713a6a3792 --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_run.py @@ -0,0 +1,662 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import os + +import pytest + +from tests.integ.sagemaker.experiments.conftest import TAGS +from sagemaker.experiments._api_types import _TrialComponentStatusType +from sagemaker.experiments._utils import is_run_trial_component +from sagemaker.processing import FrameworkProcessor +from sagemaker.pytorch import PyTorch +from sagemaker.s3 import S3Uploader +from sagemaker.xgboost import XGBoostModel +from tests.integ import DATA_DIR +from sagemaker.experiments._metrics import BATCH_SIZE +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.sklearn import SKLearn +from sagemaker.utils import retry_with_backoff, unique_name_from_base +from tests.integ.sagemaker.experiments.helpers import name, cleanup_exp_resources +from sagemaker.experiments.run import ( + RUN_NAME_BASE, + DELIMITER, +) +from sagemaker.experiments import Run, load_run, list_runs +from sagemaker.experiments._helper import _DEFAULT_ARTIFACT_PREFIX + + +# when running integration tests locally modify this to your test account's execution role +EXECUTION_ROLE = "SageMakerRole" + + +@pytest.fixture +def artifact_file_path(tempdir): + file_contents = "test artifact file" + file_path = os.path.join(tempdir, "artifact_file.txt") + with open(file_path, "w") as foo_file: + foo_file.write(file_contents) + return file_path + + +artifact_name = unique_name_from_base("Test-Artifact") +file_artifact_name = f"File-Artifact-{name()}" +metric_name = "Test-Local-Init-Log-Metric" + + +def test_local_run_with_load(sagemaker_session, artifact_file_path): + exp_name = f"My-Local-Exp-{name()}" + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + # Run name is not provided, will create a new TC + with Run(experiment_name=exp_name, sagemaker_session=sagemaker_session) as run1: + run1_name = run1.run_name + assert RUN_NAME_BASE in run1_name + _local_run_log_behaviors( + artifact_file_path=artifact_file_path, + sagemaker_session=sagemaker_session, + ) + + def verify_load_run(): + with load_run( + experiment_name=exp_name, + run_name=run1_name, + sagemaker_session=sagemaker_session, + ) as run2: + assert run2.run_name == run1_name + assert ( + run2._trial_component.trial_component_name + == f"{run2.experiment_name}{DELIMITER}{run1_name}" + ) + _check_run_from_local_end_result( + sagemaker_session=sagemaker_session, tc=run2._trial_component + ) + + # Add retry to make sure metrics -> eureka propagation is consistent + retry_with_backoff(verify_load_run, 4) + + +def test_two_local_run_init_with_same_run_name_and_different_exp_names(sagemaker_session): + exp_name1 = f"my-two-local-exp1-{name()}" + exp_name2 = f"my-two-local-exp2-{name()}" + run_name = "test-run" + with cleanup_exp_resources( + exp_names=[exp_name1, exp_name2], sagemaker_session=sagemaker_session + ): + # Run name is not provided, will create a new TC + with Run( + experiment_name=exp_name1, run_name=run_name, sagemaker_session=sagemaker_session + ) as run1: + pass + with Run( + experiment_name=exp_name2, run_name=run_name, sagemaker_session=sagemaker_session + ) as run2: + pass + + assert run1.experiment_name != run2.experiment_name + assert run1.run_name == run2.run_name + assert ( + run1._trial_component.trial_component_name != run2._trial_component.trial_component_name + ) + assert run1._trial_component.trial_component_name == f"{exp_name1}{DELIMITER}{run_name}" + assert run2._trial_component.trial_component_name == f"{exp_name2}{DELIMITER}{run_name}" + + +@pytest.mark.parametrize( + "input_names", + [ + (f"my-local-exp-{name()}", "test-run", None), # both have delimiter - + ("my-test-1", "my-test-1", None), # exp_name equals run_name + ("my-test-3", "my-test-3-run", None), # is subset of run_name + ("x" * 59, "test-run", None), # long exp_name + ("test-exp", "y" * 59, None), # long run_name + ("e" * 59, "y" * 59, None), # long exp_name and run_name + ("my-test4", "test-run", "run-display-name-test"), # with supplied display name + ], +) +def test_run_name_vs_trial_component_name_edge_cases(sagemaker_session, input_names): + exp_name, run_name, run_display_name = input_names + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + sagemaker_session=sagemaker_session, + run_name=run_name, + run_display_name=run_display_name, + ) as run1: + assert not run1._experiment.tags + assert not run1._trial.tags + is_run_tc = is_run_trial_component( + trial_component_name=run1._trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + assert is_run_tc + + with load_run( + experiment_name=exp_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ) as run2: + assert run2.experiment_name == exp_name + assert run2.run_name == run_name + assert run2._trial_component.trial_component_name == f"{exp_name}{DELIMITER}{run_name}" + assert run2._trial_component.display_name in ( + run_display_name, + run2._trial_component.trial_component_name, + ) + + +_EXP_NAME_BASE_IN_SCRIPT = "job-exp-in-script" +_RUN_NAME_IN_SCRIPT = "job-run-in-script" + +_EXP_DIR = os.path.join(DATA_DIR, "experiment") +_ENTRY_POINT_PATH = os.path.join(_EXP_DIR, "train_job_script_for_run_clz.py") +_PYTHON_PROCESS_SCRIPT = "process_job_script_for_run_clz.py" +_TRANSFORM_MATERIALS = os.path.join(_EXP_DIR, "transform_job_materials") + +_RUN_INIT = "init" +_RUN_LOAD = "load" + + +def test_run_from_local_and_train_job_and_all_exp_cfg_match(sagemaker_session, dev_sdk_tar): + # Notes: + # 1. The 1st Run TC created locally and its exp config was auto passed to the job + # 2. In training job, the same exp and run names are given in the Run constructor + # which will load the 1st Run TC in training job and log parameters + # and metrics there + # 3. In a different training job, load the same Run TC and log more parameters there. + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + estimator = _generate_estimator( + sdk_tar=dev_sdk_tar, sagemaker_session=sagemaker_session, exp_name=exp_name + ) + tc_name = Run._generate_trial_component_name( + experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT + ) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + run_name=_RUN_NAME_IN_SCRIPT, + sagemaker_session=sagemaker_session, + ) as run: + init_start_time = _check_tc_status_when_entering(run._trial_component) + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + # experiment_config is auto passed in by _RunContext + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + old_end_time = _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + sagemaker_session=sagemaker_session, + ) + + _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + assert run.experiment_name == exp_name + assert run.run_name == _RUN_NAME_IN_SCRIPT + _check_run_from_local_end_result( + tc=run._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, + ) + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + ) + + with run: + estimator.environment["RUN_OPERATION"] = _RUN_LOAD + estimator.environment["CALL_RUN_LOAD_WITH_NO_NAME_ARGS"] = "True" + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + + old_end_time = _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + + _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + is_init=False, + has_extra_load=True, + ) + + +def test_run_from_local_and_train_job_and_exp_cfg_not_match(sagemaker_session, dev_sdk_tar): + # Notes: + # 1. The 1st Run TC created locally and its exp config was auto passed to the job + # 2. In training job, different exp and run names (i.e. 2nd Run TC) are given + # in the Run constructor which will create a Run TC according to the run_name + # passed in there and ignore the exp config in the job + # 3. Both metrics and parameters are logged in the Run TC created in job + # 4. In a different training job, load the 2nd Run TC and log more parameters there. + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + exp_name2 = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + estimator = _generate_estimator( + sdk_tar=dev_sdk_tar, sagemaker_session=sagemaker_session, exp_name=exp_name + ) + tc_name = Run._generate_trial_component_name( + experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT + ) + + with cleanup_exp_resources( + exp_names=[exp_name, exp_name2], sagemaker_session=sagemaker_session + ): + with Run( + experiment_name=exp_name2, + run_name=f"{_RUN_NAME_IN_SCRIPT}2", + sagemaker_session=sagemaker_session, + ) as run: + init_start_time = _check_tc_status_when_entering(run._trial_component) + # experiment_config is auto passed in by _RunContext + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_tc_status_intermediate( + trial_component=run._trial_component, + sagemaker_session=sagemaker_session, + init_start_time=init_start_time, + ) + + old_end_time = _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + sagemaker_session=sagemaker_session, + ) + assert run.experiment_name != exp_name + assert run.run_name != _RUN_NAME_IN_SCRIPT + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + ) + + with run: + estimator.environment["RUN_OPERATION"] = _RUN_LOAD + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_tc_status_intermediate( + trial_component=run._trial_component, + sagemaker_session=sagemaker_session, + init_start_time=init_start_time, + old_end_time=old_end_time, + ) + + _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + +def test_run_from_train_job_only(sagemaker_session, dev_sdk_tar): + # Notes: + # 1. No Run TC created locally or specified in experiment config + # 2. In training job, Run is initialized + # which will create a Run TC according to the run_name passed in there + # 3. Both metrics and parameters are logged in the Run TC created in job + # 4. In a different training job, load the same Run TC and log more parameters there. + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + estimator = _generate_estimator( + sdk_tar=dev_sdk_tar, + sagemaker_session=sagemaker_session, + exp_name=exp_name, + ) + tc_name = Run._generate_trial_component_name( + experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT + ) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + ) + + estimator.environment["RUN_OPERATION"] = _RUN_LOAD + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + +# dev_sdk_tar is required to trigger generating the dev SDK tar +def test_run_from_processing_job_and_override_default_exp_config( + sagemaker_session, dev_sdk_tar, run_obj +): + # Notes: + # 1. The 1st Run TC (run) created locally + # 2. Within the 2nd Run TC (run_obj)'s context, invoke processor.run + # but override the default experiment config in context of 2nd Run TC + # with the experiment config of the 1st Run TC + # 3. In the processing job script, load the 1st Run TC via the experiment config + # fetched from the job env + # 4. All data are logged in the Run TC either locally or in the processing job + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + processor = FrameworkProcessor( + estimator_cls=PyTorch, + framework_version="1.10", + py_version="py38", + instance_count=1, + instance_type="ml.m5.xlarge", + role=EXECUTION_ROLE, + sagemaker_session=sagemaker_session, + ) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + run_name=_RUN_NAME_IN_SCRIPT, + sagemaker_session=sagemaker_session, + ) as run: + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + + with run_obj: + # Override the default experiment_config in _RunContext of run_obj + # with the experiment_config of run + processor.run( + code=_PYTHON_PROCESS_SCRIPT, + source_dir=_EXP_DIR, + job_name=f"process-job-{name()}", + wait=True, # wait the job to finish + logs=False, + experiment_config=run.experiment_config, + ) + + assert run_obj.experiment_name != run.experiment_name + assert run_obj.run_name != run.run_name + _check_run_from_local_end_result( + tc=run._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, + ) + tc_name = Run._generate_trial_component_name( + experiment_name=run.experiment_name, run_name=run.run_name + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + with run_obj: + # Not to override the exp config and use the default one in the context + processor.run( + code=_PYTHON_PROCESS_SCRIPT, + source_dir=_EXP_DIR, + job_name=f"process-job-{name()}", + wait=True, # wait the job to finish + logs=False, + ) + + tc_name = Run._generate_trial_component_name( + experiment_name=run_obj.experiment_name, run_name=run_obj.run_name + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + +# dev_sdk_tar is required to trigger generating the dev SDK tar +def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, run_obj, xgboost_latest_version): + # Notes: + # 1. The 1st Run TC (run) created locally + # 2. In the inference script running in a transform job, load the 1st Run TC + # via explicitly passing the experiment_name and run_name of the 1st Run TC + # TODO: once we're able to retrieve exp config from the transform job env, + # we should expand this test and add the load_run() without explicitly supplying the names + # 3. All data are logged in the Run TC either locally or in the transform job + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_TRANSFORM_MATERIALS, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + xgboost_model = XGBoostModel( + sagemaker_session=sagemaker_session, + model_data=xgb_model_data_s3, + role=EXECUTION_ROLE, + entry_point="inference.py", + source_dir=_EXP_DIR, + framework_version=xgboost_latest_version, + env={ + "EXPERIMENT_NAME": run_obj.experiment_name, + "RUN_NAME": run_obj.run_name, + }, + ) + transformer = xgboost_model.transformer( + instance_count=1, + instance_type="ml.m5.4xlarge", + max_concurrent_transforms=5, + max_payload=1, + strategy="MultiRecord", + ) + uri = "s3://{}/{}/input/data/{}".format( + sagemaker_session.default_bucket(), + "transform-test", + unique_name_from_base("json-data"), + ) + input_data = S3Uploader.upload( + os.path.join(_TRANSFORM_MATERIALS, "data.csv"), uri, sagemaker_session=sagemaker_session + ) + + with run_obj: + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + transformer.transform( + data=input_data, + content_type="text/libsvm", + split_type="Line", + wait=True, + job_name=f"transform-job-{name()}", + ) + + _check_run_from_local_end_result( + tc=run_obj._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, + ) + tc_name = Run._generate_trial_component_name( + experiment_name=run_obj.experiment_name, run_name=run_obj.run_name + ) + _check_run_from_job_result(tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False) + + +def test_list(run_obj, sagemaker_session): + tc1 = _TrialComponent.create( + trial_component_name=f"non-run-tc1-{name()}", + sagemaker_session=sagemaker_session, + ) + tc2 = _TrialComponent.create( + trial_component_name=f"non-run-tc2-{name()}", + sagemaker_session=sagemaker_session, + tags=TAGS, + ) + run_obj._trial.add_trial_component(tc1) + run_obj._trial.add_trial_component(tc2) + + run_tcs = list_runs( + experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session + ) + assert len(run_tcs) == 1 + assert run_tcs[0].run_name == run_obj.run_name + assert run_tcs[0].experiment_name == run_obj.experiment_name + assert run_tcs[0].experiment_config == run_obj.experiment_config + + +def _generate_estimator(exp_name, sdk_tar, sagemaker_session): + return SKLearn( + framework_version="0.23-1", + entry_point=_ENTRY_POINT_PATH, + dependencies=[sdk_tar], + role=EXECUTION_ROLE, + instance_type="ml.m5.large", + instance_count=1, + volume_size=10, + max_run=900, + enable_sagemaker_metrics=True, + environment={ + "EXPERIMENT_NAME": exp_name, + "RUN_NAME": _RUN_NAME_IN_SCRIPT, + "RUN_OPERATION": _RUN_INIT, + }, + sagemaker_session=sagemaker_session, + ) + + +def _local_run_log_behaviors( + sagemaker_session, + artifact_file_path=None, + is_complete_log=True, +): + with load_run(sagemaker_session=sagemaker_session) as run: + run.log_parameter("pa", 1.0) + run.log_parameter("pb", "p2-value") + run.log_parameters({"pc": 2.0, "pd": "p4-value"}) + + if is_complete_log: + run.log_file(file_path=artifact_file_path, name=file_artifact_name) + run.log_artifact(name=artifact_name, value="s3://Output") + run.log_artifact(name=artifact_name, value="s3://Input", is_output=False) + + for i in range(BATCH_SIZE): + run.log_metric(name=metric_name, value=i, step=i) + + +def _check_run_from_local_end_result(sagemaker_session, tc, is_complete_log=True): + assert tc.parameters == {"pa": 1.0, "pb": "p2-value", "pc": 2.0, "pd": "p4-value"} + + if not is_complete_log: + return + + s3_prefix = f"s3://{sagemaker_session.default_bucket()}/{_DEFAULT_ARTIFACT_PREFIX}" + assert s3_prefix in tc.output_artifacts[file_artifact_name].value + assert "text/plain" == tc.output_artifacts[file_artifact_name].media_type + assert "s3://Output" == tc.output_artifacts[artifact_name].value + assert not tc.output_artifacts[artifact_name].media_type + assert "s3://Input" == tc.input_artifacts[artifact_name].value + assert not tc.input_artifacts[artifact_name].media_type + + # TODO: revert to len(tc.metrics) == 1 once backend fix reaches prod + assert len(tc.metrics) > 0 + metric_summary = tc.metrics[0] + assert metric_summary.metric_name == metric_name + assert metric_summary.max == 9.0 + assert metric_summary.min == 0.0 + + +def _check_run_from_job_result(sagemaker_session, tc_name=None, is_init=True, has_extra_load=False): + def validate_tc_updated_in_init(): + assert tc.start_time + assert tc.end_time + assert tc.status.primary_status == _TrialComponentStatusType.Completed.value + assert tc.parameters["p1"] == 1.0 + assert tc.parameters["p2"] == 2.0 + # TODO: revert to assert len(tc.metrics) == 5 once + # backend fix hits prod + assert len(tc.metrics) > 0 + for metric_summary in tc.metrics: + # metrics deletion is not supported at this point + # so its count would accumulate + assert metric_summary.count > 0 + assert metric_summary.min == 0.0 + assert metric_summary.max == 1.0 + + def validate_tc_updated_in_load(): + assert tc.parameters["p3"] == 3.0 + assert tc.parameters["p4"] == 4.0 + assert len(tc.metrics) > 0 + for metric_summary in tc.metrics: + if metric_summary.metric_name != "test-job-load-log-metric": + continue + assert metric_summary.last == 0.1 + assert metric_summary.max == 0.1 + assert metric_summary.min == 0.1 + if has_extra_load: + assert tc.parameters["p5"] == 5.0 + assert tc.parameters["p6"] == 6.0 + + tc = _TrialComponent.load(trial_component_name=tc_name, sagemaker_session=sagemaker_session) + if is_init: + # Add retry since the load behavior is inconsistent sometimes + retry_with_backoff(validate_tc_updated_in_init, 4) + else: + retry_with_backoff(validate_tc_updated_in_load, 4) + + +def _check_tc_status_when_entering(trial_component): + assert isinstance(trial_component.start_time, datetime.datetime) + assert not trial_component.end_time + assert trial_component.status.primary_status == _TrialComponentStatusType.InProgress.value + return trial_component.start_time + + +def _check_tc_status_when_exiting( + trial_component_name, sagemaker_session, init_start_time, old_end_time=None +): + tc = _TrialComponent.load( + trial_component_name=trial_component_name, sagemaker_session=sagemaker_session + ) + # There will be deviation (< 1s) caused by different TS precisions used in Backend and SDK + assert abs(tc.start_time.timestamp() - init_start_time.timestamp()) < 1 + assert tc.status.primary_status == _TrialComponentStatusType.Completed.value + assert isinstance(tc.end_time, datetime.datetime) + if old_end_time: + assert tc.end_time > old_end_time + return tc.end_time + + +def _check_tc_status_intermediate( + trial_component, sagemaker_session, init_start_time, old_end_time=None +): + tc_load = _TrialComponent.load( + trial_component_name=trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + assert abs(tc_load.start_time.timestamp() - init_start_time.timestamp()) < 1 + assert tc_load.status.primary_status == _TrialComponentStatusType.InProgress.value + if not old_end_time: + assert not trial_component.end_time + return + assert isinstance(tc_load.end_time, datetime.datetime) + assert tc_load.end_time == old_end_time diff --git a/tests/integ/sagemaker/experiments/test_trial.py b/tests/integ/sagemaker/experiments/test_trial.py new file mode 100644 index 0000000000..08f646c086 --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_trial.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import logging + +from sagemaker.experiments import trial +from src.sagemaker.utils import retry_with_backoff + + +def test_create_delete(trial_obj): + # Fixture creates / deletes, just ensure used at least once. + assert trial_obj.trial_name + + +def test_create_tags(trial_obj, sagemaker_session): + client = sagemaker_session.sagemaker_client + while True: + actual_tags = client.list_tags(ResourceArn=trial_obj.trial_arn)["Tags"] + if actual_tags: + break + for tag in actual_tags: + if "aws:tag" in tag.get("Key"): + actual_tags.remove(tag) + assert actual_tags == trial_obj.tags + + +def test_save_load(trial_obj, sagemaker_session): + trial_obj.display_name = "foo" + trial_obj.save() + assert ( + "foo" + == trial._Trial.load( + trial_name=trial_obj.trial_name, + sagemaker_session=sagemaker_session, + ).display_name + ) + + +def test_add_remove_trial_component(trial_obj, trial_component_obj): + trial_obj.add_trial_component(trial_component_obj) + logging.info( + f"Added trial component {trial_component_obj.trial_component_name} to trial {trial_obj.trial_name}" + ) + + def validate_add(): + trial_components = list(trial_obj.list_trial_components()) + assert 1 == len( + trial_components + ), "Expected trial component to be included in trials list of TC" + + retry_with_backoff(validate_add) + + trial_obj.remove_trial_component(trial_component_obj) + logging.info( + f"Removed trial component {trial_component_obj.trial_component_name} from trial {trial_obj.trial_name}" + ) + + def validate_remove(): + trial_components = list(trial_obj.list_trial_components()) + assert 0 == len( + trial_components + ), "Expected trial component to be removed from trials list of TC" + + retry_with_backoff(validate_remove) diff --git a/tests/integ/sagemaker/experiments/test_trial_component.py b/tests/integ/sagemaker/experiments/test_trial_component.py new file mode 100644 index 0000000000..3d79e41cc4 --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_trial_component.py @@ -0,0 +1,144 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import uuid + +from sagemaker.experiments._api_types import _TrialComponentStatusType +from tests.integ.sagemaker.experiments.helpers import EXP_INTEG_TEST_NAME_PREFIX +from sagemaker.experiments import _api_types, trial_component +from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression + + +def test_create_delete(trial_component_obj): + # Fixture does create / delete, just need to ensure called at least once + assert trial_component_obj.trial_component_name + assert trial_component_obj.input_artifacts == {} + assert trial_component_obj.parameters == {} + assert trial_component_obj.output_artifacts == {} + + +def test_create_tags(trial_component_obj, sagemaker_session): + client = sagemaker_session.sagemaker_client + while True: + actual_tags = client.list_tags(ResourceArn=trial_component_obj.trial_component_arn)["Tags"] + if actual_tags: + break + for tag in actual_tags: + if "aws:tag" in tag.get("Key"): + actual_tags.remove(tag) + assert actual_tags == trial_component_obj.tags + + +def test_delete_with_force_disassociate( + trial_component_with_force_disassociation_obj, sagemaker_session +): + assert trial_component_with_force_disassociation_obj.trial_component_name + trials = sagemaker_session.sagemaker_client.list_trials( + TrialComponentName=trial_component_with_force_disassociation_obj.trial_component_name + )["TrialSummaries"] + assert len(trials) == 3 + + +def test_save(trial_component_obj, sagemaker_session): + trial_component_obj.display_name = str(uuid.uuid4()) + trial_component_obj.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="Message" + ) + trial_component_obj.start_time = datetime.datetime.now( + datetime.timezone.utc + ) - datetime.timedelta(days=1) + trial_component_obj.end_time = datetime.datetime.now(datetime.timezone.utc) + trial_component_obj.parameters = {"foo": "bar", "whizz": 100.1} + trial_component_obj.input_artifacts = { + "snizz": _api_types.TrialComponentArtifact(value="s3:/foo/bar", media_type="text/plain"), + "snizz1": _api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2"), + } + trial_component_obj.output_artifacts = { + "fly": _api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow"), + "fly2": _api_types.TrialComponentArtifact( + value="s3:/sky/far2", media_type="away/tomorrow2" + ), + } + trial_component_obj.parameters_to_remove = ["foo"] + trial_component_obj.input_artifacts_to_remove = ["snizz"] + trial_component_obj.output_artifacts_to_remove = ["fly2"] + + trial_component_obj.save() + + loaded = trial_component._TrialComponent.load( + trial_component_name=trial_component_obj.trial_component_name, + sagemaker_session=sagemaker_session, + ) + + assert trial_component_obj.trial_component_name == loaded.trial_component_name + assert trial_component_obj.status == loaded.status + + assert trial_component_obj.start_time - loaded.start_time < datetime.timedelta(seconds=1) + assert trial_component_obj.end_time - loaded.end_time < datetime.timedelta(seconds=1) + + assert loaded.parameters == {"whizz": 100.1} + assert loaded.input_artifacts == { + "snizz1": _api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2") + } + assert loaded.output_artifacts == { + "fly": _api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow") + } + + +def test_load(trial_component_obj, sagemaker_session): + loaded = trial_component._TrialComponent.load( + trial_component_name=trial_component_obj.trial_component_name, + sagemaker_session=sagemaker_session, + ) + assert trial_component_obj.trial_component_arn == loaded.trial_component_arn + + +def test_list_sort(trial_components, sagemaker_session): + slack = datetime.timedelta(minutes=1) + now = datetime.datetime.now(datetime.timezone.utc) + trial_component_names = [tc.trial_component_name for tc in trial_components] + + for sort_order in ["Ascending", "Descending"]: + trial_component_names_listed = [ + s.trial_component_name + for s in trial_component._TrialComponent.list( + created_after=now - slack, + created_before=now + slack, + sort_by="CreationTime", + sort_order=sort_order, + sagemaker_session=sagemaker_session, + ) + if s.trial_component_name in trial_component_names + ] + + if sort_order == "Descending": + trial_component_names_listed = trial_component_names_listed[::-1] + assert trial_component_names == trial_component_names_listed + assert trial_component_names # sanity test + + +def test_search(sagemaker_session): + trial_component_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + search_expression = SearchExpression(filters=[search_filter]) + for s in trial_component._TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + trial_component_names_searched.append(s.trial_component_name) + + assert len(trial_component_names_searched) > 0 + assert trial_component_names_searched # sanity test diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index 3c416ffd36..abfe6f6d0d 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -26,6 +26,7 @@ artifact, ) from sagemaker.model import ModelPackage +from sagemaker.utils import retry_with_backoff from tests.integ.sagemaker.workflow.test_workflow import ( test_end_to_end_pipeline_successful_execution, ) @@ -43,7 +44,7 @@ ) from sagemaker.lineage.lineage_trial_component import LineageTrialComponent -from tests.integ.sagemaker.lineage.helpers import name, names, retry +from tests.integ.sagemaker.lineage.helpers import name, names SLEEP_TIME_SECONDS = 1 SLEEP_TIME_TWO_SECONDS = 2 @@ -400,7 +401,7 @@ def model_obj(sagemaker_session): yield model time.sleep(SLEEP_TIME_SECONDS) - retry(lambda: model.delete(disassociate=True), num_attempts=4) + retry_with_backoff(lambda: model.delete(disassociate=True), num_attempts=4) @pytest.fixture diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index fb71d1d88c..5548c63cff 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -15,7 +15,6 @@ import uuid from datetime import datetime -import time def name(): @@ -33,19 +32,6 @@ def names(): ] -def retry(callable, num_attempts=8): - assert num_attempts >= 1 - for i in range(num_attempts): - try: - return callable() - except Exception as ex: - if i == num_attempts - 1: - raise ex - print("Retrying", ex) - time.sleep(2**i) - assert False, "logic error in retry" - - def traverse_graph_back(start_arn, sagemaker_session): def visit(arn, visited: set): visited.add(arn) diff --git a/tests/integ/sagemaker/lineage/test_artifact.py b/tests/integ/sagemaker/lineage/test_artifact.py index c629fcdc30..1980b51da2 100644 --- a/tests/integ/sagemaker/lineage/test_artifact.py +++ b/tests/integ/sagemaker/lineage/test_artifact.py @@ -20,7 +20,7 @@ import pytest from sagemaker.lineage import artifact -from tests.integ.sagemaker.lineage.helpers import retry +from sagemaker.utils import retry_with_backoff def test_create_delete(artifact_obj): @@ -125,7 +125,7 @@ def validate(): assert len(trials) == 1 assert trial_obj.trial_name in trials - retry(validate, num_attempts=3) + retry_with_backoff(validate, num_attempts=3) def test_downstream_trials_v2(trial_associated_artifact, trial_obj, sagemaker_session): diff --git a/tests/integ/sagemaker/utilities/__init__.py b/tests/integ/sagemaker/utilities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/utilities/test_search_expression.py b/tests/integ/sagemaker/utilities/test_search_expression.py new file mode 100644 index 0000000000..ea7f4476bf --- /dev/null +++ b/tests/integ/sagemaker/utilities/test_search_expression.py @@ -0,0 +1,67 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from tests.integ.sagemaker.experiments.helpers import EXP_INTEG_TEST_NAME_PREFIX +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression, NestedFilter + + +def test_search(sagemaker_session): + tc_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + search_expression = SearchExpression(filters=[search_filter]) + for tc in _TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + tc_names_searched.append(tc.trial_component_name) + + assert len(tc_names_searched) > 0 + assert tc_names_searched + + +@pytest.mark.skip(reason="failed validation, need to wait for NestedFilter bug to be fixed") +def test_nested_search(sagemaker_session): + tc_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + nested_filter = NestedFilter(property_name="TrialComponentName", filters=[search_filter]) + search_expression = SearchExpression(nested_filters=[nested_filter]) + for tc in _TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + tc_names_searched.append(tc.trial_component_name) + + assert len(tc_names_searched) > 0 + assert tc_names_searched + + +def test_sub_expression(sagemaker_session): + tc_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + sub_expression = SearchExpression(filters=[search_filter]) + search_expression = SearchExpression(sub_expressions=[sub_expression]) + for tc in _TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + tc_names_searched.append(tc.trial_component_name) + + assert len(tc_names_searched) > 0 + assert tc_names_searched diff --git a/tests/integ/sagemaker/workflow/test_model_steps.py b/tests/integ/sagemaker/workflow/test_model_steps.py index 31c518b100..f25723c440 100644 --- a/tests/integ/sagemaker/workflow/test_model_steps.py +++ b/tests/integ/sagemaker/workflow/test_model_steps.py @@ -112,6 +112,7 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen inference_instances=["ml.m5.xlarge"], transform_instances=["ml.m5.xlarge"], description="test-description", + model_package_name="model-pkg-name-will-be-popped-out", ) step_model_regis = ModelStep( name="pytorch-register-model", diff --git a/tests/integ/sagemaker/workflow/test_processing_steps.py b/tests/integ/sagemaker/workflow/test_processing_steps.py index 781bce85a7..238eff6123 100644 --- a/tests/integ/sagemaker/workflow/test_processing_steps.py +++ b/tests/integ/sagemaker/workflow/test_processing_steps.py @@ -17,15 +17,18 @@ import re import subprocess from datetime import datetime +from pathlib import Path import pytest from botocore.exceptions import WaiterError +from sagemaker.workflow.utilities import hash_files_or_dirs, hash_object from sagemaker import image_uris, get_execution_role, utils from sagemaker.dataset_definition import DatasetDefinition, AthenaDatasetDefinition -from sagemaker.processing import ProcessingInput, ProcessingOutput -from sagemaker.s3 import S3Uploader -from sagemaker.sklearn import SKLearnProcessor +from sagemaker.processing import ProcessingInput, ProcessingOutput, FrameworkProcessor +from sagemaker.s3 import S3Uploader, S3Downloader +from sagemaker.sklearn import SKLearnProcessor, SKLearn +from sagemaker.tensorflow import TensorFlow from sagemaker.workflow.parameters import ParameterInteger, ParameterString from sagemaker.workflow.pipeline import Pipeline from sagemaker.workflow.steps import ( @@ -379,6 +382,203 @@ def test_one_step_framework_processing_pipeline( pass +def test_multi_step_framework_processing_pipeline_same_source_dir( + pipeline_session, role, pipeline_name +): + default_bucket = pipeline_session.default_bucket() + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + + SOURCE_DIR = "/pipeline/test_source_dir" + + framework_processor_tf = FrameworkProcessor( + role=role, + instance_type="ml.m5.xlarge", + instance_count=1, + estimator_cls=TensorFlow, + framework_version="2.9", + py_version="py39", + sagemaker_session=pipeline_session, + ) + + framework_processor_sk = FrameworkProcessor( + framework_version="1.0-1", + instance_type="ml.m5.xlarge", + instance_count=1, + base_job_name="my-job", + role=role, + estimator_cls=SKLearn, + sagemaker_session=pipeline_session, + ) + + step_1 = ProcessingStep( + name="Step-1", + step_args=framework_processor_tf.run( + code="script_1.py", + source_dir=DATA_DIR + SOURCE_DIR, + outputs=[ProcessingOutput(output_name="test", source="/opt/ml/processing/test")], + ), + cache_config=cache_config, + ) + + step_2 = ProcessingStep( + name="Step-2", + step_args=framework_processor_sk.run( + code="script_2.py", + source_dir=DATA_DIR + SOURCE_DIR, + inputs=[ + ProcessingInput( + source=step_1.properties.ProcessingOutputConfig.Outputs["test"].S3Output.S3Uri, + destination="/opt/ml/processing/test", + ), + ], + ), + cache_config=cache_config, + ) + + pipeline = Pipeline( + name=pipeline_name, steps=[step_1, step_2], sagemaker_session=pipeline_session + ) + try: + pipeline.create(role) + definition = json.loads(pipeline.definition()) + + source_dir_1_s3_uri, entry_point_1 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_tf, + default_bucket, + pipeline_name, + definition["Steps"][0], + SOURCE_DIR, + "script_1.py", + ) + source_dir_2_s3_uri, entry_point_2 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_sk, + default_bucket, + pipeline_name, + definition["Steps"][1], + SOURCE_DIR, + "script_2.py", + ) + + # the same local source_dirs should have the same s3 paths + assert source_dir_1_s3_uri == source_dir_2_s3_uri + + # verify different entry_point paths + assert entry_point_1 != entry_point_2 + + execution = pipeline.start(parameters={}) + try: + execution.wait(delay=540, max_attempts=3) + except WaiterError: + pass + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + for step in execution_steps: + assert step["StepStatus"] == "Succeeded" + + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_multi_step_framework_processing_pipeline_different_source_dir( + pipeline_session, role, pipeline_name +): + default_bucket = pipeline_session.default_bucket() + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + + SOURCE_DIR_1 = "/pipeline/test_source_dir" + SOURCE_DIR_2 = "/pipeline/test_source_dir_2" + + framework_processor_tf = FrameworkProcessor( + role=role, + instance_type="ml.m5.xlarge", + instance_count=1, + estimator_cls=TensorFlow, + framework_version="2.9", + py_version="py39", + sagemaker_session=pipeline_session, + ) + + step_1 = ProcessingStep( + name="Step-1", + step_args=framework_processor_tf.run( + code="script_1.py", + source_dir=DATA_DIR + SOURCE_DIR_1, + outputs=[ProcessingOutput(output_name="test", source="/opt/ml/processing/test")], + ), + cache_config=cache_config, + ) + + step_2 = ProcessingStep( + name="Step-2", + step_args=framework_processor_tf.run( + code="script_2.py", + source_dir=DATA_DIR + SOURCE_DIR_2, + inputs=[ + ProcessingInput( + source=step_1.properties.ProcessingOutputConfig.Outputs["test"].S3Output.S3Uri, + destination="/opt/ml/processing/test", + ), + ], + ), + cache_config=cache_config, + ) + + pipeline = Pipeline( + name=pipeline_name, steps=[step_1, step_2], sagemaker_session=pipeline_session + ) + try: + pipeline.create(role) + definition = json.loads(pipeline.definition()) + + source_dir_1_s3_uri, entry_point_1 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_tf, + default_bucket, + pipeline_name, + definition["Steps"][0], + SOURCE_DIR_1, + "script_1.py", + ) + source_dir_2_s3_uri, entry_point_2 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_tf, + default_bucket, + pipeline_name, + definition["Steps"][1], + SOURCE_DIR_2, + "script_2.py", + ) + + # different local source_dirs should have different s3 paths + assert source_dir_1_s3_uri != source_dir_2_s3_uri + + # verify different entry_point paths + assert entry_point_1 != entry_point_2 + + execution = pipeline.start(parameters={}) + try: + execution.wait(delay=540, max_attempts=3) + except WaiterError: + pass + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + for step in execution_steps: + assert step["StepStatus"] == "Succeeded" + + finally: + try: + pipeline.delete() + except Exception: + pass + + def test_one_step_pyspark_processing_pipeline( sagemaker_session, role, @@ -796,3 +996,46 @@ def test_two_processing_job_depends_on( pipeline.delete() except Exception: pass + + +def _verify_code_artifacts_of_framework_processing_step( + pipeline_session, processor, bucket, pipeline_name, step_definition, source_dir, entry_point +): + + source_dir_s3_uri = ( + f"s3://{bucket}/{pipeline_name}" f"/code/{hash_files_or_dirs([f'{DATA_DIR}/{source_dir}'])}" + ) + + # verify runproc.sh prefix is different from code artifact prefix + runprocs = [] + for input_obj in step_definition["Arguments"]["ProcessingInputs"]: + if input_obj["InputName"] == "entrypoint": + s3_uri = input_obj["S3Input"]["S3Uri"] + runprocs.append(s3_uri) + + assert Path(s3_uri).parent != source_dir_s3_uri + + # verify only one entrypoint generated per step + assert len(runprocs) == 1 + + expected_source_dir_tar = ( + f"{pipeline_name}" + f"/code/{hash_files_or_dirs([DATA_DIR + '/pipeline/test_source_dir'])}/sourcedir.tar.gz" + ) + + step_script = processor._generate_framework_script(entry_point) + expected_step_artifact = f"{pipeline_name}/code/{hash_object(step_script)}/runproc.sh" + + expected_prefix = f"{pipeline_name}/code" + s3_code_objects = pipeline_session.list_s3_files(bucket=bucket, key_prefix=expected_prefix) + + # verify all distinct artifacts were uploaded + assert expected_source_dir_tar in s3_code_objects + assert expected_step_artifact in s3_code_objects + + # verify runprocs contain the correct commands + step_runproc = S3Downloader.read_file( + f"s3://{bucket}/{expected_step_artifact}", pipeline_session + ) + assert f"python {entry_point}" in step_runproc + return source_dir, expected_step_artifact diff --git a/tests/integ/sagemaker/workflow/test_workflow.py b/tests/integ/sagemaker/workflow/test_workflow.py index 634ef752d6..bd24b653ae 100644 --- a/tests/integ/sagemaker/workflow/test_workflow.py +++ b/tests/integ/sagemaker/workflow/test_workflow.py @@ -1168,7 +1168,13 @@ def walk(): def test_caching_behavior( - pipeline_session, role, cpu_instance_type, pipeline_name, script_dir, athena_dataset_definition + pipeline_session, + role, + cpu_instance_type, + pipeline_name, + script_dir, + athena_dataset_definition, + region_name, ): default_bucket = pipeline_session.default_bucket() data_path = os.path.join(DATA_DIR, "workflow") @@ -1263,8 +1269,6 @@ def test_caching_behavior( # create pipeline pipeline.create(role) definition = json.loads(pipeline.definition()) - # delete profiler config for assertions as it will contain a timestamp - del definition["Steps"][1]["Arguments"]["ProfilerRuleConfigurations"] # verify input path expected_abalone_input_path = f"{pipeline_name}/{step_process.name}" f"/input/abalone_data" @@ -1289,7 +1293,6 @@ def test_caching_behavior( # verify no changes definition2 = json.loads(pipeline.definition()) - del definition2["Steps"][1]["Arguments"]["ProfilerRuleConfigurations"] assert definition == definition2 # add dummy file to source_dir @@ -1300,7 +1303,6 @@ def test_caching_behavior( # verify changes definition3 = json.loads(pipeline.definition()) - del definition3["Steps"][1]["Arguments"]["ProfilerRuleConfigurations"] assert definition != definition3 finally: diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index c1b84117c3..e19cebdca4 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -14,6 +14,7 @@ import json import time +import datetime from contextlib import contextmanager import boto3 @@ -24,6 +25,7 @@ from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition from sagemaker.feature_store.feature_group import FeatureGroup +from sagemaker.feature_store.feature_store import FeatureStore from sagemaker.feature_store.inputs import FeatureValue, FeatureParameter, TableFormatEnum from sagemaker.session import get_execution_role, Session from tests.integ.timeout import timeout @@ -80,6 +82,11 @@ def feature_group_name(): return f"my-feature-group-{int(time.time() * 10**7)}" +@pytest.fixture +def base_name(): + return f"my-base-{int(time.time() * 10**7)}" + + @pytest.fixture def offline_store_s3_uri(feature_store_session, region_name): bucket = f"sagemaker-test-featurestore-{region_name}-{feature_store_session.account_id()}" @@ -107,6 +114,32 @@ def pandas_data_frame(): return df +@pytest.fixture +def base_dataframe(): + base_data = [ + [1, 187512346.0, 123, 128], + [2, 187512347.0, 168, 258], + [3, 187512348.0, 125, 184], + [1, 187512349.0, 195, 206], + ] + return pd.DataFrame( + base_data, columns=["base_id", "base_time", "base_feature_1", "base_feature_2"] + ) + + +@pytest.fixture +def feature_group_dataframe(): + feature_group_data = [ + [1, 187512246.0, 456, 325], + [2, 187512247.0, 729, 693], + [3, 187512348.0, 129, 901], + [1, 187512449.0, 289, 286], + ] + return pd.DataFrame( + feature_group_data, columns=["fg_id", "fg_time", "fg_feature_1", "fg_feature_2"] + ) + + @pytest.fixture def pandas_data_frame_without_string(): df = pd.DataFrame( @@ -288,6 +321,92 @@ def test_create_feature_group_glue_table_format( assert table_format == "Glue" +def test_get_record( + feature_store_session, + role, + feature_group_name, + pandas_data_frame, + record, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + record_identifier_value_as_string = record[0].value_as_string + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + # Ingest data + feature_group.put_record(record=record) + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + ) + record_names = list(map(lambda r: r.feature_name, record)) + assert len(retrieved_record) == len(record_names) + for feature in retrieved_record: + assert feature["FeatureName"] in record_names + removed_feature_name = record_names.pop() + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + feature_names=record_names, + ) + assert len(retrieved_record) == len(record_names) + for feature in retrieved_record: + assert feature["FeatureName"] in record_names + assert feature["FeatureName"] is not removed_feature_name + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string="1.0", + ) + assert retrieved_record is None + + +def test_delete_record( + feature_store_session, + role, + feature_group_name, + pandas_data_frame, + record, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + record_identifier_value_as_string = record[0].value_as_string + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + # Ingest data + feature_group.put_record(record=record) + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + ) + assert retrieved_record is not None + # Delete data + feature_group.delete_record( + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=datetime.datetime.now().replace(microsecond=0).isoformat() + "Z", + ) + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + ) + assert retrieved_record is None + + def test_update_feature_group( feature_store_session, role, @@ -316,6 +435,25 @@ def test_update_feature_group( assert any([True for elem in feature_definitions if new_feature_name in elem.values()]) +def test_list_feature_groups(feature_store_session, role, feature_group_name, pandas_data_frame): + feature_store = FeatureStore(sagemaker_session=feature_store_session) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + output = feature_store.list_feature_groups(name_contains=feature_group_name) + + assert output["FeatureGroupSummaries"][0]["FeatureGroupName"] == feature_group_name + + def test_feature_metadata( feature_store_session, role, @@ -420,6 +558,242 @@ def test_ingest_multi_process( assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}") +def test_create_dataset_with_feature_group_base( + feature_store_session, + region_name, + role, + base_name, + feature_group_name, + offline_store_s3_uri, + base_dataframe, + feature_group_dataframe, +): + base = FeatureGroup(name=base_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + with cleanup_feature_group(base), cleanup_feature_group(feature_group): + _create_feature_group_and_ingest_data( + base, base_dataframe, offline_store_s3_uri, "base_id", "base_time", role + ) + _create_feature_group_and_ingest_data( + feature_group, feature_group_dataframe, offline_store_s3_uri, "fg_id", "fg_time", role + ) + base_table_name = _get_athena_table_name_after_data_replication( + feature_store_session, base, offline_store_s3_uri + ) + feature_group_table_name = _get_athena_table_name_after_data_replication( + feature_store_session, feature_group, offline_store_s3_uri + ) + + with timeout(minutes=10) and cleanup_offline_store( + base_table_name, feature_store_session + ) and cleanup_offline_store(feature_group_table_name, feature_store_session): + feature_store = FeatureStore(sagemaker_session=feature_store_session) + df, query_string = ( + feature_store.create_dataset(base=base, output_path=offline_store_s3_uri) + .with_number_of_recent_records_by_record_identifier(4) + .with_feature_group(feature_group) + .to_dataframe() + ) + sorted_df = df.sort_values(by=list(df.columns)).reset_index(drop=True) + merged_df = base_dataframe.merge( + feature_group_dataframe, left_on="base_id", right_on="fg_id" + ) + + expect_df = merged_df.sort_values(by=list(merged_df.columns)).reset_index(drop=True) + + expect_df.rename( + columns={ + "fg_id": "fg_id.1", + "fg_time": "fg_time.1", + "fg_feature_1": "fg_feature_1.1", + "fg_feature_2": "fg_feature_2.1", + }, + inplace=True, + ) + + assert sorted_df.equals(expect_df) + assert ( + query_string + == "WITH fg_base AS (WITH table_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."base_id", origin_base."base_time"\n' + + 'ORDER BY origin_base."api_invocation_time" DESC, origin_base."write_time" DESC\n' + + ") AS dedup_row_base\n" + + f'FROM "sagemaker_featurestore"."{base_table_name}" origin_base\n' + + ")\n" + + "WHERE dedup_row_base = 1\n" + + "),\n" + + "deleted_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."base_id"\n' + + 'ORDER BY origin_base."base_time" DESC,' + ' origin_base."api_invocation_time" DESC,' + ' origin_base."write_time" DESC\n' + + ") AS deleted_row_base\n" + + f'FROM "sagemaker_featurestore"."{base_table_name}" origin_base\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_base = 1\n" + + ")\n" + + 'SELECT table_base."base_id", table_base."base_time",' + ' table_base."base_feature_1", table_base."base_feature_2"\n' + + "FROM (\n" + + 'SELECT table_base."base_id", table_base."base_time",' + ' table_base."base_feature_1", table_base."base_feature_2",' + ' table_base."write_time"\n' + + "FROM table_base\n" + + "LEFT JOIN deleted_base\n" + + 'ON table_base."base_id" = deleted_base."base_id"\n' + + 'WHERE deleted_base."base_id" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_base."base_id", table_base."base_time",' + ' table_base."base_feature_1", table_base."base_feature_2",' + ' table_base."write_time"\n' + + "FROM deleted_base\n" + + "JOIN table_base\n" + + 'ON table_base."base_id" = deleted_base."base_id"\n' + + "AND (\n" + + 'table_base."base_time" > deleted_base."base_time"\n' + + 'OR (table_base."base_time" = deleted_base."base_time" AND' + ' table_base."api_invocation_time" >' + ' deleted_base."api_invocation_time")\n' + + 'OR (table_base."base_time" = deleted_base."base_time" AND' + ' table_base."api_invocation_time" =' + ' deleted_base."api_invocation_time" AND' + ' table_base."write_time" > deleted_base."write_time")\n' + + ")\n" + + ") AS table_base\n" + + "),\n" + + "fg_0 AS (WITH table_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."fg_id", origin_0."fg_time"\n' + + 'ORDER BY origin_0."api_invocation_time" DESC, origin_0."write_time" DESC\n' + + ") AS dedup_row_0\n" + + f'FROM "sagemaker_featurestore"."{feature_group_table_name}" origin_0\n' + + ")\n" + + "WHERE dedup_row_0 = 1\n" + + "),\n" + + "deleted_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."fg_id"\n' + + 'ORDER BY origin_0."fg_time" DESC, origin_0."api_invocation_time" DESC,' + ' origin_0."write_time" DESC\n' + + ") AS deleted_row_0\n" + + f'FROM "sagemaker_featurestore"."{feature_group_table_name}" origin_0\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_0 = 1\n" + + ")\n" + + 'SELECT table_0."fg_id", table_0."fg_time", table_0."fg_feature_1",' + ' table_0."fg_feature_2"\n' + + "FROM (\n" + + 'SELECT table_0."fg_id", table_0."fg_time",' + ' table_0."fg_feature_1", table_0."fg_feature_2",' + ' table_0."write_time"\n' + + "FROM table_0\n" + + "LEFT JOIN deleted_0\n" + + 'ON table_0."fg_id" = deleted_0."fg_id"\n' + + 'WHERE deleted_0."fg_id" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_0."fg_id", table_0."fg_time",' + ' table_0."fg_feature_1", table_0."fg_feature_2",' + ' table_0."write_time"\n' + + "FROM deleted_0\n" + + "JOIN table_0\n" + + 'ON table_0."fg_id" = deleted_0."fg_id"\n' + + "AND (\n" + + 'table_0."fg_time" > deleted_0."fg_time"\n' + + 'OR (table_0."fg_time" = deleted_0."fg_time" AND' + ' table_0."api_invocation_time" >' + ' deleted_0."api_invocation_time")\n' + + 'OR (table_0."fg_time" = deleted_0."fg_time" AND' + ' table_0."api_invocation_time" =' + ' deleted_0."api_invocation_time" AND table_0."write_time" >' + ' deleted_0."write_time")\n' + + ")\n" + + ") AS table_0\n" + + ")\n" + + "SELECT base_id, base_time, base_feature_1, base_feature_2," + ' "fg_id.1", "fg_time.1", "fg_feature_1.1",' + ' "fg_feature_2.1"\n' + "FROM (\n" + "SELECT fg_base.base_id, fg_base.base_time," + " fg_base.base_feature_1, fg_base.base_feature_2," + ' fg_0."fg_id" as "fg_id.1", fg_0."fg_time" as "fg_time.1",' + ' fg_0."fg_feature_1" as "fg_feature_1.1",' + ' fg_0."fg_feature_2" as "fg_feature_2.1", row_number()' + " OVER (\n" + + 'PARTITION BY fg_base."base_id"\n' + + 'ORDER BY fg_base."base_time" DESC, fg_0."fg_time" DESC\n' + + ") AS row_recent\n" + + "FROM fg_base\n" + + "JOIN fg_0\n" + + 'ON fg_base."base_id" = fg_0."fg_id"\n' + + ")\n" + + "WHERE row_recent <= 4" + ) + + +def _create_feature_group_and_ingest_data( + feature_group: FeatureGroup, + dataframe: DataFrame, + offline_store_s3_uri: str, + record_identifier_name: str, + event_time_name: str, + role: str, +): + feature_group.load_feature_definitions(data_frame=dataframe) + feature_group.create( + s3_uri=offline_store_s3_uri, + record_identifier_name=record_identifier_name, + event_time_feature_name=event_time_name, + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + + ingestion_manager = feature_group.ingest(data_frame=dataframe, max_workers=3, wait=False) + ingestion_manager.wait() + assert 0 == len(ingestion_manager.failed_rows) + + +def _get_athena_table_name_after_data_replication( + feature_store_session, feature_group: FeatureGroup, offline_store_s3_uri +): + feature_group_metadata = feature_group.describe() + resolved_output_s3_uri = ( + feature_group_metadata.get("OfflineStoreConfig", None) + .get("S3StorageConfig", None) + .get("ResolvedOutputS3Uri", None) + ) + s3_prefix = resolved_output_s3_uri.replace(f"{offline_store_s3_uri}/", "") + region_name = feature_store_session.boto_session.region_name + s3_client = feature_store_session.boto_session.client( + service_name="s3", region_name=region_name + ) + while True: + objects_in_bucket = s3_client.list_objects( + Bucket=offline_store_s3_uri.replace("s3://", ""), Prefix=s3_prefix + ) + if "Contents" in objects_in_bucket and len(objects_in_bucket["Contents"]) > 1: + break + else: + print(f"Waiting for {feature_group.name} data in offline store...") + time.sleep(60) + print(f"{feature_group.name} data available.") + return ( + feature_group_metadata.get("OfflineStoreConfig", None) + .get("DataCatalogConfig", None) + .get("TableName", None) + ) + + def _wait_for_feature_group_create(feature_group: FeatureGroup): status = feature_group.describe().get("FeatureGroupStatus") while status == "Creating": @@ -451,5 +825,31 @@ def cleanup_feature_group(feature_group: FeatureGroup): finally: try: feature_group.delete() + print(f"{feature_group.name} is deleted") except Exception: raise RuntimeError(f"Failed to delete feature group with name {feature_group.name}") + + +@contextmanager +def cleanup_offline_store(table_name: str, feature_store_session: Session): + try: + yield + finally: + try: + region_name = feature_store_session.boto_session.region_name + s3_client = feature_store_session.boto_session.client( + service_name="s3", region_name=region_name + ) + account_id = feature_store_session.account_id() + bucket_name = f"sagemaker-test-featurestore-{region_name}-{account_id}" + response = s3_client.list_objects_v2( + Bucket=bucket_name, + Prefix=f"{account_id}/sagemaker/{region_name}/offline-store/{table_name}/", + ) + files_in_folder = response["Contents"] + files_to_delete = [] + for f in files_in_folder: + files_to_delete.append({"Key": f["Key"]}) + s3_client.delete_objects(Bucket=bucket_name, Delete={"Objects": files_to_delete}) + except Exception: + raise RuntimeError(f"Failed to delete data under {table_name}") diff --git a/tests/integ/test_inference_pipeline.py b/tests/integ/test_inference_pipeline.py index 53d966fe9b..a26d8c9101 100644 --- a/tests/integ/test_inference_pipeline.py +++ b/tests/integ/test_inference_pipeline.py @@ -50,6 +50,7 @@ ) +@pytest.mark.skip(reason="Test has likely been failing for a while. Suspected bad XGB model.") def test_inference_pipeline_batch_transform(sagemaker_session, cpu_instance_type): sparkml_model_data = sagemaker_session.upload_data( path=os.path.join(SPARKML_DATA_PATH, "mleap_model.tar.gz"), diff --git a/tests/integ/test_marketplace.py b/tests/integ/test_marketplace.py index b9ff13c50e..28b537c1ea 100644 --- a/tests/integ/test_marketplace.py +++ b/tests/integ/test_marketplace.py @@ -23,6 +23,7 @@ import sagemaker import tests.integ +from tests.integ.utils import create_repository from sagemaker import AlgorithmEstimator, ModelPackage, Model from sagemaker.serializers import CSVSerializer from sagemaker.tuner import IntegerParameter, HyperparameterTuner @@ -33,7 +34,6 @@ from tests.integ.test_multidatamodel import ( _ecr_image_uri, _ecr_login, - _create_repository, _delete_repository, ) from tests.integ.retry import retries @@ -214,7 +214,7 @@ def iris_image(sagemaker_session): rm=True, ) image.tag(ecr_image, tag="latest") - _create_repository(ecr_client, algorithm_name) + create_repository(ecr_client, algorithm_name) # Retry docker image push for _ in retries(3, "Upload docker image to ECR repo", seconds_to_sleep=10): diff --git a/tests/integ/test_multidatamodel.py b/tests/integ/test_multidatamodel.py index 78ba62c3db..d6c14037a7 100644 --- a/tests/integ/test_multidatamodel.py +++ b/tests/integ/test_multidatamodel.py @@ -19,8 +19,8 @@ import docker import numpy import pytest -from botocore.exceptions import ClientError +from tests.integ.utils import create_repository from sagemaker import utils from sagemaker.amazon.randomcutforest import RandomCutForest from sagemaker.deserializers import StringDeserializer @@ -59,7 +59,7 @@ def container_image(sagemaker_session): image.tag(ecr_image, tag="latest") # Create AWS ECR and push the local docker image to it - _create_repository(ecr_client, algorithm_name) + create_repository(ecr_client, algorithm_name) # Retry docker image push for _ in retries(3, "Upload docker image to ECR repo", seconds_to_sleep=10): @@ -90,23 +90,6 @@ def _ecr_image_uri(sagemaker_session, algorithm_name): return "{}.dkr.{}/{}:latest".format(account_id, endpoint_data["hostname"], algorithm_name) -def _create_repository(ecr_client, repository_name): - """ - Creates an ECS Repository (ECR). When a new transform is being registered, - we'll need a repository to push the image (and composed model images) to - """ - try: - response = ecr_client.create_repository(repositoryName=repository_name) - return response["repository"]["repositoryUri"] - except ClientError as e: - # Handle when the repository already exists - if "RepositoryAlreadyExistsException" == e.response.get("Error", {}).get("Code"): - response = ecr_client.describe_repositories(repositoryNames=[repository_name]) - return response["repositories"][0]["repositoryUri"] - else: - raise - - def _delete_repository(ecr_client, repository_name): """ Deletes an ECS Repository (ECR). After the integration test completes diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index bddd53e20c..7d3fdb2d7b 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -13,7 +13,6 @@ from __future__ import absolute_import import os -import re import time import uuid @@ -22,7 +21,6 @@ from sagemaker.debugger import ( DebuggerHookConfig, FrameworkProfile, - get_rule_container_image_uri, ProfilerConfig, ProfilerRule, Rule, @@ -93,8 +91,6 @@ def test_mxnet_with_default_profiler_config_and_profiler_rule( ) job_description = mx.latest_training_job.describe() - if "DisableProfiler" in job_description["ProfilerConfig"]: - job_description["ProfilerConfig"].pop("DisableProfiler") assert ( job_description["ProfilerConfig"] == ProfilerConfig( @@ -103,13 +99,6 @@ def test_mxnet_with_default_profiler_config_and_profiler_rule( ) assert job_description.get("ProfilingStatus") == "Enabled" - profiler_rule_configuration = job_description.get("ProfilerRuleConfigurations")[0] - assert re.match(r"ProfilerReport-\d*", profiler_rule_configuration["RuleConfigurationName"]) - assert profiler_rule_configuration["RuleEvaluatorImage"] == get_rule_container_image_uri( - mx.sagemaker_session.boto_region_name - ) - assert profiler_rule_configuration["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"} - with pytest.raises(ValueError) as error: mx.enable_default_profiling() assert "Debugger monitoring is already enabled." in str(error) @@ -155,18 +144,9 @@ def test_mxnet_with_custom_profiler_config_then_update_rule_and_config( ) job_description = mx.latest_training_job.describe() - if "DisableProfiler" in job_description["ProfilerConfig"]: - job_description["ProfilerConfig"].pop("DisableProfiler") assert job_description.get("ProfilerConfig") == profiler_config._to_request_dict() assert job_description.get("ProfilingStatus") == "Enabled" - profiler_rule_configuration = job_description.get("ProfilerRuleConfigurations")[0] - assert re.match(r"ProfilerReport-\d*", profiler_rule_configuration["RuleConfigurationName"]) - assert profiler_rule_configuration["RuleEvaluatorImage"] == get_rule_container_image_uri( - mx.sagemaker_session.boto_region_name - ) - assert profiler_rule_configuration["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"} - _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) mx.update_profiler( @@ -178,13 +158,6 @@ def test_mxnet_with_custom_profiler_config_then_update_rule_and_config( assert job_description["ProfilerConfig"]["S3OutputPath"] == profiler_config.s3_output_path assert job_description["ProfilerConfig"]["ProfilingIntervalInMilliseconds"] == 500 - profiler_report_rule_config = job_description.get("ProfilerRuleConfigurations")[0] - assert re.match(r"ProfilerReport-\d*", profiler_report_rule_config["RuleConfigurationName"]) - assert profiler_report_rule_config["RuleEvaluatorImage"] == get_rule_container_image_uri( - mx.sagemaker_session.boto_region_name - ) - assert profiler_report_rule_config["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"} - def test_mxnet_with_built_in_profiler_rule_with_custom_parameters( sagemaker_session, @@ -225,8 +198,6 @@ def test_mxnet_with_built_in_profiler_rule_with_custom_parameters( ) job_description = mx.latest_training_job.describe() - if "DisableProfiler" in job_description["ProfilerConfig"]: - job_description["ProfilerConfig"].pop("DisableProfiler") assert job_description.get("ProfilingStatus") == "Enabled" assert ( job_description.get("ProfilerConfig") @@ -298,8 +269,6 @@ def test_mxnet_with_profiler_and_debugger_then_disable_framework_metrics( ) job_description = mx.latest_training_job.describe() - if "DisableProfiler" in job_description["ProfilerConfig"]: - job_description["ProfilerConfig"].pop("DisableProfiler") assert job_description["ProfilerConfig"] == profiler_config._to_request_dict() assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict() assert job_description.get("ProfilingStatus") == "Enabled" @@ -387,13 +356,6 @@ def test_mxnet_with_enable_framework_metrics_then_update_framework_metrics( == updated_framework_profile.profiling_parameters ) - profiler_rule_configuration = job_description.get("ProfilerRuleConfigurations")[0] - assert re.match(r"ProfilerReport-\d*", profiler_rule_configuration["RuleConfigurationName"]) - assert profiler_rule_configuration["RuleEvaluatorImage"] == get_rule_container_image_uri( - mx.sagemaker_session.boto_region_name - ) - assert profiler_rule_configuration["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"} - def test_mxnet_with_disable_profiler_then_enable_default_profiling( sagemaker_session, @@ -431,12 +393,10 @@ def test_mxnet_with_disable_profiler_then_enable_default_profiling( ) job_description = mx.latest_training_job.describe() - assert job_description.get("ProfilerConfig") is None assert job_description.get("ProfilerRuleConfigurations") is None assert job_description.get("ProfilingStatus") == "Disabled" _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) - mx.enable_default_profiling() job_description = mx.latest_training_job.describe() diff --git a/tests/integ/test_training_compiler.py b/tests/integ/test_training_compiler.py index 67de050ed1..724cd8890c 100644 --- a/tests/integ/test_training_compiler.py +++ b/tests/integ/test_training_compiler.py @@ -20,6 +20,8 @@ from sagemaker.huggingface import TrainingCompilerConfig as HFTrainingCompilerConfig from sagemaker.tensorflow import TensorFlow from sagemaker.tensorflow import TrainingCompilerConfig as TFTrainingCompilerConfig +from sagemaker.pytorch import PyTorch +from sagemaker.pytorch import TrainingCompilerConfig as PTTrainingCompilerConfig from tests import integ from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES @@ -48,8 +50,7 @@ def imagenet_val_set(request, sagemaker_session, tmpdir_factory): key_prefix="Imagenet/TFRecords/validation", ) train_input = sagemaker_session.upload_data( - path=local_path, - key_prefix="integ-test-data/trcomp/tensorflow/imagenet/val", + path=local_path, key_prefix="integ-test-data/trcomp/tensorflow/imagenet/val" ) return train_input @@ -84,8 +85,8 @@ def skip_if_incompatible(gpu_instance_type, request): @pytest.mark.parametrize( "gpu_instance_type,instance_count", [ - ("ml.p3.2xlarge", 1), - ("ml.p3.16xlarge", 2), + pytest.param("ml.p3.2xlarge", 1, marks=pytest.mark.release), + pytest.param("ml.p3.16xlarge", 2), ], ) def test_huggingface_pytorch( @@ -129,27 +130,32 @@ def test_huggingface_pytorch( hf.fit(huggingface_dummy_dataset) -@pytest.mark.release -def test_huggingface_pytorch_release( +@pytest.mark.parametrize( + "gpu_instance_type,instance_count", + [ + pytest.param("ml.p3.2xlarge", 1, marks=pytest.mark.release), + pytest.param("ml.p3.16xlarge", 2), + ], +) +def test_pytorch( sagemaker_session, gpu_instance_type, - huggingface_training_compiler_latest_version, - huggingface_training_compiler_pytorch_latest_version, + instance_count, + pytorch_training_compiler_latest_version, huggingface_dummy_dataset, ): """ - Test the HuggingFace estimator with PyTorch + Test the PyTorch estimator """ with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, "huggingface") - hf = HuggingFace( + hf = PyTorch( py_version="py38", - entry_point=os.path.join(data_path, "run_glue.py"), + source_dir=os.path.join(DATA_DIR, "huggingface_byoc"), + entry_point="run_glue.py", role="SageMakerRole", - transformers_version=huggingface_training_compiler_latest_version, - pytorch_version=huggingface_training_compiler_pytorch_latest_version, - instance_count=1, + framework_version=pytorch_training_compiler_latest_version, + instance_count=instance_count, instance_type=gpu_instance_type, hyperparameters={ "model_name_or_path": "distilbert-base-cased", @@ -163,7 +169,8 @@ def test_huggingface_pytorch_release( }, sagemaker_session=sagemaker_session, disable_profiler=True, - compiler_config=HFTrainingCompilerConfig(), + compiler_config=PTTrainingCompilerConfig(), + distribution={"pytorchxla": {"enabled": True}} if instance_count > 1 else None, ) hf.fit(huggingface_dummy_dataset) @@ -209,10 +216,7 @@ def test_huggingface_tensorflow( @pytest.mark.release def test_tensorflow( - sagemaker_session, - gpu_instance_type, - tensorflow_training_latest_version, - imagenet_val_set, + sagemaker_session, gpu_instance_type, tensorflow_training_latest_version, imagenet_val_set ): """ Test the TensorFlow estimator @@ -264,8 +268,4 @@ def test_tensorflow( compiler_config=TFTrainingCompilerConfig(), ) - tf.fit( - inputs=imagenet_val_set, - logs=True, - wait=True, - ) + tf.fit(inputs=imagenet_val_set, logs=True, wait=True) diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index a0e37ffc77..1de333b987 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -25,6 +25,7 @@ from sagemaker.transformer import Transformer from sagemaker.estimator import Estimator from sagemaker.inputs import BatchDataCaptureConfig +from sagemaker.xgboost import XGBoostModel from sagemaker.utils import unique_name_from_base from tests.integ import ( datasets, @@ -36,7 +37,7 @@ from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer from tests.integ.vpc_test_utils import get_or_create_vpc_resources -from sagemaker.model_monitor import DatasetFormat, Statistics +from sagemaker.model_monitor import DatasetFormat, Statistics, Constraints from sagemaker.workflow.check_job_config import CheckJobConfig from sagemaker.workflow.quality_check_step import ( @@ -645,3 +646,66 @@ def _create_transformer_and_transform_job( job_name=unique_name_from_base("test-transform"), ) return transformer + + +def test_transformer_and_monitoring_job( + pipeline_session, + sagemaker_session, + role, + pipeline_name, + check_job_config, + data_bias_check_config, +): + xgb_model_data_s3 = pipeline_session.upload_data( + path=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + data_bias_supplied_baseline_constraints = Constraints.from_file_path( + constraints_file_path=os.path.join( + DATA_DIR, "pipeline/clarify_check_step/data_bias/good_cases/analysis.json" + ), + sagemaker_session=sagemaker_session, + ).file_s3_uri + + xgb_model = XGBoostModel( + model_data=xgb_model_data_s3, + framework_version="1.3-1", + role=role, + sagemaker_session=sagemaker_session, + entry_point=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "inference.py"), + enable_network_isolation=True, + ) + + xgb_model.deploy(_INSTANCE_COUNT, _INSTANCE_TYPE) + + transform_output = f"s3://{sagemaker_session.default_bucket()}/{pipeline_name}Transform" + transformer = Transformer( + model_name=xgb_model.name, + strategy="SingleRecord", + instance_type="ml.m5.xlarge", + instance_count=1, + output_path=transform_output, + sagemaker_session=pipeline_session, + ) + + transform_input = pipeline_session.upload_data( + path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"), + key_prefix="integ-test-data/xgboost_abalone/abalone", + ) + + execution = transformer.transform_with_monitoring( + monitoring_config=data_bias_check_config, + monitoring_resource_config=check_job_config, + data=transform_input, + content_type="text/libsvm", + supplied_baseline_constraints=data_bias_supplied_baseline_constraints, + role=role, + ) + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + + for execution_step in execution_steps: + assert execution_step["StepStatus"] == "Succeeded" + + xgb_model.delete_model() diff --git a/tests/integ/test_xgboost.py b/tests/integ/test_xgboost.py index 733ab4665a..df06a8863a 100644 --- a/tests/integ/test_xgboost.py +++ b/tests/integ/test_xgboost.py @@ -40,6 +40,26 @@ def xgboost_training_job( ) +def test_sourcedir_naming( + sagemaker_session, + xgboost_latest_version, + xgboost_latest_py_version, + cpu_instance_type, +): + with pytest.raises(RuntimeError): + processor = XGBoostProcessor( + framework_version=xgboost_latest_version, + role=ROLE, + instance_count=1, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + ) + processor.run( + source_dir="s3://bucket/deps.tar.gz", + code="main_script.py", + ) + + @pytest.mark.release def test_framework_processing_job_with_deps( sagemaker_session, diff --git a/tests/integ/utils.py b/tests/integ/utils.py index 53440f96f5..d7891321f2 100644 --- a/tests/integ/utils.py +++ b/tests/integ/utils.py @@ -14,6 +14,8 @@ import logging from functools import wraps +from botocore.exceptions import ClientError + from tests.conftest import NO_P3_REGIONS, NO_M4_REGIONS from sagemaker.exceptions import CapacityError @@ -69,3 +71,21 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def create_repository(ecr_client, repository_name): + """Creates an ECS Repository (ECR). + + When a new transform is being registered, + we'll need a repository to push the image (and composed model images) to + """ + try: + response = ecr_client.create_repository(repositoryName=repository_name) + return response["repository"]["repositoryUri"] + except ClientError as e: + # Handle when the repository already exists + if "RepositoryAlreadyExistsException" == e.response.get("Error", {}).get("Code"): + response = ecr_client.describe_repositories(repositoryNames=[repository_name]) + return response["repositories"][0]["repositoryUri"] + else: + raise diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000000..21fe49cc97 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,66 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import sagemaker + +from mock import Mock, PropertyMock + +_ROLE = "DummyRole" +_REGION = "us-west-2" +_DEFAULT_BUCKET = "my-bucket" + + +@pytest.fixture(scope="session") +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture(scope="session") +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value=_ROLE) + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name=_REGION) + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + +@pytest.fixture(scope="session") +def sagemaker_session(boto_session, client): + # ideally this would mock Session instead of instantiating it + # most unit tests do mock the session correctly + return sagemaker.session.Session( + boto_session=boto_session, + sagemaker_client=client, + sagemaker_runtime_client=client, + default_bucket=_DEFAULT_BUCKET, + sagemaker_metrics_client=client, + ) diff --git a/tests/unit/sagemaker/experiments/__init__.py b/tests/unit/sagemaker/experiments/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/experiments/conftest.py b/tests/unit/sagemaker/experiments/conftest.py new file mode 100644 index 0000000000..4d33ad759d --- /dev/null +++ b/tests/unit/sagemaker/experiments/conftest.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest +from unittest.mock import patch, MagicMock, Mock + +import pytest + +from sagemaker import Session +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import RUN_NAME_BASE +from sagemaker.experiments import Run +from tests.unit.sagemaker.experiments.helpers import ( + mock_tc_load_or_create_func, + mock_trial_load_or_create_func, + TEST_EXP_NAME, +) + + +@pytest.fixture +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = unittest.mock.Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture +def sagemaker_session(client): + return Session( + sagemaker_client=client, + ) + + +@pytest.fixture +def run_obj(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.update_trial_component.return_value = {} + client.associate_trial_component.return_value = {} + with patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock( + return_value=_Experiment( + experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session + ) + ), + ): + with patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), + ): + with patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), + ): + run = Run( + experiment_name=TEST_EXP_NAME, + sagemaker_session=sagemaker_session, + ) + run._artifact_uploader = Mock() + run._lineage_artifact_tracker = Mock() + run._metrics_manager = Mock() + + assert run.run_name.startswith(RUN_NAME_BASE) + assert run.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) + + return run diff --git a/tests/unit/sagemaker/experiments/helpers.py b/tests/unit/sagemaker/experiments/helpers.py new file mode 100644 index 0000000000..b7914010e5 --- /dev/null +++ b/tests/unit/sagemaker/experiments/helpers.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + + +TEST_EXP_NAME = "my-experiment" +TEST_RUN_NAME = "my-run" + + +def mock_tc_load_or_create_func( + trial_component_name, display_name=None, tags=None, sagemaker_session=None +): + tc = _TrialComponent( + trial_component_name=trial_component_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return tc, True + + +def mock_trial_load_or_create_func( + experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None +): + return _Trial( + trial_name=trial_name, + experiment_name=experiment_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) diff --git a/tests/unit/sagemaker/experiments/test_environment.py b/tests/unit/sagemaker/experiments/test_environment.py new file mode 100644 index 0000000000..8bb23db7b6 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_environment.py @@ -0,0 +1,107 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os +import shutil +import tempfile +import unittest.mock + +import pytest + +from sagemaker.experiments import _environment +from sagemaker.utils import retry_with_backoff + + +@pytest.fixture +def tempdir(): + dir = tempfile.mkdtemp() + yield dir + shutil.rmtree(dir) + + +@pytest.fixture +def training_job_env(): + old_value = os.environ.get("TRAINING_JOB_ARN") + os.environ["TRAINING_JOB_ARN"] = "arn:1234aBcDe" + yield os.environ + del os.environ["TRAINING_JOB_ARN"] + if old_value: + os.environ["TRAINING_JOB_ARN"] = old_value + + +@pytest.fixture +def transform_job_env(): + old_value = os.environ.get("SAGEMAKER_BATCH") + os.environ["SAGEMAKER_BATCH"] = "true" + yield os.environ + del os.environ["SAGEMAKER_BATCH"] + if old_value: + os.environ["SAGEMAKER_BATCH"] = old_value + + +def test_processing_job_environment(tempdir): + config_path = os.path.join(tempdir, "config.json") + with open(config_path, "w") as f: + f.write(json.dumps({"ProcessingJobArn": "arn:1234aBcDe"})) + environment = _environment._RunEnvironment.load(processing_job_config_path=config_path) + + assert _environment._EnvironmentType.SageMakerProcessingJob == environment.environment_type + assert "arn:1234aBcDe" == environment.source_arn + + +def test_training_job_environment(training_job_env): + environment = _environment._RunEnvironment.load() + assert _environment._EnvironmentType.SageMakerTrainingJob == environment.environment_type + assert "arn:1234aBcDe" == environment.source_arn + + +def test_transform_job_environment(transform_job_env): + environment = _environment._RunEnvironment.load() + assert _environment._EnvironmentType.SageMakerTransformJob == environment.environment_type + # TODO: update if we figure out how to get source_arn from the transform job + assert not environment.source_arn + + +def test_no_environment(): + assert _environment._RunEnvironment.load() is None + + +def test_resolve_trial_component(training_job_env, sagemaker_session): + trial_component_name = "foo-bar" + client = sagemaker_session.sagemaker_client + client.list_trial_components.return_value = { + "TrialComponentSummaries": [{"TrialComponentName": trial_component_name}] + } + client.describe_trial_component.return_value = {"TrialComponentName": trial_component_name} + environment = _environment._RunEnvironment.load() + tc = environment.get_trial_component(sagemaker_session) + + assert trial_component_name == tc.trial_component_name + client.describe_trial_component.assert_called_with(TrialComponentName=trial_component_name) + client.list_trial_components.assert_called_with(SourceArn="arn:1234abcde") + + +@unittest.mock.patch("sagemaker.experiments._environment.retry_with_backoff") +def test_resolve_trial_component_fails(mock_retry, sagemaker_session, training_job_env): + mock_retry.side_effect = lambda func: retry_with_backoff(func, 2) + client = sagemaker_session.sagemaker_client + client.list_trial_components.side_effect = Exception("Failed test") + environment = _environment._RunEnvironment.load() + assert environment.get_trial_component(sagemaker_session) is None + + +def test_resolve_transform_job_trial_component_fail(transform_job_env, sagemaker_session): + environment = _environment._RunEnvironment.load() + assert environment.get_trial_component(sagemaker_session) is None diff --git a/tests/unit/sagemaker/experiments/test_experiment.py b/tests/unit/sagemaker/experiments/test_experiment.py new file mode 100644 index 0000000000..b0ad55c27f --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_experiment.py @@ -0,0 +1,306 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import unittest.mock +import datetime + +from unittest.mock import patch + +from sagemaker import Session +from sagemaker.experiments import experiment +from sagemaker.experiments._api_types import TrialSummary + + +@pytest.fixture +def datetime_obj(): + return datetime.datetime(2017, 6, 16, 15, 55, 0) + + +def test_load(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.describe_experiment.return_value = {"Description": "description-value"} + experiment_obj = experiment._Experiment.load( + experiment_name="name-value", sagemaker_session=sagemaker_session + ) + assert experiment_obj.experiment_name == "name-value" + assert experiment_obj.description == "description-value" + + client.describe_experiment.assert_called_with(ExperimentName="name-value") + + +def test_create(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_experiment.return_value = {"Arn": "arn:aws:1234"} + experiment_obj = experiment._Experiment.create( + experiment_name="name-value", sagemaker_session=sagemaker_session + ) + assert experiment_obj.experiment_name == "name-value" + client.create_experiment.assert_called_with(ExperimentName="name-value") + + +def test_create_with_tags(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_experiment.return_value = {"Arn": "arn:aws:1234"} + tags = [{"Key": "foo", "Value": "bar"}] + experiment_obj = experiment._Experiment.create( + experiment_name="name-value", sagemaker_session=sagemaker_session, tags=tags + ) + assert experiment_obj.experiment_name == "name-value" + client.create_experiment.assert_called_with(ExperimentName="name-value", Tags=tags) + + +def test_save(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + client.update_experiment.return_value = {} + obj.save() + client.update_experiment.assert_called_with(ExperimentName="foo", Description="bar") + + +def test_delete(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + client.delete_experiment.return_value = {} + obj.delete() + client.delete_experiment.assert_called_with(ExperimentName="foo") + + +@patch("sagemaker.experiments.experiment._Experiment.load") +def test_load_or_create_when_exist(mock_load, sagemaker_session): + exp_name = "exp_name" + experiment._Experiment._load_or_create( + experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + mock_load.assert_called_once_with(exp_name, sagemaker_session) + + +@patch("sagemaker.experiments.experiment._Experiment.load") +@patch("sagemaker.experiments.experiment._Experiment.create") +def test_load_or_create_when_not_exist(mock_create, mock_load): + sagemaker_session = Session() + client = sagemaker_session.sagemaker_client + exp_name = "exp_name" + not_found_err = client.exceptions.ResourceNotFound( + error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, + operation_name="foo", + ) + mock_load.side_effect = not_found_err + + experiment._Experiment._load_or_create( + experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + + mock_create.assert_called_once_with( + experiment_name=exp_name, + display_name=None, + description=None, + tags=None, + sagemaker_session=sagemaker_session, + ) + + +def test_list_trials_empty(sagemaker_session): + sagemaker_session.sagemaker_client.list_trials.return_value = {"TrialSummaries": []} + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + assert list(experiment_obj.list_trials()) == [] + + +def test_list_trials_single(sagemaker_session, datetime_obj): + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + sagemaker_session.sagemaker_client.list_trials.return_value = { + "TrialSummaries": [ + {"Name": "trial-foo", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj} + ] + } + + assert list(experiment_obj.list_trials()) == [ + TrialSummary(name="trial-foo", creation_time=datetime_obj, last_modified_time=datetime_obj) + ] + + +def test_list_trials_two_values(sagemaker_session, datetime_obj): + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + sagemaker_session.sagemaker_client.list_trials.return_value = { + "TrialSummaries": [ + {"Name": "trial-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}, + {"Name": "trial-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}, + ] + } + + assert list(experiment_obj.list_trials()) == [ + TrialSummary( + name="trial-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + name="trial-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + + +def test_next_token(sagemaker_session, datetime_obj): + experiment_obj = experiment._Experiment(sagemaker_session) + client = sagemaker_session.sagemaker_client + client.list_trials.side_effect = [ + { + "TrialSummaries": [ + { + "Name": "trial-foo-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "Name": "trial-foo-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ], + "NextToken": "foo", + }, + { + "TrialSummaries": [ + { + "Name": "trial-foo-3", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + } + ] + }, + ] + + assert list(experiment_obj.list_trials()) == [ + TrialSummary( + name="trial-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + name="trial-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + name="trial-foo-3", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + + client.list_trials.assert_any_call(**{}) + client.list_trials.assert_any_call(NextToken="foo") + + +def test_list_trials_call_args(sagemaker_session): + client = sagemaker_session.sagemaker_client + created_before = datetime.datetime(1999, 10, 12, 0, 0, 0) + created_after = datetime.datetime(1990, 10, 12, 0, 0, 0) + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + client.list_trials.return_value = {} + assert [] == list( + experiment_obj.list_trials(created_after=created_after, created_before=created_before) + ) + client.list_trials.assert_called_with(CreatedBefore=created_before, CreatedAfter=created_after) + + +def test_delete_all_with_incorrect_action_name(sagemaker_session): + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + with pytest.raises(ValueError) as err: + obj._delete_all(action="abc") + + assert "Must confirm with string '--force'" in str(err) + + +def test_delete_all(sagemaker_session): + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + client.describe_trial.side_effect = [ + {"Trialname": "trial-1", "ExperimentName": "experiment-name-value"}, + {"Trialname": "trial-2", "ExperimentName": "experiment-name-value"}, + ] + client.list_trial_components.side_effect = [ + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "trial-component-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialComponentName": "trial-component-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + }, + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "trial-component-3", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialComponentName": "trial-component-4", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + }, + ] + + client.describe_trial_component.side_effect = [ + {"TrialComponentName": "trial-component-1"}, + {"TrialComponentName": "trial-component-2"}, + {"TrialComponentName": "trial-component-3"}, + {"TrialComponentName": "trial-component-4"}, + ] + + client.delete_trial_component.return_value = {} + client.delete_trial.return_value = {} + client.delete_experiment.return_value = {} + + obj._delete_all(action="--force") + + client.delete_experiment.assert_called_with(ExperimentName="foo") + + delete_trial_expected_calls = [ + unittest.mock.call(TrialName="trial-1"), + unittest.mock.call(TrialName="trial-2"), + ] + assert delete_trial_expected_calls == client.delete_trial.mock_calls + + delete_trial_component_expected_calls = [ + unittest.mock.call(TrialComponentName="trial-component-1"), + unittest.mock.call(TrialComponentName="trial-component-2"), + unittest.mock.call(TrialComponentName="trial-component-3"), + unittest.mock.call(TrialComponentName="trial-component-4"), + ] + assert delete_trial_component_expected_calls == client.delete_trial_component.mock_calls + + +def test_delete_all_fail(sagemaker_session): + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + sagemaker_session.sagemaker_client.list_trials.side_effect = Exception + with pytest.raises(Exception) as e: + obj._delete_all(action="--force") + + assert str(e.value) == "Failed to delete, please try again." diff --git a/tests/unit/sagemaker/experiments/test_helper.py b/tests/unit/sagemaker/experiments/test_helper.py new file mode 100644 index 0000000000..a11f67389b --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_helper.py @@ -0,0 +1,195 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os +import shutil +import tempfile + +from mock import Mock, PropertyMock, call +import pytest + +from src.sagemaker.experiments._helper import ( + _LineageArtifactTracker, + _ArtifactUploader, +) +from src.sagemaker.experiments._utils import resolve_artifact_name +from src.sagemaker.session import Session + + +@pytest.fixture +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value="DummyRole") + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name="us-west-2") + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + +@pytest.fixture +def sagemaker_session(client, boto_session): + return Session( + sagemaker_client=client, + boto_session=boto_session, + ) + + +@pytest.fixture +def lineage_artifact_tracker(sagemaker_session): + return _LineageArtifactTracker("test_trial_component_arn", sagemaker_session) + + +def test_lineage_artifact_tracker(lineage_artifact_tracker, sagemaker_session): + client = sagemaker_session.sagemaker_client + lineage_artifact_tracker.add_input_artifact( + "input_name", "input_source_uri", "input_etag", "text/plain" + ) + lineage_artifact_tracker.add_output_artifact( + "output_name", "output_source_uri", "output_etag", "text/plain" + ) + client.create_artifact.side_effect = [ + {"ArtifactArn": "created_arn_1"}, + {"ArtifactArn": "created_arn_2"}, + ] + + lineage_artifact_tracker.save() + + expected_calls = [ + call( + ArtifactName="input_name", + ArtifactType="text/plain", + Source={ + "SourceUri": "input_source_uri", + "SourceTypes": [{"SourceIdType": "S3ETag", "Value": "input_etag"}], + }, + ), + call( + ArtifactName="output_name", + ArtifactType="text/plain", + Source={ + "SourceUri": "output_source_uri", + "SourceTypes": [{"SourceIdType": "S3ETag", "Value": "output_etag"}], + }, + ), + ] + assert expected_calls == client.create_artifact.mock_calls + + expected_calls = [ + call( + SourceArn="created_arn_1", + DestinationArn="test_trial_component_arn", + AssociationType="ContributedTo", + ), + call( + SourceArn="test_trial_component_arn", + DestinationArn="created_arn_2", + AssociationType="Produced", + ), + ] + assert expected_calls == client.add_association.mock_calls + + +@pytest.fixture +def artifact_uploader(sagemaker_session): + return _ArtifactUploader( + trial_component_name="trial_component_name", + artifact_bucket="artifact_bucket", + artifact_prefix="artifact_prefix", + sagemaker_session=sagemaker_session, + ) + + +@pytest.fixture +def tempdir(): + tmp_dir = tempfile.mkdtemp() + yield tmp_dir + shutil.rmtree(tmp_dir) + + +def test_artifact_uploader_init(artifact_uploader): + assert "trial_component_name" == artifact_uploader.trial_component_name + assert "artifact_bucket" == artifact_uploader.artifact_bucket + assert "artifact_prefix" == artifact_uploader.artifact_prefix + + +def test_artifact_uploader_upload_artifact_file_not_exists(tempdir, artifact_uploader): + not_exist_file = os.path.join(tempdir, "not.exists") + with pytest.raises(ValueError) as error: + artifact_uploader.upload_artifact(not_exist_file) + assert "does not exist or is not a file" in str(error) + + +def test_artifact_uploader_upload_artifact(tempdir, artifact_uploader): + path = os.path.join(tempdir, "exists") + with open(path, "a") as f: + f.write("boo") + + name = resolve_artifact_name(path) + artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"} + + s3_uri, etag = artifact_uploader.upload_artifact(path) + expected_key = "{}/{}/{}".format( + artifact_uploader.artifact_prefix, artifact_uploader.trial_component_name, name + ) + + artifact_uploader._s3_client.upload_file.assert_called_with( + path, artifact_uploader.artifact_bucket, expected_key + ) + + expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key) + assert expected_uri == s3_uri + + +def test_artifact_uploader_upload_object_artifact(tempdir, artifact_uploader): + artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"} + + artifact_name = "my-artifact" + artifact_object = {"key": "value"} + file_extension = ".csv" + s3_uri, etag = artifact_uploader.upload_object_artifact( + artifact_name, artifact_object, file_extension + ) + name = artifact_name + file_extension + expected_key = "{}/{}/{}".format( + artifact_uploader.artifact_prefix, artifact_uploader.trial_component_name, name + ) + + artifact_uploader._s3_client.put_object.assert_called_with( + Body=json.dumps(artifact_object), Bucket=artifact_uploader.artifact_bucket, Key=expected_key + ) + + expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key) + assert expected_uri == s3_uri diff --git a/tests/unit/sagemaker/experiments/test_metrics.py b/tests/unit/sagemaker/experiments/test_metrics.py new file mode 100644 index 0000000000..21556f70fd --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_metrics.py @@ -0,0 +1,178 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import pytest +import tempfile +import shutil +import datetime +import dateutil +import json +import time + +from sagemaker.experiments._metrics import ( + _RawMetricData, + _SageMakerFileMetricsWriter, + SageMakerMetricsWriterException, +) + + +@pytest.fixture +def tempdir(): + dir = tempfile.mkdtemp() + yield dir + shutil.rmtree(dir) + + +@pytest.fixture +def filepath(tempdir): + return os.path.join(tempdir, "foo.json") + + +@pytest.fixture +def timestamp(): + return datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1) + + +def test_raw_metric_data_utc_timestamp(): + utcnow = datetime.datetime.now(datetime.timezone.utc) + assert utcnow.tzinfo + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=utcnow) + assert utcnow.timestamp() == metric.Timestamp + + +def test_raw_metric_data_utc_(): + utcnow = datetime.datetime.now(datetime.timezone.utc) + assert utcnow.tzinfo + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=utcnow) + assert utcnow.timestamp() == metric.Timestamp + + +def test_raw_metric_data_aware_timestamp(): + aware_datetime = datetime.datetime.now(dateutil.tz.gettz("America/Chicago")) + assert aware_datetime.tzinfo + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=aware_datetime) + assert (aware_datetime - aware_datetime.utcoffset()).replace( + tzinfo=datetime.timezone.utc + ).timestamp() == metric.Timestamp + + +def test_raw_metric_data_naive_timestamp(): + naive_datetime = datetime.datetime.now() + assert naive_datetime.tzinfo is None + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=naive_datetime) + local_datetime = naive_datetime.replace(tzinfo=dateutil.tz.tzlocal()) + assert (local_datetime - local_datetime.utcoffset()).replace( + tzinfo=datetime.timezone.utc + ).timestamp() == metric.Timestamp + + +def test_raw_metric_data_number_timestamp(): + time_now = time.time() + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=time_now) + assert time_now == metric.Timestamp + + +def test_raw_metric_data_request_item(): + time_now = time.time() + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=time_now, step=10) + expected = { + "MetricName": "foo", + "Value": 1.0, + "Timestamp": str(int(time_now)), + "Step": 10, + } + assert expected == metric.to_raw_metric_data() + + +def test_raw_metric_data_invalid_timestamp(): + with pytest.raises(ValueError) as error1: + _RawMetricData(metric_name="IFail", value=100, timestamp=time.time() - 2000000) + assert "Timestamps must be between two weeks before and two hours from now" in str(error1) + + with pytest.raises(ValueError) as error2: + _RawMetricData(metric_name="IFail", value=100, timestamp=time.time() + 10000) + assert "Timestamps must be between two weeks before and two hours from now" in str(error2) + + +def test_file_metrics_writer_log_metric(timestamp, filepath): + now = datetime.datetime.now(datetime.timezone.utc) + writer = _SageMakerFileMetricsWriter(filepath) + writer.log_metric(metric_name="foo", value=1.0) + writer.log_metric(metric_name="foo", value=2.0, step=1) + writer.log_metric(metric_name="foo", value=3.0, timestamp=timestamp) + writer.log_metric(metric_name="foo", value=4.0, timestamp=timestamp, step=2) + writer.close() + + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one, entry_two, entry_three, entry_four] = [json.loads(line) for line in lines] + + assert "foo" == entry_one["MetricName"] + assert 1.0 == entry_one["Value"] + assert (now.timestamp() - entry_one["Timestamp"]) < 1 + assert "Step" not in entry_one + + assert 1 == entry_two["Step"] + assert timestamp.timestamp() == entry_three["Timestamp"] + assert 2 == entry_four["Step"] + + +def test_file_metrics_writer_flushes_buffer_every_line_log_metric(filepath): + writer = _SageMakerFileMetricsWriter(filepath) + + writer.log_metric(metric_name="foo", value=1.0) + + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one] = [json.loads(line) for line in lines] + assert "foo" == entry_one["MetricName"] + assert 1.0 == entry_one["Value"] + + writer.log_metric(metric_name="bar", value=2.0) + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one, entry_two] = [json.loads(line) for line in lines] + assert "bar" == entry_two["MetricName"] + assert 2.0 == entry_two["Value"] + + writer.log_metric(metric_name="biz", value=3.0) + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one, entry_two, entry_three] = [json.loads(line) for line in lines] + assert "biz" == entry_three["MetricName"] + assert 3.0 == entry_three["Value"] + + writer.close() + + +def test_file_metrics_writer_context_manager(timestamp, filepath): + with _SageMakerFileMetricsWriter(filepath) as writer: + writer.log_metric("foo", value=1.0, timestamp=timestamp) + entry = json.loads(open(filepath, "r").read().strip()) + assert { + "MetricName": "foo", + "Value": 1.0, + "Timestamp": timestamp.timestamp(), + }.items() <= entry.items() + + +def test_file_metrics_writer_fail_write_on_close(filepath): + writer = _SageMakerFileMetricsWriter(filepath) + writer.log_metric(metric_name="foo", value=1.0) + writer.close() + with pytest.raises(SageMakerMetricsWriterException): + writer.log_metric(metric_name="foo", value=1.0) + + +def test_file_metrics_writer_no_write(filepath): + writer = _SageMakerFileMetricsWriter(filepath) + writer.close() + assert not os.path.exists(filepath) diff --git a/tests/unit/sagemaker/experiments/test_run.py b/tests/unit/sagemaker/experiments/test_run.py new file mode 100644 index 0000000000..0e4ebee181 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_run.py @@ -0,0 +1,941 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import unittest +from math import inf, nan +from unittest.mock import patch, Mock, MagicMock + +import dateutil +import pytest + +from sagemaker.experiments import _environment, SortOrderType +from sagemaker.experiments._api_types import ( + TrialComponentArtifact, + TrialComponentSummary, + TrialComponentStatus, + _TrialComponentStatusType, + TrialComponentSearchResult, +) +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import ( + TRIAL_NAME_TEMPLATE, + MAX_RUN_TC_ARTIFACTS_LEN, + MAX_NAME_LEN_IN_BACKEND, + EXPERIMENT_NAME, + RUN_NAME, + TRIAL_NAME, + DELIMITER, + RUN_TC_TAG, + SortByType, +) +from sagemaker.experiments import Run, load_run, list_runs +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent +from tests.unit.sagemaker.experiments.helpers import ( + mock_trial_load_or_create_func, + mock_tc_load_or_create_func, + TEST_EXP_NAME, + TEST_RUN_NAME, +) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch.object(_TrialComponent, "save") +def test_run_init(mock_tc_save, sagemaker_session): + with Run( + experiment_name=TEST_EXP_NAME, run_name=TEST_RUN_NAME, sagemaker_session=sagemaker_session + ) as run_obj: + assert not run_obj._in_load + assert not run_obj._inside_load_context + assert run_obj._inside_init_context + assert not run_obj._trial_component.parameters + + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + assert run_obj.experiment_name == TEST_EXP_NAME + assert run_obj.run_name == TEST_RUN_NAME + assert run_obj.run_group_name == TRIAL_NAME_TEMPLATE.format(TEST_EXP_NAME) + assert run_obj._trial_component.trial_component_name == expected_tc_name + assert run_obj._trial.trial_name == TRIAL_NAME_TEMPLATE.format(TEST_EXP_NAME) + assert run_obj._experiment.experiment_name == TEST_EXP_NAME + assert run_obj.experiment_config == { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: run_obj.run_group_name, + RUN_NAME: expected_tc_name, + } + + # trail_component.save is called when entering/ exiting the with block + mock_tc_save.assert_called() + + +def test_run_init_name_length_exceed_limit(sagemaker_session): + invalid_name = "x" * MAX_NAME_LEN_IN_BACKEND + + # experiment_name exceeds + with pytest.raises(ValueError) as err: + Run( + experiment_name=invalid_name, + run_name=TEST_RUN_NAME, + sagemaker_session=sagemaker_session, + ) + + assert ( + f"The experiment_name (length: {MAX_NAME_LEN_IN_BACKEND}) must have length less than" + in str(err) + ) + + # run_name exceeds + with pytest.raises(ValueError) as err: + Run( + experiment_name=TEST_EXP_NAME, + run_name=invalid_name, + sagemaker_session=sagemaker_session, + ) + + assert f"The run_name (length: {MAX_NAME_LEN_IN_BACKEND}) must have length less than" in str( + err + ) + + +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session): + client = sagemaker_session.sagemaker_client + job_name = "my-train-job" + rv = Mock() + rv.source_arn = f"arn:1234/{job_name}" + rv.environment_type = _environment._EnvironmentType.SageMakerTrainingJob + mock_run_env.load.return_value = rv + + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + exp_config = { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME), + RUN_NAME: expected_tc_name, + } + client.describe_training_job.return_value = { + "TrainingJobName": "train-job-experiments", + # The Run object has been created else where + "ExperimentConfig": exp_config, + } + with load_run(sagemaker_session=sagemaker_session) as run_obj: + assert run_obj._in_load + assert not run_obj._inside_init_context + assert run_obj._inside_load_context + assert run_obj.run_name == TEST_RUN_NAME + assert run_obj._trial_component.trial_component_name == expected_tc_name + assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) + assert run_obj._trial + assert run_obj.experiment_name == TEST_EXP_NAME + assert run_obj._experiment + assert run_obj.experiment_config == exp_config + + client.describe_training_job.assert_called_once_with(TrainingJobName=job_name) + + +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_no_run_name_and_in_train_job_but_fail_to_get_exp_cfg( + mock_run_env, sagemaker_session +): + rv = Mock() + rv.source_arn = "arn:1234/my-train-job" + rv.environment_type = _environment._EnvironmentType.SageMakerTrainingJob + mock_run_env.load.return_value = rv + + # No Run object is created else where + sagemaker_session.sagemaker_client.describe_training_job.return_value = { + "TrainingJobName": "train-job-experiments", + } + + with pytest.raises(RuntimeError) as err: + with load_run(sagemaker_session=sagemaker_session): + pass + + assert "Not able to fetch RunName in ExperimentConfig of the sagemaker job" in str(err) + + +def test_run_load_no_run_name_and_not_in_train_job(run_obj, sagemaker_session): + with run_obj: + with load_run(sagemaker_session=sagemaker_session) as run: + assert run_obj == run + + +def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemaker_session): + with pytest.raises(RuntimeError) as err: + with load_run(sagemaker_session=sagemaker_session): + pass + + assert "Failed to load a Run object" in str(err) + + # experiment_name is given but is not supplied along with the run_name so it's ignored. + with pytest.raises(RuntimeError) as err: + with load_run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session): + pass + + assert "Failed to load a Run object" in str(err) + + +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +def test_run_load_with_run_name_and_exp_name(sagemaker_session): + with load_run( + run_name=TEST_RUN_NAME, + experiment_name=TEST_EXP_NAME, + sagemaker_session=sagemaker_session, + ) as run_obj: + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + expected_exp_config = { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME), + RUN_NAME: expected_tc_name, + } + + assert run_obj.run_name == TEST_RUN_NAME + assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) + assert run_obj.experiment_name == TEST_EXP_NAME + assert run_obj._trial_component.trial_component_name == expected_tc_name + assert run_obj._trial + assert run_obj._experiment + assert run_obj.experiment_config == expected_exp_config + + +def test_run_load_with_run_name_but_no_exp_name(sagemaker_session): + with pytest.raises(ValueError) as err: + with load_run( + run_name=TEST_RUN_NAME, + sagemaker_session=sagemaker_session, + ): + pass + + assert "Invalid input: experiment_name is missing" in str(err) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session): + client = sagemaker_session.sagemaker_client + job_name = "my-process-job" + rv = unittest.mock.Mock() + rv.source_arn = f"arn:1234/{job_name}" + rv.environment_type = _environment._EnvironmentType.SageMakerProcessingJob + mock_run_env.load.return_value = rv + + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + exp_config = { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME), + RUN_NAME: expected_tc_name, + } + client.describe_processing_job.return_value = { + "ProcessingJobName": "process-job-experiments", + # The Run object has been created else where + "ExperimentConfig": exp_config, + } + + with load_run(sagemaker_session=sagemaker_session): + pass + + client.describe_processing_job.assert_called_once_with(ProcessingJobName=job_name) + + +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session): + # TODO: update this test once figure out how to get source_arn from transform job + rv = unittest.mock.Mock() + rv.environment_type = _environment._EnvironmentType.SageMakerTransformJob + rv.source_arn = "" + mock_run_env.load.return_value = rv + + with pytest.raises(RuntimeError) as err: + with load_run(sagemaker_session=sagemaker_session): + pass + + assert ( + "loading experiment config from transform job environment is not currently supported" + ) in str(err) + + +def test_log_parameter_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_parameter("foo", "bar") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_parameter(run_obj): + with run_obj: + run_obj.log_parameter("foo", "bar") + assert run_obj._trial_component.parameters["foo"] == "bar" + run_obj.log_parameter("whizz", 1) + assert run_obj._trial_component.parameters["whizz"] == 1 + + +def test_log_parameter_skip_invalid_value(run_obj): + with run_obj: + run_obj.log_parameter("key", nan) + assert "key" not in run_obj._trial_component.parameters + + +def test_log_parameters_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_parameters({"a": "b", "c": "d", "e": 5}) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_parameters(run_obj): + with run_obj: + run_obj.log_parameters({"a": "b", "c": "d", "e": 5}) + assert run_obj._trial_component.parameters == {"a": "b", "c": "d", "e": 5} + + +def test_log_parameters_skip_invalid_values(run_obj): + with run_obj: + run_obj.log_parameters({"a": "b", "c": "d", "e": 5, "f": nan}) + assert run_obj._trial_component.parameters == {"a": "b", "c": "d", "e": 5} + + +def test_log_input_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_artifact("foo", "baz", "text/text", False) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_input(run_obj): + with run_obj: + run_obj.log_artifact("foo", "baz", "text/text", False) + assert run_obj._trial_component.input_artifacts == { + "foo": TrialComponentArtifact(value="baz", media_type="text/text") + } + + +def test_log_output_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_artifact("foo", "baz", "text/text") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_output(run_obj): + with run_obj: + run_obj.log_artifact("foo", "baz", "text/text") + assert run_obj._trial_component.output_artifacts == { + "foo": TrialComponentArtifact(value="baz", media_type="text/text") + } + + +def test_log_metric_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_metric(name="foo", value=1.0, step=1) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_metric(run_obj): + now = datetime.datetime.now() + with run_obj: + run_obj.log_metric(name="foo", value=1.0, step=1, timestamp=now) + run_obj._metrics_manager.log_metric.assert_called_with( + metric_name="foo", value=1.0, step=1, timestamp=now + ) + + +def test_log_metric_skip_invalid_value(run_obj): + with run_obj: + run_obj.log_metric(None, nan, None, None) + assert not run_obj._metrics_manager.log_metric.called + + +def test_log_metric_attribute_error(run_obj): + now = datetime.datetime.now() + with run_obj: + run_obj._metrics_manager.log_metric.side_effect = AttributeError + + with pytest.raises(AttributeError): + run_obj.log_metric("foo", 1.0, 1, now) + + +def test_log_output_artifact_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_file("foo.txt", "name", "whizz/bang") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_output_artifact(run_obj): + run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value") + with run_obj: + run_obj.log_file("foo.txt", "name", "whizz/bang") + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "whizz/bang" == run_obj._trial_component.output_artifacts["name"].media_type + + run_obj.log_file("foo.txt") + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "foo.txt" in run_obj._trial_component.output_artifacts + assert "text/plain" == run_obj._trial_component.output_artifacts["foo.txt"].media_type + + +def test_log_input_artifact_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_input_artifact(run_obj): + run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value") + with run_obj: + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "whizz/bang" == run_obj._trial_component.input_artifacts["name"].media_type + + run_obj.log_file("foo.txt", is_output=False) + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "foo.txt" in run_obj._trial_component.input_artifacts + assert "text/plain" == run_obj._trial_component.input_artifacts["foo.txt"].media_type + + +def test_log_multiple_inputs(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._trial_component.input_artifacts[file_path] = { + "foo": TrialComponentArtifact(value="baz" + str(index), media_type="text/text") + } + with pytest.raises(ValueError) as error: + run_obj.log_artifact("foo.txt", "name", "whizz/bang", False) + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts" in str(error) + + +def test_log_multiple_outputs(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._trial_component.output_artifacts[file_path] = { + "foo": TrialComponentArtifact(value="baz" + str(index), media_type="text/text") + } + with pytest.raises(ValueError) as error: + run_obj.log_artifact("foo.txt", "name", "whizz/bang") + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts" in str(error) + + +def test_log_multiple_input_artifacts(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value" + str(index), + "etag_value" + str(index), + ) + run_obj.log_file( + file_path, "name" + str(index), "whizz/bang" + str(index), is_output=False + ) + run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path) + + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + + # log an output artifact, should be fine + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=True) + + # log an extra input artifact, should raise exception + with pytest.raises(ValueError) as error: + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts" in str(error) + + +def test_log_multiple_output_artifacts(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value" + str(index), + "etag_value" + str(index), + ) + run_obj.log_file(file_path, "name" + str(index), "whizz/bang" + str(index)) + run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path) + + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + + # log an input artifact, should be fine + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + + # log an extra output artifact, should raise exception + with pytest.raises(ValueError) as error: + run_obj.log_file("foo.txt", "name", "whizz/bang") + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts" in str(error) + + +def test_log_precision_recall_outside_run_context(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + no_skill = 0.1 + title = "TestPrecisionRecall" + + with pytest.raises(RuntimeError) as err: + run_obj.log_precision_recall( + y_true, y_scores, 0, title=title, no_skill=no_skill, is_output=False + ) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_precision_recall(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + no_skill = 0.1 + title = "TestPrecisionRecall" + + run_obj._artifact_uploader.upload_object_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + with run_obj: + run_obj.log_precision_recall( + y_true, y_scores, 0, title=title, no_skill=no_skill, is_output=False + ) + + expected_data = { + "type": "PrecisionRecallCurve", + "version": 0, + "title": title, + "precision": [0.5, 0.3333333333333333, 0.5, 0.0, 1.0], + "recall": [1.0, 0.5, 0.5, 0.0, 0.0], + "averagePrecisionScore": 0.5, + "noSkill": 0.1, + } + run_obj._artifact_uploader.upload_object_artifact.assert_called_with( + title, expected_data, file_extension="json" + ) + + run_obj._lineage_artifact_tracker.add_input_artifact.assert_called_with( + name=title, + source_uri="s3uri_value", + etag="etag_value", + artifact_type="PrecisionRecallCurve", + ) + + +def test_log_precision_recall_invalid_input(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35] + no_skill = 0.1 + + with run_obj: + with pytest.raises(ValueError) as error: + run_obj.log_precision_recall( + y_true, y_scores, 0, title="TestPrecisionRecall", no_skill=no_skill, is_output=False + ) + assert "Lengths mismatch between true labels and predicted probabilities" in str(error) + + +def test_log_confusion_matrix_outside_run_context(run_obj): + y_true = [2, 0, 2, 2, 0, 1] + y_pred = [0, 0, 2, 2, 0, 2] + + with pytest.raises(RuntimeError) as err: + run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_confusion_matrix(run_obj): + y_true = [2, 0, 2, 2, 0, 1] + y_pred = [0, 0, 2, 2, 0, 2] + + run_obj._artifact_uploader.upload_object_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + with run_obj: + run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix") + + expected_data = { + "type": "ConfusionMatrix", + "version": 0, + "title": "TestConfusionMatrix", + "confusionMatrix": [[2, 0, 0], [0, 0, 1], [1, 0, 2]], + } + + run_obj._artifact_uploader.upload_object_artifact.assert_called_with( + "TestConfusionMatrix", expected_data, file_extension="json" + ) + + run_obj._lineage_artifact_tracker.add_output_artifact.assert_called_with( + name="TestConfusionMatrix", + source_uri="s3uri_value", + etag="etag_value", + artifact_type="ConfusionMatrix", + ) + + +def test_log_confusion_matrix_invalid_input(run_obj): + y_true = [2, 0, 2, 2, 0, 1] + y_pred = [0, 0, 2, 2, 0] + + with run_obj: + with pytest.raises(ValueError) as error: + run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix") + assert "Lengths mismatch between true labels and predicted labels" in str(error) + + +def test_log_roc_curve_outside_run_context(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + + with pytest.raises(RuntimeError) as err: + run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_roc_curve(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + with run_obj: + run_obj._artifact_uploader.upload_object_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + + run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False) + + expected_data = { + "type": "ROCCurve", + "version": 0, + "title": "TestROCCurve", + "falsePositiveRate": [0.0, 0.0, 0.5, 0.5, 1.0], + "truePositiveRate": [0.0, 0.5, 0.5, 1.0, 1.0], + "areaUnderCurve": 0.75, + } + run_obj._artifact_uploader.upload_object_artifact.assert_called_with( + "TestROCCurve", expected_data, file_extension="json" + ) + + run_obj._lineage_artifact_tracker.add_input_artifact.assert_called_with( + name="TestROCCurve", + source_uri="s3uri_value", + etag="etag_value", + artifact_type="ROCCurve", + ) + + +def test_log_roc_curve_invalid_input(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35] + + with run_obj: + with pytest.raises(ValueError) as error: + run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False) + assert "Lengths mismatch between true labels and predicted scores" in str(error) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._TrialComponent._load_or_create") +@patch("sagemaker.experiments.run._TrialComponent.list") +@patch("sagemaker.experiments.run._TrialComponent.search") +def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_session): + start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) + end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2) + creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3) + last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4) + tc_list_len = 20 + tc_list_len_half = int(tc_list_len / 2) + mock_tc_search.side_effect = [ + [ + TrialComponentSearchResult( + trial_component_name=Run._generate_trial_component_name( + "a" + str(i), TEST_EXP_NAME + ), + trial_component_arn="b" + str(i), + display_name="C" + str(i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + tags=[RUN_TC_TAG] if i < tc_list_len_half else None, + ) + ] + for i in range(tc_list_len) + ] + mock_tc_list.return_value = [ + TrialComponentSummary( + trial_component_name=Run._generate_trial_component_name("A" + str(i), TEST_EXP_NAME), + trial_component_arn="b" + str(i), + display_name="C" + str(i), + source_arn="D" + str(i), + status=TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + ), + start_time=start_time + datetime.timedelta(hours=i), + end_time=end_time + datetime.timedelta(hours=i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + ) + for i in range(tc_list_len) + ] + mock_tc_load.side_effect = [ + ( + _TrialComponent( + trial_component_name=Run._generate_trial_component_name( + "a" + str(i), TEST_EXP_NAME + ), + trial_component_arn="b" + str(i), + display_name="C" + str(i), + source_arn="D" + str(i), + status=TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + ), + start_time=start_time + datetime.timedelta(hours=i), + end_time=end_time + datetime.timedelta(hours=i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + ), + True, + ) + for i in range(tc_list_len_half) + ] + + run_list = list_runs( + experiment_name=TEST_EXP_NAME, + sort_by=SortByType.CREATION_TIME, + sort_order=SortOrderType.ASCENDING, + sagemaker_session=sagemaker_session, + ) + + mock_tc_list.assert_called_once_with( + experiment_name=TEST_EXP_NAME, + created_before=None, + created_after=None, + sort_by="CreationTime", + sort_order="Ascending", + sagemaker_session=sagemaker_session, + max_results=None, + next_token=None, + ) + assert len(run_list) == tc_list_len_half + for i in range(tc_list_len_half): + run = run_list[i] + assert run.experiment_name == TEST_EXP_NAME + assert run.run_name == "a" + str(i) + assert run._experiment + assert run._trial + assert isinstance(run._trial_component, _TrialComponent) + assert run._trial_component.trial_component_name == Run._generate_trial_component_name( + "a" + str(i), TEST_EXP_NAME + ) + assert run._in_load is False + assert run._inside_load_context is False + assert run._inside_init_context is False + assert run._artifact_uploader + assert run._lineage_artifact_tracker + assert run._metrics_manager + + +@patch("sagemaker.experiments.run._TrialComponent.list") +def test_list_empty(mock_tc_list, sagemaker_session): + mock_tc_list.return_value = [] + assert [] == list_runs(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._TrialComponent._load_or_create") +def test_enter_exit_locally(mock_load_tc, sagemaker_session, run_obj): + mock_load_tc.return_value = run_obj._trial_component, False + sagemaker_session.sagemaker_client.update_trial_component.return_value = {} + _verify_tc_status_before_enter_init(run_obj._trial_component) + + with run_obj: + _verify_tc_status_when_entering(run_obj._trial_component) + init_start_time = run_obj._trial_component.start_time + + with load_run(sagemaker_session=sagemaker_session): + _verify_tc_status_when_entering( + trial_component=run_obj._trial_component, + init_start_time=init_start_time, + ) + + old_end_time = _verify_tc_status_when_successfully_exit( + trial_component=run_obj._trial_component, + ) + + old_end_time = _verify_tc_status_when_successfully_exit( + trial_component=run_obj._trial_component, + old_end_time=old_end_time, + ) + + # Re-load to verify: + # 1. if it works when load_run and with are not in one line + # 2. if re-entering the load will change the "Completed" TC status + # to "InProgress" + # 3. when exiting the load, the end_time and status will be overridden again + run_load = load_run( + experiment_name=run_obj.experiment_name, + run_name=run_obj.run_name, + sagemaker_session=sagemaker_session, + ) + with run_load: + _verify_tc_status_when_entering( + trial_component=run_obj._trial_component, + init_start_time=init_start_time, + has_completed=True, + ) + _verify_tc_status_when_successfully_exit( + trial_component=run_obj._trial_component, old_end_time=old_end_time + ) + + +def test_exit_fail(sagemaker_session, run_obj): + sagemaker_session.sagemaker_client.update_trial_component.return_value = {} + try: + with run_obj: + raise ValueError("Foo") + except ValueError: + pass + + assert run_obj._trial_component.status.primary_status == _TrialComponentStatusType.Failed.value + assert run_obj._trial_component.status.message + assert isinstance(run_obj._trial_component.end_time, datetime.datetime) + + +@pytest.mark.parametrize( + "metric_value", + [1.3, "nan", "inf", "-inf", None], +) +def test_is_input_valid(run_obj, metric_value): + assert run_obj._is_input_valid("metric", "Name", metric_value) + + +@pytest.mark.parametrize( + "metric_value", + [nan, inf, -inf], +) +def test_is_input_valid_false(run_obj, metric_value): + assert not run_obj._is_input_valid("parameter", "Name", metric_value) + + +def test_generate_trial_name(): + base_name = "x" * MAX_NAME_LEN_IN_BACKEND + trial_name = Run._generate_trial_name(base_name=base_name) + assert len(trial_name) <= MAX_NAME_LEN_IN_BACKEND + + +def test_append_run_tc_label_to_tags(): + expected_tc_tag = RUN_TC_TAG + + tags = None + ret = Run._append_run_tc_label_to_tags(tags) + assert len(ret) == 1 + assert expected_tc_tag in ret + + tags = [] + ret = Run._append_run_tc_label_to_tags(tags) + assert len(ret) == 1 + assert expected_tc_tag in ret + + tags = [{"Key": "foo", "Value": "bar"}] + ret = Run._append_run_tc_label_to_tags(tags) + assert len(ret) == 2 + assert expected_tc_tag in ret + + +def _verify_tc_status_before_enter_init(trial_component): + assert not trial_component.start_time + assert not trial_component.end_time + assert not trial_component.status + + +def _verify_tc_status_when_entering(trial_component, init_start_time=None, has_completed=False): + if not init_start_time: + assert isinstance(trial_component.start_time, datetime.datetime) + now = datetime.datetime.now(dateutil.tz.tzlocal()) + assert (now.timestamp() - trial_component.start_time.timestamp()) < 1 + else: + assert trial_component.start_time == init_start_time + + if not has_completed: + assert not trial_component.end_time + assert trial_component.status.primary_status == _TrialComponentStatusType.InProgress.value + + +def _verify_tc_status_when_successfully_exit(trial_component, old_end_time=None): + assert trial_component.status.primary_status == _TrialComponentStatusType.Completed.value + assert isinstance(trial_component.start_time, datetime.datetime) + assert isinstance(trial_component.end_time, datetime.datetime) + if old_end_time: + assert trial_component.end_time > old_end_time + return trial_component.end_time diff --git a/tests/unit/sagemaker/experiments/test_run_context.py b/tests/unit/sagemaker/experiments/test_run_context.py new file mode 100644 index 0000000000..7e068136a1 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_run_context.py @@ -0,0 +1,191 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import patch, MagicMock + +import pytest + +from sagemaker.estimator import Estimator, _TrainingJob +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import _RunContext +from sagemaker.experiments import load_run, Run +from sagemaker.experiments.trial import _Trial +from tests.unit.sagemaker.experiments.helpers import ( + TEST_EXP_NAME, + mock_trial_load_or_create_func, + mock_tc_load_or_create_func, +) + +_bucket = "my-bucket" +_train_input_path = f"s3://{_bucket}/data.csv" +_train_output_path = f"s3://{_bucket}" + + +@patch.object(_TrainingJob, "start_new") +def test_auto_pass_in_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_session): + mock_start_job.return_value = _TrainingJob(sagemaker_session, "my-job") + with run_obj: + estimator = Estimator( + role="arn:my-role", + image_uri="my-image", + sagemaker_session=sagemaker_session, + output_path=_train_output_path, + ) + estimator.fit( + inputs=_train_input_path, + wait=False, + ) + + assert _RunContext.get_current_run() == run_obj + + expected_exp_config = run_obj.experiment_config + mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config) + + # _RunContext is cleaned up after exiting the with statement + assert not _RunContext.get_current_run() + + +@patch.object(_TrainingJob, "start_new") +def test_user_supply_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_session): + mock_start_job.return_value = _TrainingJob(sagemaker_session, "my-job") + supplied_exp_cfg = { + "ExperimentName": "my-supplied-exp-name", + "TrialName": "my-supplied-run-group-name", + "RunName": "my-supplied-run-name", + } + with run_obj: + estimator = Estimator( + role="arn:my-role", + image_uri="my-image", + sagemaker_session=sagemaker_session, + output_path=_train_output_path, + ) + estimator.fit( + experiment_config=supplied_exp_cfg, + inputs=_train_input_path, + wait=False, + ) + + assert _RunContext.get_current_run() == run_obj + + mock_start_job.assert_called_once_with(estimator, _train_input_path, supplied_exp_cfg) + + # _RunContext is cleaned up after exiting the with statement + assert not _RunContext.get_current_run() + + +def test_auto_fetch_created_run_obj_from_context(run_obj, sagemaker_session): + assert not run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert not _RunContext.get_current_run() + + def train(): + with load_run(sagemaker_session=sagemaker_session) as run_load: + assert run_load == run_obj + assert run_obj._inside_init_context + assert run_obj._inside_load_context + assert run_obj._in_load + + run_load.log_parameter("foo", "bar") + run_load.log_parameter("whizz", 1) + + with run_obj: + assert run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert _RunContext.get_current_run() + + train() + + assert run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert _RunContext.get_current_run() + + run_obj.log_parameters({"a": "b", "c": 2}) + + assert run_obj._trial_component.parameters["foo"] == "bar" + assert run_obj._trial_component.parameters["whizz"] == 1 + assert run_obj._trial_component.parameters["a"] == "b" + assert run_obj._trial_component.parameters["c"] == 2 + + # Verify separating load_run and with statement in different lines still work + run_load2 = load_run(sagemaker_session=sagemaker_session) + with run_load2: + assert run_load2 == run_obj + assert run_obj._inside_init_context + assert run_obj._inside_load_context + assert run_obj._in_load + + assert run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert _RunContext.get_current_run() + + assert not run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert not _RunContext.get_current_run() + + +def test_nested_run_init_context_on_same_run_object(run_obj, sagemaker_session): + assert not _RunContext.get_current_run() + + with pytest.raises(RuntimeError) as err: + with run_obj: + assert _RunContext.get_current_run() + + with run_obj: + pass + assert "It is not allowed to use nested 'with' statements on the Run" in str(err) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +def test_nested_run_init_context_on_different_run_object(run_obj, sagemaker_session): + assert not _RunContext.get_current_run() + + with pytest.raises(RuntimeError) as err: + with Run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session): + assert _RunContext.get_current_run() + + with run_obj: + pass + assert "It is not allowed to use nested 'with' statements on the Run" in str(err) + + +def test_nested_run_load_context(run_obj, sagemaker_session): + assert not _RunContext.get_current_run() + + with pytest.raises(RuntimeError) as err: + with run_obj: + assert _RunContext.get_current_run() + + with load_run(): + run_load = load_run() + with run_load: + pass + assert "It is not allowed to use nested 'with' statements on the load_run" in str(err) diff --git a/tests/unit/sagemaker/experiments/test_trial.py b/tests/unit/sagemaker/experiments/test_trial.py new file mode 100644 index 0000000000..f6996fefc3 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_trial.py @@ -0,0 +1,276 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +import datetime + +from unittest.mock import patch + +from sagemaker import Session +from sagemaker.experiments._api_types import TrialSummary +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + + +@pytest.fixture +def datetime_obj(): + return datetime.datetime(2017, 6, 16, 15, 55, 0) + + +def test_load(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.describe_trial.return_value = {"ExperimentName": "experiment-name-value"} + trial_obj = _Trial.load(trial_name="name-value", sagemaker_session=sagemaker_session) + assert trial_obj.trial_name == "name-value" + assert trial_obj.experiment_name == "experiment-name-value" + client.describe_trial.assert_called_with(TrialName="name-value") + + +def test_create(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial.return_value = { + "Arn": "arn:aws:1234", + "TrialName": "name-value", + } + trial_obj = _Trial.create( + trial_name="name-value", + experiment_name="experiment-name-value", + sagemaker_session=sagemaker_session, + ) + assert trial_obj.trial_name == "name-value" + client.create_trial.assert_called_with( + TrialName="name-value", ExperimentName="experiment-name-value" + ) + + +def test_create_with_tags(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial.return_value = { + "Arn": "arn:aws:1234", + "TrialName": "name-value", + } + tags = [{"Key": "foo", "Value": "bar"}] + trial_obj = _Trial.create( + trial_name="name-value", + experiment_name="experiment-name-value", + sagemaker_session=sagemaker_session, + tags=tags, + ) + assert trial_obj.trial_name == "name-value" + client.create_trial.assert_called_with( + TrialName="name-value", + ExperimentName="experiment-name-value", + Tags=[{"Key": "foo", "Value": "bar"}], + ) + + +def test_delete(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _Trial(sagemaker_session, trial_name="foo") + client.delete_trial.return_value = {} + obj.delete() + client.delete_trial.assert_called_with(TrialName="foo") + + +def test_save(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _Trial( + sagemaker_session, + trial_name="foo", + experiment_name="whizz", + display_name="bar", + tags=[{"Key": "foo", "Value": "bar"}], + ) + client.update_trial.return_value = {} + obj.save() + + client.update_trial.assert_called_with( + TrialName="foo", + DisplayName="bar", + ) + + +def test_add_trial_component(sagemaker_session): + client = sagemaker_session.sagemaker_client + trial = _Trial(sagemaker_session=sagemaker_session) + trial.trial_name = "bar" + trial.add_trial_component("foo") + client.associate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="foo") + + tc = _TrialComponent(trial_component_name="tc-foo", sagemaker_session=sagemaker_session) + trial.add_trial_component(tc) + client.associate_trial_component.assert_called_with( + TrialName="bar", TrialComponentName=tc.trial_component_name + ) + + +def test_remove_trial_component(sagemaker_session): + client = sagemaker_session.sagemaker_client + trial = _Trial(sagemaker_session=sagemaker_session) + trial.trial_name = "bar" + trial.remove_trial_component("foo") + client.disassociate_trial_component.assert_called_with( + TrialName="bar", TrialComponentName="foo" + ) + + tc = _TrialComponent(trial_component_name="tc-foo", sagemaker_session=sagemaker_session) + trial.remove_trial_component(tc) + client.disassociate_trial_component.assert_called_with( + TrialName="bar", TrialComponentName=tc.trial_component_name + ) + + +@patch("sagemaker.experiments.trial._Trial.load") +def test_load_or_create_when_exist(mock_load): + sagemaker_session = Session() + trial_name = "trial_name" + exp_name = "exp_name" + + # The trial exists and experiment matches + mock_load.return_value = _Trial( + trial_name=trial_name, + experiment_name=exp_name, + sagemaker_session=sagemaker_session, + ) + _Trial._load_or_create( + trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + mock_load.assert_called_once_with(trial_name, sagemaker_session) + + # The trial exists but experiment does not match + mock_load.return_value = _Trial( + trial_name=trial_name, + exp_name="another_exp_name", + sagemaker_session=sagemaker_session, + ) + with pytest.raises(ValueError) as err: + _Trial._load_or_create( + trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + assert "The given experiment_name {} does not match that in the loaded trial".format( + exp_name + ) in str(err) + + +@patch("sagemaker.experiments.trial._Trial.load") +@patch("sagemaker.experiments.trial._Trial.create") +def test_load_or_create_when_not_exist(mock_create, mock_load): + sagemaker_session = Session() + client = sagemaker_session.sagemaker_client + trial_name = "trial_name" + exp_name = "exp_name" + not_found_err = client.exceptions.ResourceNotFound( + error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, + operation_name="foo", + ) + mock_load.side_effect = not_found_err + + _Trial._load_or_create( + trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + + mock_create.assert_called_once_with( + trial_name=trial_name, + experiment_name=exp_name, + display_name=None, + tags=None, + sagemaker_session=sagemaker_session, + ) + + +def test_list_trials_without_experiment_name(sagemaker_session, datetime_obj): + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + expected = [ + TrialSummary( + trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + assert expected == list(_Trial.list(sagemaker_session=sagemaker_session)) + client.list_trials.assert_called_with(**{}) + + +def test_list_trials_with_experiment_name(sagemaker_session, datetime_obj): + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + expected = [ + TrialSummary( + trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + assert expected == list(_Trial.list(experiment_name="foo", sagemaker_session=sagemaker_session)) + client.list_trials.assert_called_with(ExperimentName="foo") + + +def test_list_trials_with_trial_component_name(sagemaker_session, datetime_obj): + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + expected = [ + TrialSummary( + trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + assert expected == list( + _Trial.list(trial_component_name="tc-foo", sagemaker_session=sagemaker_session) + ) + client.list_trials.assert_called_with(TrialComponentName="tc-foo") diff --git a/tests/unit/sagemaker/experiments/test_trial_component.py b/tests/unit/sagemaker/experiments/test_trial_component.py new file mode 100644 index 0000000000..c14663893e --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_trial_component.py @@ -0,0 +1,384 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import unittest.mock + +from unittest.mock import patch + +from sagemaker import Session +from sagemaker.experiments import _api_types +from sagemaker.experiments._api_types import ( + TrialComponentSearchResult, + Parent, + _TrialComponentStatusType, +) +from sagemaker.experiments.trial_component import _TrialComponent + + +def test_create(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial_component.return_value = { + "TrialComponentArn": "bazz", + } + obj = _TrialComponent.create( + trial_component_name="foo", display_name="bar", sagemaker_session=sagemaker_session + ) + client.create_trial_component.assert_called_with(TrialComponentName="foo", DisplayName="bar") + assert "foo" == obj.trial_component_name + assert "bar" == obj.display_name + assert "bazz" == obj.trial_component_arn + + +def test_create_with_tags(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial_component.return_value = { + "TrialComponentArn": "bazz", + } + tags = [{"Key": "foo", "Value": "bar"}] + _TrialComponent.create( + trial_component_name="foo", + display_name="bar", + sagemaker_session=sagemaker_session, + tags=tags, + ) + client.create_trial_component.assert_called_with( + TrialComponentName="foo", DisplayName="bar", Tags=tags + ) + + +def test_load(sagemaker_session): + now = datetime.datetime.now(datetime.timezone.utc) + client = sagemaker_session.sagemaker_client + client.describe_trial_component.return_value = { + "TrialComponentArn": "A", + "TrialComponentName": "B", + "DisplayName": "C", + "Status": {"PrimaryStatus": _TrialComponentStatusType.InProgress.value, "Message": "D"}, + "Parameters": {"E": {"NumberValue": 1.0}, "F": {"StringValue": "G"}}, + "InputArtifacts": {"H": {"Value": "s3://foo/bar", "MediaType": "text/plain"}}, + "OutputArtifacts": {"I": {"Value": "s3://whizz/bang", "MediaType": "text/plain"}}, + "Metrics": [ + { + "MetricName": "J", + "Count": 1, + "Min": 1.0, + "Max": 2.0, + "Avg": 3.0, + "StdDev": 4.0, + "SourceArn": "K", + "Timestamp": now, + } + ], + } + obj = _TrialComponent.load(trial_component_name="foo", sagemaker_session=sagemaker_session) + client.describe_trial_component.assert_called_with(TrialComponentName="foo") + assert "A" == obj.trial_component_arn + assert "B" == obj.trial_component_name + assert "C" == obj.display_name + assert ( + _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="D" + ) + == obj.status + ) + assert {"E": 1.0, "F": "G"} == obj.parameters + assert {"H": _api_types.TrialComponentArtifact(value="s3://foo/bar", media_type="text/plain")} + assert { + "I": _api_types.TrialComponentArtifact(value="s3://whizz/bang", media_type="text/plain") + } + assert [ + _api_types.TrialComponentMetricSummary( + metric_name="J", + count=1, + min=1.0, + max=2.0, + avg=3.0, + std_dev=4.0, + source_arn="K", + timestamp=now, + ) + ] + + +def test_save(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _TrialComponent( + sagemaker_session, + trial_component_name="foo", + display_name="bar", + parameters_to_remove=["E"], + input_artifacts_to_remove=["F"], + output_artifacts_to_remove=["G"], + ) + client.update_trial_component.return_value = {} + obj.save() + + client.update_trial_component.assert_called_with( + TrialComponentName="foo", + DisplayName="bar", + Parameters={}, + ParametersToRemove=["E"], + InputArtifacts={}, + InputArtifactsToRemove=["F"], + OutputArtifacts={}, + OutputArtifactsToRemove=["G"], + ) + + +def test_delete(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _TrialComponent(sagemaker_session, trial_component_name="foo", display_name="bar") + client.delete_trial_component.return_value = {} + obj.delete() + client.delete_trial_component.assert_called_with(TrialComponentName="foo") + + +def test_delete_with_force_disassociate(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _TrialComponent(sagemaker_session, trial_component_name="foo", display_name="bar") + client.delete_trial_component.return_value = {} + + client.list_trials.side_effect = [ + {"TrialSummaries": [{"TrialName": "trial-1"}, {"TrialName": "trial-2"}], "NextToken": "a"}, + {"TrialSummaries": [{"TrialName": "trial-3"}, {"TrialName": "trial-4"}]}, + ] + + obj.delete(force_disassociate=True) + expected_calls = [ + unittest.mock.call(TrialName="trial-1", TrialComponentName="foo"), + unittest.mock.call(TrialName="trial-2", TrialComponentName="foo"), + unittest.mock.call(TrialName="trial-3", TrialComponentName="foo"), + unittest.mock.call(TrialName="trial-4", TrialComponentName="foo"), + ] + assert expected_calls == client.disassociate_trial_component.mock_calls + client.delete_trial_component.assert_called_with(TrialComponentName="foo") + + +def test_list(sagemaker_session): + start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) + end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2) + creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3) + last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4) + + client = sagemaker_session.sagemaker_client + client.list_trial_components.side_effect = [ + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "A" + str(i), + "TrialComponentArn": "B" + str(i), + "DisplayName": "C" + str(i), + "SourceArn": "D" + str(i), + "Status": { + "PrimaryStatus": _TrialComponentStatusType.InProgress.value, + "Message": "E" + str(i), + }, + "StartTime": start_time + datetime.timedelta(hours=i), + "EndTime": end_time + datetime.timedelta(hours=i), + "CreationTime": creation_time + datetime.timedelta(hours=i), + "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i), + "LastModifiedBy": {}, + } + for i in range(10) + ], + "NextToken": "100", + }, + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "A" + str(i), + "TrialComponentArn": "B" + str(i), + "DisplayName": "C" + str(i), + "SourceArn": "D" + str(i), + "Status": { + "PrimaryStatus": _TrialComponentStatusType.InProgress.value, + "Message": "E" + str(i), + }, + "StartTime": start_time + datetime.timedelta(hours=i), + "EndTime": end_time + datetime.timedelta(hours=i), + "CreationTime": creation_time + datetime.timedelta(hours=i), + "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i), + "LastModifiedBy": {}, + } + for i in range(10, 20) + ] + }, + ] + + expected = [ + _api_types.TrialComponentSummary( + trial_component_name="A" + str(i), + trial_component_arn="B" + str(i), + display_name="C" + str(i), + source_arn="D" + str(i), + status=_api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + ), + start_time=start_time + datetime.timedelta(hours=i), + end_time=end_time + datetime.timedelta(hours=i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + ) + for i in range(20) + ] + result = list( + _TrialComponent.list( + sagemaker_session=sagemaker_session, + source_arn="foo", + sort_by="CreationTime", + sort_order="Ascending", + ) + ) + + assert expected == result + expected_calls = [ + unittest.mock.call(SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"), + unittest.mock.call( + NextToken="100", SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo" + ), + ] + assert expected_calls == client.list_trial_components.mock_calls + + +def test_list_empty(sagemaker_session): + sagemaker_session.sagemaker_client.list_trial_components.return_value = { + "TrialComponentSummaries": [] + } + assert [] == list(_TrialComponent.list(sagemaker_session=sagemaker_session)) + + +def test_list_trial_components_call_args(sagemaker_session): + created_before = datetime.datetime(1999, 10, 12, 0, 0, 0) + created_after = datetime.datetime(1990, 10, 12, 0, 0, 0) + trial_name = "foo-trial" + experiment_name = "foo-experiment" + next_token = "thetoken" + max_results = 99 + + client = sagemaker_session.sagemaker_client + client.list_trial_components.return_value = {} + assert [] == list( + _TrialComponent.list( + sagemaker_session=sagemaker_session, + trial_name=trial_name, + experiment_name=experiment_name, + created_before=created_before, + created_after=created_after, + next_token=next_token, + max_results=max_results, + sort_by="CreationTime", + sort_order="Ascending", + ) + ) + + expected_calls = [ + unittest.mock.call( + TrialName="foo-trial", + ExperimentName="foo-experiment", + CreatedBefore=created_before, + CreatedAfter=created_after, + SortBy="CreationTime", + SortOrder="Ascending", + NextToken="thetoken", + MaxResults=99, + ) + ] + assert expected_calls == client.list_trial_components.mock_calls + + +@patch("sagemaker.experiments.trial_component._TrialComponent.load") +def test_load_or_create_when_exist(mock_load, sagemaker_session): + tc_name = "tc_name" + _, is_existed = _TrialComponent._load_or_create( + trial_component_name=tc_name, sagemaker_session=sagemaker_session + ) + assert is_existed + mock_load.assert_called_once_with( + tc_name, + sagemaker_session, + ) + + +@patch("sagemaker.experiments.trial_component._TrialComponent.load") +@patch("sagemaker.experiments.trial_component._TrialComponent.create") +def test_load_or_create_when_not_exist(mock_create, mock_load): + sagemaker_session = Session() + client = sagemaker_session.sagemaker_client + tc_name = "tc_name" + not_found_err = client.exceptions.ResourceNotFound( + error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, + operation_name="foo", + ) + mock_load.side_effect = not_found_err + + _, is_existed = _TrialComponent._load_or_create( + trial_component_name=tc_name, sagemaker_session=sagemaker_session + ) + + assert not is_existed + mock_create.assert_called_once_with( + trial_component_name=tc_name, + display_name=None, + tags=None, + sagemaker_session=sagemaker_session, + ) + + +def test_search(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.search.return_value = { + "Results": [ + { + "TrialComponent": { + "TrialComponentName": "tc-1", + "TrialComponentArn": "arn::tc-1", + "DisplayName": "TC1", + "Parents": [ + { + "ExperimentName": "e-1", + "TrialName": "t-1", + }, + { + "ExperimentName": "e-2", + "TrialName": "t-2", + }, + ], + } + }, + { + "TrialComponent": { + "TrialComponentName": "tc-2", + "TrialComponentArn": "arn::tc-2", + "DisplayName": "TC2", + } + }, + ] + } + expected = [ + TrialComponentSearchResult( + trial_component_name="tc-1", + trial_component_arn="arn::tc-1", + display_name="TC1", + parents=[ + Parent(experiment_name="e-1", trial_name="t-1"), + Parent(experiment_name="e-2", trial_name="t-2"), + ], + ), + TrialComponentSearchResult( + trial_component_name="tc-2", trial_component_arn="arn::tc-2", display_name="TC2" + ), + ] + assert expected == list(_TrialComponent.search(sagemaker_session=sagemaker_session)) diff --git a/tests/unit/sagemaker/experiments/test_utils.py b/tests/unit/sagemaker/experiments/test_utils.py new file mode 100644 index 0000000000..a63c96c0fe --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_utils.py @@ -0,0 +1,36 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from src.sagemaker.experiments._utils import resolve_artifact_name, guess_media_type + + +def test_resolve_artifact_name(): + file_names = { + "a": "a", + "a.txt": "a.txt", + "b.": "b.", + ".c": ".c", + "/x/a/a.txt": "a.txt", + "/a/b/c.": "c.", + "./.a": ".a", + "../b.txt": "b.txt", + "~/a.txt": "a.txt", + "c/d.txt": "d.txt", + } + for file_name, artifact_name in file_names.items(): + assert artifact_name == resolve_artifact_name(file_name) + + +def test_guess_media_type(): + assert "text/plain" == guess_media_type("foo.txt") diff --git a/tests/unit/sagemaker/feature_store/test_dataset_builder.py b/tests/unit/sagemaker/feature_store/test_dataset_builder.py new file mode 100644 index 0000000000..0e55b86bd0 --- /dev/null +++ b/tests/unit/sagemaker/feature_store/test_dataset_builder.py @@ -0,0 +1,612 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime + +import pandas as pd +import pytest +import os +from mock import Mock, patch + +from sagemaker.feature_store.dataset_builder import ( + DatasetBuilder, + FeatureGroupToBeMerged, + TableType, +) +from sagemaker.feature_store.feature_group import ( + FeatureDefinition, + FeatureGroup, + FeatureTypeEnum, +) + + +@pytest.fixture +def sagemaker_session_mock(): + return Mock() + + +@pytest.fixture +def feature_group_mock(): + return Mock() + + +@pytest.fixture +def read_csv_mock(): + return Mock() + + +@pytest.fixture +def to_csv_file_mock(): + return Mock() + + +@pytest.fixture +def remove_mock(): + return Mock() + + +BASE = FeatureGroupToBeMerged( + ["target-feature", "other-feature"], + ["target-feature", "other-feature"], + ["target-feature", "other-feature"], + "catalog", + "database", + "base-table", + "target-feature", + FeatureDefinition("other-feature", FeatureTypeEnum.STRING), + None, + TableType.FEATURE_GROUP, +) +FEATURE_GROUP = FeatureGroupToBeMerged( + ["feature-1", "feature-2"], + ["feature-1", "feature-2"], + ["feature-1", "feature-2"], + "catalog", + "database", + "table-name", + "feature-1", + FeatureDefinition("feature-2", FeatureTypeEnum.FRACTIONAL), + "target-feature", + TableType.FEATURE_GROUP, +) + + +def test_with_feature_group_throw_runtime_error(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + sagemaker_session_mock.describe_feature_group.return_value = {"OfflineStoreConfig": {}} + with pytest.raises(RuntimeError) as error: + dataset_builder.with_feature_group( + feature_group, "target-feature", ["feature-1", "feature-2"] + ) + assert "No metastore is configured with FeatureGroup MyFeatureGroup." in str(error) + + +def test_with_feature_group(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]}) + feature_group.load_feature_definitions(dataframe) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + record_identifier_feature_name="target-feature", + ) + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": {"DataCatalogConfig": {"TableName": "table", "Database": "database"}}, + "RecordIdentifierFeatureName": "feature-1", + "EventTimeFeatureName": "feature-2", + "FeatureDefinitions": [ + {"FeatureName": "feature-1", "FeatureType": "String"}, + {"FeatureName": "feature-2", "FeatureType": "String"}, + ], + } + dataset_builder.with_feature_group(feature_group, "target-feature", ["feature-1", "feature-2"]) + assert len(dataset_builder._feature_groups_to_be_merged) == 1 + assert dataset_builder._feature_groups_to_be_merged[0].features == [ + "feature-1", + "feature-2", + ] + assert dataset_builder._feature_groups_to_be_merged[0].included_feature_names == [ + "feature-1", + "feature-2", + ] + assert dataset_builder._feature_groups_to_be_merged[0].database == "database" + assert dataset_builder._feature_groups_to_be_merged[0].table_name == "table" + assert ( + dataset_builder._feature_groups_to_be_merged[0].record_identifier_feature_name + == "feature-1" + ) + assert ( + dataset_builder._feature_groups_to_be_merged[0].event_time_identifier_feature.feature_name + == "feature-2" + ) + assert ( + dataset_builder._feature_groups_to_be_merged[0].event_time_identifier_feature.feature_type + == FeatureTypeEnum.STRING + ) + assert ( + dataset_builder._feature_groups_to_be_merged[0].target_feature_name_in_base + == "target-feature" + ) + + +def test_point_in_time_accurate_join(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.point_in_time_accurate_join() + assert dataset_builder._point_in_time_accurate_join + + +def test_include_duplicated_records(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.include_duplicated_records() + assert dataset_builder._include_duplicated_records + + +def test_include_deleted_records(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.include_deleted_records() + assert dataset_builder._include_deleted_records + + +def test_with_number_of_recent_records_by_record_identifier( + sagemaker_session_mock, feature_group_mock +): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.with_number_of_recent_records_by_record_identifier(5) + assert dataset_builder._number_of_recent_records == 5 + + +def test_with_number_of_records_from_query_results(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.with_number_of_records_from_query_results(100) + assert dataset_builder._number_of_records == 100 + + +def test_with_event_time_range(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + start = datetime.datetime.now() + end = start + datetime.timedelta(minutes=1) + dataset_builder.with_event_time_range(start, end) + assert dataset_builder._event_time_starting_timestamp == start + assert dataset_builder._event_time_ending_timestamp == end + + +def test_to_csv_file_not_support_base_type(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + with pytest.raises(ValueError) as error: + dataset_builder.to_csv_file() + assert "Base must be either a FeatureGroup or a DataFrame." in str(error) + + +def test_to_csv_file_with_feature_group(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": {"DataCatalogConfig": {"TableName": "table", "Database": "database"}}, + "RecordIdentifierFeatureName": "feature-1", + "EventTimeFeatureName": "feature-2", + "FeatureDefinitions": [ + {"FeatureName": "feature-1", "FeatureType": "String"}, + {"FeatureName": "feature-2", "FeatureType": "String"}, + ], + } + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": { + "Status": {"State": "SUCCEEDED"}, + "ResultConfiguration": {"OutputLocation": "s3-file-path"}, + "Query": "query-string", + } + } + file_path, query_string = dataset_builder.to_csv_file() + assert file_path == "s3-file-path" + assert query_string == "query-string" + + +@patch("pandas.DataFrame.to_csv") +@patch("pandas.read_csv") +@patch("os.remove") +def test_to_dataframe_with_dataframe( + remove_mock, read_csv_mock, to_csv_file_mock, sagemaker_session_mock +): + dataframe = pd.DataFrame({"feature-1": [420, 380.0, 390], "feature-2": [50, 40.0, 45]}) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="s3://file/to/path", + event_time_identifier_feature_name="feature-2", + ) + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": { + "Status": {"State": "SUCCEEDED"}, + "ResultConfiguration": {"OutputLocation": "s3://s3-file-path"}, + "Query": "query-string", + } + } + to_csv_file_mock.return_value = None + read_csv_mock.return_value = dataframe + os.remove.return_value = None + df, query_string = dataset_builder.to_dataframe() + assert df.equals(dataframe) + assert query_string == "query-string" + + +def test_construct_where_query_string(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + time = datetime.datetime.now().replace(microsecond=0) + start = time + datetime.timedelta(minutes=1) + end = start + datetime.timedelta(minutes=1) + dataset_builder._write_time_ending_timestamp = time + dataset_builder._event_time_starting_timestamp = start + dataset_builder._event_time_ending_timestamp = end + query_string = dataset_builder._construct_where_query_string( + "suffix", + FeatureDefinition("event-time", FeatureTypeEnum.STRING), + ["NOT is_deleted"], + ) + assert ( + query_string + == "WHERE NOT is_deleted\n" + + f"AND table_suffix.\"write_time\" <= to_timestamp('{time}', " + + "'yyyy-mm-dd hh24:mi:ss')\n" + + 'AND from_iso8601_timestamp(table_suffix."event-time") >= ' + + f"from_unixtime({start.timestamp()})\n" + + 'AND from_iso8601_timestamp(table_suffix."event-time") <= ' + + f"from_unixtime({end.timestamp()})" + ) + + +def test_construct_query_string_with_duplicated_records(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder._include_duplicated_records = True + + dataset_builder._feature_groups_to_be_merged = [FEATURE_GROUP] + query_string = dataset_builder._construct_query_string(BASE) + assert ( + query_string + == "WITH fg_base AS (WITH deleted_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."target-feature"\n' + + 'ORDER BY origin_base."other-feature" DESC, origin_base."api_invocation_time" DESC, ' + + 'origin_base."write_time" DESC\n' + + ") AS deleted_row_base\n" + + 'FROM "database"."base-table" origin_base\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_base = 1\n" + + ")\n" + + 'SELECT table_base."target-feature", table_base."other-feature"\n' + + "FROM (\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + 'FROM "database"."base-table" table_base\n' + + "LEFT JOIN deleted_base\n" + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + 'WHERE deleted_base."target-feature" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + "FROM deleted_base\n" + + 'JOIN "database"."base-table" table_base\n' + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + "AND (\n" + + 'table_base."other-feature" > deleted_base."other-feature"\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" > deleted_base."api_invocation_time")\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" = deleted_base."api_invocation_time" AND ' + + 'table_base."write_time" > deleted_base."write_time")\n' + + ")\n" + + ") AS table_base\n" + + "),\n" + + "fg_0 AS (WITH deleted_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."feature-1"\n' + + 'ORDER BY origin_0."feature-2" DESC, origin_0."api_invocation_time" DESC, ' + + 'origin_0."write_time" DESC\n' + + ") AS deleted_row_0\n" + + 'FROM "database"."table-name" origin_0\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_0 = 1\n" + + ")\n" + + 'SELECT table_0."feature-1", table_0."feature-2"\n' + + "FROM (\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + 'FROM "database"."table-name" table_0\n' + + "LEFT JOIN deleted_0\n" + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + 'WHERE deleted_0."feature-1" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + "FROM deleted_0\n" + + 'JOIN "database"."table-name" table_0\n' + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + "AND (\n" + + 'table_0."feature-2" > deleted_0."feature-2"\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND table_0."api_invocation_time" > ' + + 'deleted_0."api_invocation_time")\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND table_0."api_invocation_time" = ' + + 'deleted_0."api_invocation_time" AND table_0."write_time" > deleted_0."write_time")\n' + + ")\n" + + ") AS table_0\n" + + ")\n" + + 'SELECT target-feature, other-feature, "feature-1.1", "feature-2.1"\n' + + "FROM (\n" + + 'SELECT fg_base.target-feature, fg_base.other-feature, fg_0."feature-1" as ' + + '"feature-1.1", fg_0."feature-2" as "feature-2.1", row_number() OVER (\n' + + 'PARTITION BY fg_base."target-feature"\n' + + 'ORDER BY fg_base."other-feature" DESC, fg_0."feature-2" DESC\n' + + ") AS row_recent\n" + + "FROM fg_base\n" + + "JOIN fg_0\n" + + 'ON fg_base."target-feature" = fg_0."feature-1"\n' + + ")\n" + ) + + +def test_construct_query_string(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + dataset_builder._point_in_time_accurate_join = True + dataset_builder._event_time_identifier_feature_name = "target-feature" + dataset_builder._feature_groups_to_be_merged = [FEATURE_GROUP] + query_string = dataset_builder._construct_query_string(BASE) + assert ( + query_string + == "WITH fg_base AS (WITH table_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."target-feature", origin_base."other-feature"\n' + + 'ORDER BY origin_base."api_invocation_time" DESC, origin_base."write_time" DESC\n' + + ") AS dedup_row_base\n" + + 'FROM "database"."base-table" origin_base\n' + + ")\n" + + "WHERE dedup_row_base = 1\n" + + "),\n" + + "deleted_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."target-feature"\n' + + 'ORDER BY origin_base."other-feature" DESC, origin_base."api_invocation_time" ' + + 'DESC, origin_base."write_time" DESC\n' + + ") AS deleted_row_base\n" + + 'FROM "database"."base-table" origin_base\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_base = 1\n" + + ")\n" + + 'SELECT table_base."target-feature", table_base."other-feature"\n' + + "FROM (\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + "FROM table_base\n" + + "LEFT JOIN deleted_base\n" + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + 'WHERE deleted_base."target-feature" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + "FROM deleted_base\n" + + "JOIN table_base\n" + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + "AND (\n" + + 'table_base."other-feature" > deleted_base."other-feature"\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" > deleted_base."api_invocation_time")\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" = deleted_base."api_invocation_time" AND ' + + 'table_base."write_time" > deleted_base."write_time")\n' + + ")\n" + + ") AS table_base\n" + + "),\n" + + "fg_0 AS (WITH table_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."feature-1", origin_0."feature-2"\n' + + 'ORDER BY origin_0."api_invocation_time" DESC, origin_0."write_time" DESC\n' + + ") AS dedup_row_0\n" + + 'FROM "database"."table-name" origin_0\n' + + ")\n" + + "WHERE dedup_row_0 = 1\n" + + "),\n" + + "deleted_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."feature-1"\n' + + 'ORDER BY origin_0."feature-2" DESC, origin_0."api_invocation_time" DESC, ' + + 'origin_0."write_time" DESC\n' + + ") AS deleted_row_0\n" + + 'FROM "database"."table-name" origin_0\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_0 = 1\n" + + ")\n" + + 'SELECT table_0."feature-1", table_0."feature-2"\n' + + "FROM (\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + "FROM table_0\n" + + "LEFT JOIN deleted_0\n" + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + 'WHERE deleted_0."feature-1" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + "FROM deleted_0\n" + + "JOIN table_0\n" + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + "AND (\n" + + 'table_0."feature-2" > deleted_0."feature-2"\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND ' + + 'table_0."api_invocation_time" > deleted_0."api_invocation_time")\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND ' + + 'table_0."api_invocation_time" = deleted_0."api_invocation_time" AND ' + + 'table_0."write_time" > deleted_0."write_time")\n' + + ")\n" + + ") AS table_0\n" + + ")\n" + + 'SELECT target-feature, other-feature, "feature-1.1", "feature-2.1"\n' + + "FROM (\n" + + 'SELECT fg_base.target-feature, fg_base.other-feature, fg_0."feature-1" as ' + + '"feature-1.1", fg_0."feature-2" as "feature-2.1", row_number() OVER (\n' + + 'PARTITION BY fg_base."target-feature"\n' + + 'ORDER BY fg_base."other-feature" DESC, fg_0."feature-2" DESC\n' + + ") AS row_recent\n" + + "FROM fg_base\n" + + "JOIN fg_0\n" + + 'ON fg_base."target-feature" = fg_0."feature-1"\n' + + 'AND from_unixtime(fg_base."target-feature") >= from_unixtime(fg_0."feature-2")\n' + + ")\n" + ) + + +def test_create_temp_table(sagemaker_session_mock): + dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]}) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="file/to/path", + ) + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "SUCCEEDED"}} + } + dataset_builder._create_temp_table("table-name", "s3-folder") + assert sagemaker_session_mock.start_query_execution.call_count == 1 + sagemaker_session_mock.start_query_execution.assert_called_once_with( + catalog="AwsDataCatalog", + database="sagemaker_featurestore", + query_string="CREATE EXTERNAL TABLE table-name (feature-1 INT, feature-2 INT) " + + "ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' " + + 'WITH SERDEPROPERTIES ("separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\") ' + + "LOCATION 's3-folder';", + output_location="file/to/path", + kms_key=None, + ) + + +@pytest.mark.parametrize( + "column, expected", + [ + ("feature-1", "feature-1 STRING"), + ("feature-2", "feature-2 INT"), + ("feature-3", "feature-3 DOUBLE"), + ("feature-4", "feature-4 BOOLEAN"), + ("feature-5", "feature-5 TIMESTAMP"), + ], +) +def test_construct_athena_table_column_string(column, expected, sagemaker_session_mock): + dataframe = pd.DataFrame( + { + "feature-1": ["420"], + "feature-2": [50], + "feature-3": [5.0], + "feature-4": [True], + "feature-5": [pd.Timestamp(1513393355)], + } + ) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="file/to/path", + ) + query_string = dataset_builder._construct_athena_table_column_string(column) + assert query_string == expected + + +def test_construct_athena_table_column_string_not_support_column_type( + sagemaker_session_mock, +): + dataframe = pd.DataFrame({"feature": pd.Series([1] * 3, dtype="int8")}) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="file/to/path", + ) + with pytest.raises(RuntimeError) as error: + dataset_builder._construct_athena_table_column_string("feature") + assert "The dataframe type int8 is not supported yet." in str(error) + + +def test_run_query_throw_runtime_error(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "FAILED"}} + } + with pytest.raises(RuntimeError) as error: + dataset_builder._run_query("query-string", "catalog", "database") + assert "Failed to execute query query-id." in str(error) diff --git a/tests/unit/sagemaker/feature_store/test_feature_group.py b/tests/unit/sagemaker/feature_store/test_feature_group.py new file mode 100644 index 0000000000..dce38fe426 --- /dev/null +++ b/tests/unit/sagemaker/feature_store/test_feature_group.py @@ -0,0 +1,580 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +import pandas as pd +import pytest +from mock import Mock, patch, MagicMock +from botocore.exceptions import ProfileNotFound + +from sagemaker.feature_store.feature_definition import ( + FractionalFeatureDefinition, + IntegralFeatureDefinition, + StringFeatureDefinition, + FeatureTypeEnum, +) +from sagemaker.feature_store.feature_group import ( + FeatureGroup, + IngestionManagerPandas, + AthenaQuery, + IngestionError, +) +from sagemaker.feature_store.inputs import FeatureParameter + + +class PicklableMock(Mock): + def __reduce__(self): + return (Mock, ()) + + +@pytest.fixture +def role_arn(): + return "arn:role" + + +@pytest.fixture +def s3_uri(): + return "s3://some/uri" + + +@pytest.fixture +def sagemaker_session_mock(): + return Mock() + + +@pytest.fixture +def fs_runtime_client_config_mock(): + return PicklableMock() + + +@pytest.fixture +def feature_group_dummy_definitions(): + return [ + FractionalFeatureDefinition(feature_name="feature1"), + IntegralFeatureDefinition(feature_name="feature2"), + StringFeatureDefinition(feature_name="feature3"), + ] + + +@pytest.fixture +def create_table_ddl(): + return ( + "CREATE EXTERNAL TABLE IF NOT EXISTS {database}.{table_name} (\n" + " feature1 FLOAT\n" + " feature2 INT\n" + " feature3 STRING\n" + " write_time TIMESTAMP\n" + " event_time TIMESTAMP\n" + " is_deleted BOOLEAN\n" + ")\n" + "ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'\n" + " STORED AS\n" + " INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'\n" + " OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'\n" + "LOCATION 's3://resolved_output_s3_uri'" + ) + + +def test_feature_store_create( + sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri +): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + feature_group.create( + s3_uri=s3_uri, + record_identifier_name="feature1", + event_time_feature_name="feature2", + role_arn=role_arn, + enable_online_store=True, + ) + sagemaker_session_mock.create_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_name="feature1", + event_time_feature_name="feature2", + feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], + role_arn=role_arn, + description=None, + tags=None, + online_store_config={"EnableOnlineStore": True}, + offline_store_config={ + "DisableGlueTableCreation": False, + "S3StorageConfig": {"S3Uri": s3_uri}, + }, + ) + + +def test_feature_store_create_online_only( + sagemaker_session_mock, role_arn, feature_group_dummy_definitions +): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature2", + role_arn=role_arn, + enable_online_store=True, + ) + sagemaker_session_mock.create_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_name="feature1", + event_time_feature_name="feature2", + feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], + role_arn=role_arn, + description=None, + tags=None, + online_store_config={"EnableOnlineStore": True}, + ) + + +def test_feature_store_delete(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.delete() + sagemaker_session_mock.delete_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup" + ) + + +def test_feature_store_describe(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.describe() + sagemaker_session_mock.describe_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", next_token=None + ) + + +def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.update(feature_group_dummy_definitions) + sagemaker_session_mock.update_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions], + ) + + +def test_feature_metadata_update(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + + parameter_additions = [FeatureParameter(key="key1", value="value1")] + parameter_removals = ["key2"] + + feature_group.update_feature_metadata( + feature_name="Feature1", + description="TestDescription", + parameter_additions=parameter_additions, + parameter_removals=parameter_removals, + ) + sagemaker_session_mock.update_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_name="Feature1", + description="TestDescription", + parameter_additions=[pa.to_dict() for pa in parameter_additions], + parameter_removals=parameter_removals, + ) + feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription") + sagemaker_session_mock.update_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_name="Feature1", + description="TestDescription", + parameter_additions=[], + parameter_removals=[], + ) + + +def test_feature_metadata_describe(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.describe_feature_metadata(feature_name="Feature1") + sagemaker_session_mock.describe_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", feature_name="Feature1" + ) + + +def test_get_record(sagemaker_session_mock): + feature_group_name = "MyFeatureGroup" + feature_names = ["MyFeature1", "MyFeature2"] + record_identifier_value_as_string = "1.0" + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session_mock) + feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + feature_names=feature_names, + ) + sagemaker_session_mock.get_record.assert_called_with( + feature_group_name=feature_group_name, + record_identifier_value_as_string=record_identifier_value_as_string, + feature_names=feature_names, + ) + + +def test_put_record(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.put_record(record=[]) + sagemaker_session_mock.put_record.assert_called_with( + feature_group_name="MyFeatureGroup", record=[] + ) + + +def test_delete_record(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + record_identifier_value_as_string = "1.0" + event_time = "2022-09-14" + feature_group.delete_record( + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=event_time, + ) + sagemaker_session_mock.delete_record.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=event_time, + ) + + +def test_load_feature_definition(sagemaker_session_mock): + feature_group = FeatureGroup(name="SomeGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame( + { + "float": pd.Series([2.0], dtype="float64"), + "int": pd.Series([2], dtype="int64"), + "string": pd.Series(["f1"], dtype="string"), + } + ) + feature_definitions = feature_group.load_feature_definitions(data_frame=df) + names = [fd.feature_name for fd in feature_definitions] + types = [fd.feature_type for fd in feature_definitions] + assert names == ["float", "int", "string"] + assert types == [ + FeatureTypeEnum.FRACTIONAL, + FeatureTypeEnum.INTEGRAL, + FeatureTypeEnum.STRING, + ] + + +def test_load_feature_definition_unsupported_types(sagemaker_session_mock): + feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame( + { + "float": pd.Series([2.0], dtype="float64"), + "int": pd.Series([2], dtype="int64"), + "bool": pd.Series([True], dtype="bool"), + } + ) + with pytest.raises(ValueError) as error: + feature_group.load_feature_definitions(data_frame=df) + assert "Failed to infer Feature type based on dtype bool for column bool." in str(error) + + +def test_ingest_zero_processes(): + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = Mock() + with pytest.raises(RuntimeError) as error: + feature_group.ingest(data_frame=df, max_workers=1, max_processes=0) + + assert "max_processes must be greater than 0." in str(error) + + +def test_ingest_zero_workers(): + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = Mock() + with pytest.raises(RuntimeError) as error: + feature_group.ingest(data_frame=df, max_workers=0, max_processes=1) + + assert "max_workers must be greater than 0." in str(error) + + +@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") +def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock): + sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( + fs_runtime_client_config_mock + ) + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) + + mock_ingestion_manager_instance = Mock() + ingestion_manager_init.return_value = mock_ingestion_manager_instance + feature_group.ingest(data_frame=df, max_workers=10) + + ingestion_manager_init.assert_called_once_with( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=10, + max_processes=1, + profile_name=None, + ) + mock_ingestion_manager_instance.run.assert_called_once_with( + data_frame=df, wait=True, timeout=None + ) + + +@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") +def test_ingest_with_profile_name( + ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock +): + sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( + fs_runtime_client_config_mock + ) + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) + + mock_ingestion_manager_instance = Mock() + ingestion_manager_init.return_value = mock_ingestion_manager_instance + feature_group.ingest(data_frame=df, max_workers=10, profile_name="profile_name") + + ingestion_manager_init.assert_called_once_with( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=10, + max_processes=1, + profile_name="profile_name", + ) + mock_ingestion_manager_instance.run.assert_called_once_with( + data_frame=df, wait=True, timeout=None + ) + + +def test_as_hive_ddl_with_default_values( + create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock +): + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": "s3://some-bucket", + "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", + } + } + } + sagemaker_session_mock.account_id.return_value = "1234" + sagemaker_session_mock.boto_session.region_name = "us-west-2" + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + assert ( + create_table_ddl.format( + database="sagemaker_featurestore", + table_name="MyGroup", + account="1234", + region="us-west-2", + feature_group_name="MyGroup", + ) + == feature_group.as_hive_ddl() + ) + + +def test_as_hive_ddl(create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock): + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": "s3://some-bucket", + "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", + } + } + } + sagemaker_session_mock.account_id.return_value = "1234" + sagemaker_session_mock.boto_session.region_name = "us-west-2" + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + assert create_table_ddl.format( + database="MyDatabase", + table_name="MyTable", + account="1234", + region="us-west-2", + feature_group_name="MyGroup", + ) == feature_group.as_hive_ddl(database="MyDatabase", table_name="MyTable") + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_process", + MagicMock(), +) +def test_ingestion_manager_run_success(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=10, + ) + manager.run(df) + + manager._run_multi_process.assert_called_once_with(data_frame=df, wait=True, timeout=None) + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_threaded", + PicklableMock(return_value=[]), +) +def test_ingestion_manager_run_multi_process_with_multi_thread_success( + fs_runtime_client_config_mock, +): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=2, + max_processes=2, + ) + manager.run(df) + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", + MagicMock(return_value=[1]), +) +def test_ingestion_manager_run_failure(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=1, + ) + + with pytest.raises(IngestionError) as error: + manager.run(df) + + assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) + assert error.value.failed_rows == [1] + assert manager.failed_rows == [1] + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", + MagicMock(side_effect=ProfileNotFound(profile="non_exist")), +) +def test_ingestion_manager_with_profile_name_run_failure(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=1, + profile_name="non_exist", + ) + + try: + manager.run(df) + except Exception as e: + assert "The config profile (non_exist) could not be found" in str(e) + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", + PicklableMock(return_value=[1]), +) +def test_ingestion_manager_run_multi_process_failure(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=None, + max_workers=2, + max_processes=2, + ) + + with pytest.raises(IngestionError) as error: + manager.run(df) + + assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) + assert error.value.failed_rows == [1, 1, 1, 1] + assert manager.failed_rows == [1, 1, 1, 1] + + +@pytest.fixture +def query(sagemaker_session_mock): + return AthenaQuery( + catalog="catalog", + database="database", + table_name="table_name", + sagemaker_session=sagemaker_session_mock, + ) + + +def test_athena_query_run(sagemaker_session_mock, query): + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"} + query.run( + query_string="query", output_location="s3://some-bucket/some-path", workgroup="workgroup" + ) + sagemaker_session_mock.start_query_execution.assert_called_with( + catalog="catalog", + database="database", + query_string="query", + output_location="s3://some-bucket/some-path", + kms_key=None, + workgroup="workgroup", + ) + assert "some-bucket" == query._result_bucket + assert "some-path" == query._result_file_prefix + assert "query_id" == query._current_query_execution_id + + +def test_athena_query_wait(sagemaker_session_mock, query): + query._current_query_execution_id = "query_id" + query.wait() + sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id") + + +def test_athena_query_get_query_execution(sagemaker_session_mock, query): + query._current_query_execution_id = "query_id" + query.get_query_execution() + sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id") + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +@patch("pandas.read_csv") +def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "SUCCEEDED"}} + } + query._current_query_execution_id = "query_id" + query._result_bucket = "bucket" + query._result_file_prefix = "prefix" + query.as_dataframe() + sagemaker_session_mock.download_athena_query_result.assert_called_with( + bucket="bucket", + prefix="prefix", + query_execution_id="query_id", + filename="tmp/query_id.csv", + ) + read_csv.assert_called_with("tmp/query_id.csv", delimiter=",") + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +def test_athena_query_as_dataframe_query_failed(sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "FAILED"}} + } + query._current_query_execution_id = "query_id" + with pytest.raises(RuntimeError) as error: + query.as_dataframe() + assert "Failed to execute query query_id" in str(error) + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +def test_athena_query_as_dataframe_query_queued(sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "QUEUED"}} + } + query._current_query_execution_id = "query_id" + with pytest.raises(RuntimeError) as error: + query.as_dataframe() + assert "Current query query_id is still being executed" in str(error) + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +def test_athena_query_as_dataframe_query_running(sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "RUNNING"}} + } + query._current_query_execution_id = "query_id" + with pytest.raises(RuntimeError) as error: + query.as_dataframe() + assert "Current query query_id is still being executed" in str(error) diff --git a/tests/unit/sagemaker/feature_store/test_feature_store.py b/tests/unit/sagemaker/feature_store/test_feature_store.py index 92ba35573c..073daca9ea 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_store.py +++ b/tests/unit/sagemaker/feature_store/test_feature_store.py @@ -10,46 +10,17 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -# language governing permissions and limitations under the License. from __future__ import absolute_import +import datetime import pandas as pd import pytest -from mock import Mock, patch, MagicMock -from botocore.exceptions import ProfileNotFound - -from sagemaker.feature_store.feature_definition import ( - FractionalFeatureDefinition, - IntegralFeatureDefinition, - StringFeatureDefinition, - FeatureTypeEnum, -) -from sagemaker.feature_store.feature_group import ( - FeatureGroup, - IngestionManagerPandas, - AthenaQuery, - IngestionError, -) -from sagemaker.feature_store.inputs import ( - FeatureParameter, - TableFormatEnum, -) - +from mock import Mock -class PicklableMock(Mock): - def __reduce__(self): - return (Mock, ()) +from sagemaker.feature_store.feature_store import FeatureStore - -@pytest.fixture -def role_arn(): - return "arn:role" - - -@pytest.fixture -def s3_uri(): - return "s3://some/uri" +DATAFRAME = pd.DataFrame({"feature_1": [420, 380, 390], "feature_2": [50, 40, 45]}) @pytest.fixture @@ -58,558 +29,108 @@ def sagemaker_session_mock(): @pytest.fixture -def fs_runtime_client_config_mock(): - return PicklableMock() - - -@pytest.fixture -def feature_group_dummy_definitions(): - return [ - FractionalFeatureDefinition(feature_name="feature1"), - IntegralFeatureDefinition(feature_name="feature2"), - StringFeatureDefinition(feature_name="feature3"), - ] - - -@pytest.fixture -def create_table_ddl(): - return ( - "CREATE EXTERNAL TABLE IF NOT EXISTS {database}.{table_name} (\n" - " feature1 FLOAT\n" - " feature2 INT\n" - " feature3 STRING\n" - " write_time TIMESTAMP\n" - " event_time TIMESTAMP\n" - " is_deleted BOOLEAN\n" - ")\n" - "ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'\n" - " STORED AS\n" - " INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'\n" - " OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'\n" - "LOCATION 's3://resolved_output_s3_uri'" - ) - - -def test_feature_store_create( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=s3_uri, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - offline_store_config={ - "DisableGlueTableCreation": False, - "S3StorageConfig": {"S3Uri": s3_uri}, - }, - ) - - -def test_feature_store_create_iceberg_table_format( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=s3_uri, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - disable_glue_table_creation=False, - table_format=TableFormatEnum.ICEBERG, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - offline_store_config={ - "DisableGlueTableCreation": False, - "TableFormat": "Iceberg", - "S3StorageConfig": {"S3Uri": s3_uri}, - }, - ) - - -def test_feature_store_create_glue_table_format( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=s3_uri, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - disable_glue_table_creation=False, - table_format=TableFormatEnum.GLUE, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - offline_store_config={ - "DisableGlueTableCreation": False, - "TableFormat": "Glue", - "S3StorageConfig": {"S3Uri": s3_uri}, - }, - ) - - -def test_feature_store_create_online_only( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=False, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - ) - - -def test_feature_store_delete(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.delete() - sagemaker_session_mock.delete_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup" - ) - - -def test_feature_store_describe(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.describe() - sagemaker_session_mock.describe_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", next_token=None - ) - - -def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.update(feature_group_dummy_definitions) - sagemaker_session_mock.update_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions], - ) - - -def test_feature_metadata_update(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - - parameter_additions = [FeatureParameter(key="key1", value="value1")] - parameter_removals = ["key2"] - - feature_group.update_feature_metadata( - feature_name="Feature1", - description="TestDescription", - parameter_additions=parameter_additions, - parameter_removals=parameter_removals, - ) - sagemaker_session_mock.update_feature_metadata.assert_called_with( - feature_group_name="MyFeatureGroup", - feature_name="Feature1", - description="TestDescription", - parameter_additions=[pa.to_dict() for pa in parameter_additions], - parameter_removals=parameter_removals, - ) - feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription") - sagemaker_session_mock.update_feature_metadata.assert_called_with( - feature_group_name="MyFeatureGroup", - feature_name="Feature1", - description="TestDescription", - parameter_additions=[], - parameter_removals=[], - ) - - -def test_feature_metadata_describe(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.describe_feature_metadata(feature_name="Feature1") - sagemaker_session_mock.describe_feature_metadata.assert_called_with( - feature_group_name="MyFeatureGroup", feature_name="Feature1" - ) - - -def test_put_record(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.put_record(record=[]) - sagemaker_session_mock.put_record.assert_called_with( - feature_group_name="MyFeatureGroup", record=[] - ) - - -def test_load_feature_definition(sagemaker_session_mock): - feature_group = FeatureGroup(name="SomeGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame( - { - "float": pd.Series([2.0], dtype="float64"), - "int": pd.Series([2], dtype="int64"), - "string": pd.Series(["f1"], dtype="string"), - } - ) - feature_definitions = feature_group.load_feature_definitions(data_frame=df) - names = [fd.feature_name for fd in feature_definitions] - types = [fd.feature_type for fd in feature_definitions] - assert names == ["float", "int", "string"] - assert types == [ - FeatureTypeEnum.FRACTIONAL, - FeatureTypeEnum.INTEGRAL, - FeatureTypeEnum.STRING, - ] +def feature_group_mock(): + return Mock() -def test_load_feature_definition_unsupported_types(sagemaker_session_mock): - feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame( - { - "float": pd.Series([2.0], dtype="float64"), - "int": pd.Series([2], dtype="int64"), - "object": pd.Series(["f1"], dtype="object"), - } - ) +def test_minimal_create_dataset(sagemaker_session_mock, feature_group_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + dataset_builder = feature_store.create_dataset( + base=feature_group_mock, + output_path="file/to/path", + ) + assert dataset_builder._sagemaker_session == sagemaker_session_mock + assert dataset_builder._base == feature_group_mock + assert dataset_builder._output_path == "file/to/path" + + +def test_complete_create_dataset(sagemaker_session_mock, feature_group_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + dataset_builder = feature_store.create_dataset( + base=feature_group_mock, + included_feature_names=["feature_1", "feature_2"], + output_path="file/to/path", + kms_key_id="kms-key-id", + ) + assert dataset_builder._sagemaker_session == sagemaker_session_mock + assert dataset_builder._base == feature_group_mock + assert dataset_builder._included_feature_names == ["feature_1", "feature_2"] + assert dataset_builder._output_path == "file/to/path" + assert dataset_builder._kms_key_id == "kms-key-id" + + +def test_create_dataset_with_dataframe(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + dataset_builder = feature_store.create_dataset( + base=DATAFRAME, + record_identifier_feature_name="feature_1", + event_time_identifier_feature_name="feature_2", + included_feature_names=["feature_1", "feature_2"], + output_path="file/to/path", + kms_key_id="kms-key-id", + ) + assert dataset_builder._sagemaker_session == sagemaker_session_mock + assert dataset_builder._base.equals(DATAFRAME) + assert dataset_builder._record_identifier_feature_name == "feature_1" + assert dataset_builder._event_time_identifier_feature_name == "feature_2" + assert dataset_builder._included_feature_names == ["feature_1", "feature_2"] + assert dataset_builder._output_path == "file/to/path" + assert dataset_builder._kms_key_id == "kms-key-id" + + +def test_create_dataset_with_dataframe_value_error(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) with pytest.raises(ValueError) as error: - feature_group.load_feature_definitions(data_frame=df) - assert "Failed to infer Feature type based on dtype object for column object." in str(error) - - -def test_ingest_zero_processes(): - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = Mock() - with pytest.raises(RuntimeError) as error: - feature_group.ingest(data_frame=df, max_workers=1, max_processes=0) - - assert "max_processes must be greater than 0." in str(error) - - -def test_ingest_zero_workers(): - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = Mock() - with pytest.raises(RuntimeError) as error: - feature_group.ingest(data_frame=df, max_workers=0, max_processes=1) - - assert "max_workers must be greater than 0." in str(error) - - -@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") -def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock): - sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( - fs_runtime_client_config_mock - ) - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) - - mock_ingestion_manager_instance = Mock() - ingestion_manager_init.return_value = mock_ingestion_manager_instance - feature_group.ingest(data_frame=df, max_workers=10) - - ingestion_manager_init.assert_called_once_with( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=10, - max_processes=1, - profile_name=None, - ) - mock_ingestion_manager_instance.run.assert_called_once_with( - data_frame=df, wait=True, timeout=None - ) - - -@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") -def test_ingest_with_profile_name( - ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock -): - sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( - fs_runtime_client_config_mock - ) - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) - - mock_ingestion_manager_instance = Mock() - ingestion_manager_init.return_value = mock_ingestion_manager_instance - feature_group.ingest(data_frame=df, max_workers=10, profile_name="profile_name") - - ingestion_manager_init.assert_called_once_with( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=10, - max_processes=1, - profile_name="profile_name", - ) - mock_ingestion_manager_instance.run.assert_called_once_with( - data_frame=df, wait=True, timeout=None - ) - - -def test_as_hive_ddl_with_default_values( - create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock -): - sagemaker_session_mock.describe_feature_group.return_value = { - "OfflineStoreConfig": { - "S3StorageConfig": { - "S3Uri": "s3://some-bucket", - "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", - } - } - } - sagemaker_session_mock.account_id.return_value = "1234" - sagemaker_session_mock.boto_session.region_name = "us-west-2" - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - assert ( - create_table_ddl.format( - database="sagemaker_featurestore", - table_name="MyGroup", - account="1234", - region="us-west-2", - feature_group_name="MyGroup", + feature_store.create_dataset( + base=DATAFRAME, + included_feature_names=["feature_1", "feature_2"], + output_path="file/to/path", + kms_key_id="kms-key-id", ) - == feature_group.as_hive_ddl() - ) - - -def test_as_hive_ddl(create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock): - sagemaker_session_mock.describe_feature_group.return_value = { - "OfflineStoreConfig": { - "S3StorageConfig": { - "S3Uri": "s3://some-bucket", - "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", - } - } - } - sagemaker_session_mock.account_id.return_value = "1234" - sagemaker_session_mock.boto_session.region_name = "us-west-2" - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - assert create_table_ddl.format( - database="MyDatabase", - table_name="MyTable", - account="1234", - region="us-west-2", - feature_group_name="MyGroup", - ) == feature_group.as_hive_ddl(database="MyDatabase", table_name="MyTable") - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_process", - MagicMock(), -) -def test_ingestion_manager_run_success(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=10, - ) - manager.run(df) - - manager._run_multi_process.assert_called_once_with(data_frame=df, wait=True, timeout=None) - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_threaded", - PicklableMock(return_value=[]), -) -def test_ingestion_manager_run_multi_process_with_multi_thread_success( - fs_runtime_client_config_mock, -): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=2, - max_processes=2, - ) - manager.run(df) - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - MagicMock(return_value=[1]), -) -def test_ingestion_manager_run_failure(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=1, - ) - - with pytest.raises(IngestionError) as error: - manager.run(df) - - assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) - assert error.value.failed_rows == [1] - assert manager.failed_rows == [1] - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - MagicMock(side_effect=ProfileNotFound(profile="non_exist")), -) -def test_ingestion_manager_with_profile_name_run_failure(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=1, - profile_name="non_exist", - ) - - try: - manager.run(df) - except Exception as e: - assert "The config profile (non_exist) could not be found" in str(e) - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - PicklableMock(return_value=[1]), -) -def test_ingestion_manager_run_multi_process_failure(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=None, - max_workers=2, - max_processes=2, - ) - - with pytest.raises(IngestionError) as error: - manager.run(df) - - assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) - assert error.value.failed_rows == [1, 1, 1, 1] - assert manager.failed_rows == [1, 1, 1, 1] - - -@pytest.fixture -def query(sagemaker_session_mock): - return AthenaQuery( - catalog="catalog", - database="database", - table_name="table_name", - sagemaker_session=sagemaker_session_mock, - ) - - -def test_athena_query_run(sagemaker_session_mock, query): - WORKGROUP = "workgroup" - sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"} - query.run( - query_string="query", output_location="s3://some-bucket/some-path", workgroup=WORKGROUP - ) - sagemaker_session_mock.start_query_execution.assert_called_with( - catalog="catalog", - database="database", - query_string="query", - output_location="s3://some-bucket/some-path", - kms_key=None, - workgroup=WORKGROUP, - ) - assert "some-bucket" == query._result_bucket - assert "some-path" == query._result_file_prefix - assert "query_id" == query._current_query_execution_id - - -def test_athena_query_wait(sagemaker_session_mock, query): - query._current_query_execution_id = "query_id" - query.wait() - sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id") - - -def test_athena_query_get_query_execution(sagemaker_session_mock, query): - query._current_query_execution_id = "query_id" - query.get_query_execution() - sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id") - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -@patch("pandas.read_csv") -def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "SUCCEEDED"}} - } - query._current_query_execution_id = "query_id" - query._result_bucket = "bucket" - query._result_file_prefix = "prefix" - query.as_dataframe() - sagemaker_session_mock.download_athena_query_result.assert_called_with( - bucket="bucket", - prefix="prefix", - query_execution_id="query_id", - filename="tmp/query_id.csv", + assert ( + "You must provide a record identifier feature name and an event time identifier feature " + + "name if specify DataFrame as base." + in str(error) + ) + + +def test_list_feature_groups_with_no_filter(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + feature_store.list_feature_groups() + sagemaker_session_mock.list_feature_groups.assert_called_with( + name_contains=None, + feature_group_status_equals=None, + offline_store_status_equals=None, + creation_time_after=None, + creation_time_before=None, + sort_order=None, + sort_by=None, + max_results=None, + next_token=None, + ) + + +def test_list_feature_groups_with_all_filters(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + feature_store.list_feature_groups( + name_contains="MyFeatureGroup", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after=datetime.datetime(2020, 12, 1), + creation_time_before=datetime.datetime(2022, 7, 1), + sort_order="Ascending", + sort_by="Name", + max_results=50, + next_token="token", + ) + sagemaker_session_mock.list_feature_groups.assert_called_with( + name_contains="MyFeatureGroup", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after=datetime.datetime(2020, 12, 1), + creation_time_before=datetime.datetime(2022, 7, 1), + sort_order="Ascending", + sort_by="Name", + max_results=50, + next_token="token", ) - read_csv.assert_called_with("tmp/query_id.csv", delimiter=",") - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_failed(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "FAILED"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Failed to execute query query_id" in str(error) - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_queued(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "QUEUED"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Current query query_id is still being executed" in str(error) - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_running(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "RUNNING"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Current query query_id is still being executed" in str(error) diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index c391d45382..072eefeb83 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -48,6 +48,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } @@ -142,14 +143,8 @@ def _create_train_job(version, base_framework_version): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/image_uris/test_algos.py b/tests/unit/sagemaker/image_uris/test_algos.py index 454d375b4b..443727094a 100644 --- a/tests/unit/sagemaker/image_uris/test_algos.py +++ b/tests/unit/sagemaker/image_uris/test_algos.py @@ -68,10 +68,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107", @@ -155,10 +157,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032", diff --git a/tests/unit/sagemaker/image_uris/test_sklearn.py b/tests/unit/sagemaker/image_uris/test_sklearn.py index d0fcbdb300..8563753e8c 100644 --- a/tests/unit/sagemaker/image_uris/test_sklearn.py +++ b/tests/unit/sagemaker/image_uris/test_sklearn.py @@ -37,10 +37,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249", diff --git a/tests/unit/sagemaker/image_uris/test_xgboost.py b/tests/unit/sagemaker/image_uris/test_xgboost.py index 78ab7e10ee..4d0f9f1dc3 100644 --- a/tests/unit/sagemaker/image_uris/test_xgboost.py +++ b/tests/unit/sagemaker/image_uris/test_xgboost.py @@ -35,10 +35,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032", @@ -67,10 +69,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249", diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 2e7576421f..771b18b35a 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -56,6 +56,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } @@ -135,14 +136,8 @@ def _create_train_job(tf_version, horovod=False, ps=False, py_version="py2", smd "metric_definitions": None, "environment": None, "experiment_config": None, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index af46cf4360..656730a47c 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -52,6 +52,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } @@ -144,14 +145,8 @@ def _create_train_job( "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index 5aef9316da..c3684ac649 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -50,6 +50,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } @@ -142,14 +143,8 @@ def _create_train_job( "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py new file mode 100644 index 0000000000..068bb4e4b9 --- /dev/null +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -0,0 +1,612 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import logging + +import json +import os + +import pytest +from mock import MagicMock, Mock, patch, ANY +from packaging.version import Version + +from sagemaker import image_uris +from sagemaker.pytorch import PyTorch, TrainingCompilerConfig +from sagemaker.pytorch.model import PyTorchModel +from sagemaker.instance_group import InstanceGroup + +from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES + + +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "..", "data") +SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") +SERVING_SCRIPT_FILE = "another_dummy_script.py" +MODEL_DATA = "s3://some/data.tar.gz" +ENV = {"DUMMY_ENV_VAR": "dummy_value"} +TIMESTAMP = "2017-11-06-14:14:15.672" +TIME = 1510006209.073025 +BUCKET_NAME = "mybucket" +INSTANCE_COUNT = 1 +INSTANCE_TYPE = "ml.p3.2xlarge" +IMAGE_URI = "pytorch" +JOB_NAME = "{}-{}".format(IMAGE_URI, TIMESTAMP) +ROLE = "Dummy" +REGION = "us-east-1" +GPU = "ml.p3.2xlarge" +SUPPORTED_GPU_INSTANCE_CLASSES = {"p3", "p3dn", "g4dn", "p4d", "g5"} +UNSUPPORTED_GPU_INSTANCE_CLASSES = EC2_GPU_INSTANCE_CLASSES - SUPPORTED_GPU_INSTANCE_CLASSES + +LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} + +EXPERIMENT_CONFIG = { + "ExperimentName": "exp", + "TrialName": "trial", + "TrialComponentDisplayName": "tc", +} + + +@pytest.fixture(scope="module") +def cpu_instance_type(): + return "ml.m5.xlarge" + + +@pytest.fixture(name="sagemaker_session", scope="function") +def fixture_sagemaker_session(): + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + s3_resource=None, + s3_client=None, + ) + + describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} + session.sagemaker_client.describe_training_job = Mock(return_value=describe) + session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + session.expand_role = Mock(name="expand_role", return_value=ROLE) + return session + + +def _get_full_gpu_image_uri(version, instance_type, training_compiler_config): + return image_uris.retrieve( + "pytorch-training-compiler", + REGION, + version=version, + py_version="py38", + instance_type=instance_type, + image_scope="training", + container_version=None, + training_compiler_config=training_compiler_config, + ) + + +def _create_train_job(version, instance_type, training_compiler_config, instance_count=1): + return { + "image_uri": _get_full_gpu_image_uri(version, instance_type, training_compiler_config), + "input_mode": "File", + "input_config": [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + } + }, + } + ], + "role": ROLE, + "job_name": JOB_NAME, + "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + "resource_config": { + "InstanceType": instance_type, + "InstanceCount": instance_count, + "VolumeSizeInGB": 30, + }, + "hyperparameters": { + "sagemaker_program": json.dumps("dummy_script.py"), + "sagemaker_container_log_level": str(logging.INFO), + "sagemaker_job_name": json.dumps(JOB_NAME), + "sagemaker_submit_directory": json.dumps( + "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, JOB_NAME) + ), + "sagemaker_region": '"us-east-1"', + }, + "stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "tags": None, + "vpc_config": None, + "metric_definitions": None, + "environment": None, + "retry_strategy": None, + "experiment_config": EXPERIMENT_CONFIG, + "debugger_hook_config": { + "CollectionConfigurations": [], + "S3OutputPath": "s3://{}/".format(BUCKET_NAME), + }, + "profiler_config": { + "DisableProfiler": False, + "S3OutputPath": "s3://{}/".format(BUCKET_NAME), + }, + } + + +def test_unsupported_BYOC( + pytorch_training_compiler_version, +): + byoc = ( + "1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:" + "1.12.0-" + "gpu-" + "py38-cu113-ubuntu20.04" + ) + with pytest.raises(ValueError): + PyTorch( + image_uri=byoc, + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_cpu_instance(cpu_instance_type, pytorch_training_compiler_version): + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=cpu_instance_type, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +@pytest.mark.parametrize("unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES) +def test_unsupported_gpu_instance( + unsupported_gpu_instance_class, pytorch_training_compiler_version +): + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=f"ml.{unsupported_gpu_instance_class}.xlarge", + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +@pytest.mark.xfail(reason="With only 1 supported version, user input is ignored.") +def test_unsupported_framework_version(): + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version="99.99.99", + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_python_2( + pytorch_training_compiler_version, +): + with pytest.raises(ValueError): + PyTorch( + py_version="py27", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_instance_group( + pytorch_training_compiler_version, +): + if Version(pytorch_training_compiler_version) < Version("1.12"): + pytest.skip("This test is intended for PyTorch 1.12 and above") + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_groups=[ + InstanceGroup("ml.p3dn.24xlarge", "ml.p3dn.24xlarge", 16), + InstanceGroup("ml.p4d.24xlarge", "ml.p4d.24xlarge", 16), + ], + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_distribution( + pytorch_training_compiler_version, +): + if Version(pytorch_training_compiler_version) < Version("1.12"): + pytest.skip("This test is intended for PyTorch 1.12 and above") + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=2, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"smdistributed": {"dataparallel": {"enabled": True}}}, + ).fit() + + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=2, + instance_type=INSTANCE_TYPE, + transformers_version="4.17", + pytorch_version="1.10", + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"pytorchxla": {"enabled": True}}, + ).fit() + + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=2, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"mpi": {"enabled": True}}, + ).fit() + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) +def test_pytorchxla_distribution( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class +): + if Version(pytorch_training_compiler_version) < Version("1.12"): + pytest.skip("This test is intended for PyTorch 1.12 and above") + compiler_config = TrainingCompilerConfig() + instance_type = f"ml.{instance_class}.xlarge" + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=2, + instance_type=instance_type, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"pytorchxla": {"enabled": True}}, + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, instance_type, compiler_config, instance_count=2 + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + True + ) + expected_train_args["hyperparameters"][PyTorch.LAUNCH_PT_XLA_ENV_NAME] = json.dumps(True) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + False + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) +def test_default_compiler_config( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class +): + compiler_config = TrainingCompilerConfig() + instance_type = f"ml.{instance_class}.xlarge" + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=instance_type, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=compiler_config, + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, instance_type, compiler_config + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + True + ) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + False + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +def test_debug_compiler_config( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version +): + compiler_config = TrainingCompilerConfig(debug=True) + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=compiler_config, + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + True + ) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + True + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +def test_disable_compiler_config( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version +): + compiler_config = TrainingCompilerConfig(enabled=False) + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(enabled=False), + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + False + ) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + False + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@pytest.mark.parametrize( + ["compiler_enabled", "debug_enabled"], [(True, False), (True, True), (False, False)] +) +def test_attach(sagemaker_session, compiler_enabled, debug_enabled): + training_image = ( + "1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:" + "1.12.0-" + "gpu-" + "py38-cu113-ubuntu20.04" + ) + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"trcomp"', + "training_steps": "100", + "sagemaker_region": '"us-east-1"', + TrainingCompilerConfig.HP_ENABLE_COMPILER: json.dumps(compiler_enabled), + TrainingCompilerConfig.HP_ENABLE_DEBUG: json.dumps(debug_enabled), + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.p3.2xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "trcomp", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/trcomp", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/trcomp"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = PyTorch.attach(training_job_name="trcomp", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "trcomp" + assert estimator.py_version == "py38" + assert estimator.framework_version == "1.12.0" + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" + assert estimator.instance_count == 1 + assert estimator.max_run == 24 * 60 * 60 + assert estimator.input_mode == "File" + assert estimator.base_job_name == "trcomp" + assert estimator.output_path == "s3://place/output/trcomp" + assert estimator.output_kms_key == "" + assert estimator.hyperparameters()["training_steps"] == "100" + assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_COMPILER] == json.dumps( + compiler_enabled + ) + assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_DEBUG] == json.dumps( + debug_enabled + ) + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "iris-dnn-classifier.py" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +def test_register_pytorch_model_auto_infer_framework( + sagemaker_session, pytorch_training_compiler_version +): + + model_package_group_name = "test-pt-register-model" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarge"] + image_uri = "fakeimage" + + pt_model = PyTorchModel( + model_data="s3://some/data.tar.gz", + role=ROLE, + entry_point=SCRIPT_PATH, + framework_version=pytorch_training_compiler_version, + py_version="py38", + sagemaker_session=sagemaker_session, + ) + + pt_model.register( + content_types, + response_types, + inference_instances, + transform_instances, + model_package_group_name=model_package_group_name, + marketplace_cert=True, + image_uri=image_uri, + ) + + expected_create_model_package_request = { + "containers": [ + { + "Image": image_uri, + "Environment": ANY, + "ModelDataUrl": ANY, + "Framework": "PYTORCH", + "FrameworkVersion": pytorch_training_compiler_version, + } + ], + "content_types": content_types, + "response_types": response_types, + "inference_instances": inference_instances, + "transform_instances": transform_instances, + "model_package_group_name": model_package_group_name, + "marketplace_cert": True, + } + + sagemaker_session.create_model_package_from_containers.assert_called_with( + **expected_create_model_package_request + ) diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index 7517f3a641..a5c14b1626 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -50,6 +50,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } @@ -144,14 +145,8 @@ def _create_train_job(framework_version, instance_type, training_compiler_config "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/utilities/test_search_expression.py b/tests/unit/sagemaker/utilities/test_search_expression.py new file mode 100644 index 0000000000..98a52a992a --- /dev/null +++ b/tests/unit/sagemaker/utilities/test_search_expression.py @@ -0,0 +1,80 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from sagemaker.utilities.search_expression import ( + Filter, + Operator, + NestedFilter, + SearchExpression, + BooleanOperator, +) + + +def test_filters(): + search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1") + + assert { + "Name": "learning_rate", + "Operator": "Equals", + "Value": "0.1", + } == search_filter.to_boto() + + +def test_partial_filters(): + search_filter = Filter(name="learning_rate") + + assert {"Name": "learning_rate"} == search_filter.to_boto() + + +def test_nested_filters(): + search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1") + filters = [search_filter] + nested_filters = NestedFilter(property_name="hyper_param", filters=filters) + + assert { + "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}], + "NestedPropertyName": "hyper_param", + } == nested_filters.to_boto() + + +def test_search_expression(): + search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1") + nested_filter = NestedFilter(property_name="hyper_param", filters=[search_filter]) + search_expression = SearchExpression( + filters=[search_filter], + nested_filters=[nested_filter], + sub_expressions=[], + boolean_operator=BooleanOperator.AND, + ) + + assert { + "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}], + "NestedFilters": [ + { + "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}], + "NestedPropertyName": "hyper_param", + } + ], + "SubExpressions": [], + "Operator": "And", + } == search_expression.to_boto() + + +def test_illegal_search_expression(): + with pytest.raises( + ValueError, match="You must specify at least one subexpression, filter, or nested filter" + ): + SearchExpression() diff --git a/tests/unit/sagemaker/workflow/conftest.py b/tests/unit/sagemaker/workflow/conftest.py new file mode 100644 index 0000000000..9ea3d0bcac --- /dev/null +++ b/tests/unit/sagemaker/workflow/conftest.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import Mock, PropertyMock + +import pytest + +from sagemaker import Session +from sagemaker.workflow.pipeline_context import PipelineSession + +REGION = "us-west-2" +BUCKET = "my-bucket" +ROLE = "DummyRole" +IMAGE_URI = "fakeimage" + + +@pytest.fixture(scope="module") +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture(scope="module") +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value=ROLE) + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name=REGION) + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + +@pytest.fixture(scope="module") +def pipeline_session(boto_session, client): + return PipelineSession( + boto_session=boto_session, + sagemaker_client=client, + default_bucket=BUCKET, + ) + + +@pytest.fixture(scope="module") +def sagemaker_session(boto_session, client): + return Session( + boto_session=boto_session, + sagemaker_client=client, + sagemaker_runtime_client=client, + default_bucket=BUCKET, + ) diff --git a/tests/unit/sagemaker/workflow/test_clarify_check_step.py b/tests/unit/sagemaker/workflow/test_clarify_check_step.py index feadaa03dc..54b354b71e 100644 --- a/tests/unit/sagemaker/workflow/test_clarify_check_step.py +++ b/tests/unit/sagemaker/workflow/test_clarify_check_step.py @@ -16,10 +16,6 @@ import re import pytest -import sagemaker - -from mock import Mock, PropertyMock - from sagemaker.clarify import ( DataConfig, BiasConfig, @@ -50,46 +46,6 @@ _S3_ANALYSIS_CONFIG_OUTPUT_PATH = "s3://my_bucket/analysis_cfg_output" -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=_ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=_REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=_DEFAULT_BUCKET, - ) - - _expected_data_bias_dsl = { "Name": "DataBiasCheckStep", "Type": "ClarifyCheck", diff --git a/tests/unit/sagemaker/workflow/test_entities.py b/tests/unit/sagemaker/workflow/test_entities.py index 6f0be2ccca..a36207b241 100644 --- a/tests/unit/sagemaker/workflow/test_entities.py +++ b/tests/unit/sagemaker/workflow/test_entities.py @@ -19,9 +19,6 @@ from enum import Enum -from mock.mock import Mock, PropertyMock - -import sagemaker from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.conditions import ConditionGreaterThan from sagemaker.workflow.entities import ( @@ -58,46 +55,6 @@ def custom_entity_list(): return [CustomEntity(1), CustomEntity(2)] -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value="role") - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name="us-west-2") - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket="my-bucket", - ) - - def test_entity(custom_entity): request_struct = {"foo": 1} assert custom_entity.to_request() == request_struct diff --git a/tests/unit/sagemaker/workflow/test_model_step.py b/tests/unit/sagemaker/workflow/test_model_step.py index 080e70ca62..2216299d3b 100644 --- a/tests/unit/sagemaker/workflow/test_model_step.py +++ b/tests/unit/sagemaker/workflow/test_model_step.py @@ -15,7 +15,7 @@ import json import os -from mock import Mock, PropertyMock, patch +from mock import patch import pytest @@ -43,7 +43,6 @@ ) from sagemaker.workflow.parameters import ParameterString, ParameterInteger from sagemaker.workflow.pipeline import Pipeline, PipelineGraph -from sagemaker.workflow.pipeline_context import PipelineSession from sagemaker.workflow.retry import ( StepRetryPolicy, StepExceptionTypeEnum, @@ -55,11 +54,9 @@ from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum from tests.unit import DATA_DIR from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered +from tests.unit.sagemaker.workflow.conftest import BUCKET, ROLE _IMAGE_URI = "fakeimage" -_REGION = "us-west-2" -_BUCKET = "my-bucket" -_ROLE = "DummyRole" _INSTANCE_TYPE = "ml.m4.xlarge" _SAGEMAKER_PROGRAM = SCRIPT_PARAM_NAME.upper() @@ -69,60 +66,10 @@ _XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone") _TENSORFLOW_PATH = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies") _REPACK_OUTPUT_KEY_PREFIX = "code-output" -_MODEL_CODE_LOCATION = f"s3://{_BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}" +_MODEL_CODE_LOCATION = f"s3://{BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}" _MODEL_CODE_LOCATION_TRAILING_SLASH = _MODEL_CODE_LOCATION + "/" -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def boto_session(client): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=_ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=_REGION) - session_mock.resource.return_value = resource_mock - session_mock.client.return_value = client - - return session_mock - - -@pytest.fixture -def pipeline_session(boto_session, client): - return PipelineSession( - boto_session=boto_session, - sagemaker_client=client, - default_bucket=_BUCKET, - ) - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=_BUCKET, - ) - - @pytest.fixture def model_data_param(): return ParameterString(name="ModelData", default_value="s3://my-bucket/file") @@ -137,7 +84,7 @@ def model(pipeline_session, model_data_param): sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", source_dir=f"{DATA_DIR}", - role=_ROLE, + role=ROLE, ) @@ -322,13 +269,13 @@ def test_create_pipeline_model_with_runtime_repack(pipeline_session, model_data_ sparkml_model = SparkMLModel( name="MySparkMLModel", model_data=model_data_param, - role=_ROLE, + role=ROLE, sagemaker_session=pipeline_session, env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, ) # The model need to runtime repack ppl_model = PipelineModel( - models=[sparkml_model, model], role=_ROLE, sagemaker_session=pipeline_session + models=[sparkml_model, model], role=ROLE, sagemaker_session=pipeline_session ) step_args = ppl_model.create( instance_type="c4.4xlarge", @@ -417,7 +364,7 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat # The model no need to runtime repack, since source_dir is missing sparkml_model = SparkMLModel( model_data=model_data_param, - role=_ROLE, + role=ROLE, sagemaker_session=pipeline_session, env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", @@ -429,11 +376,11 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", source_dir=f"{DATA_DIR}", - role=_ROLE, + role=ROLE, env={"k": "v"}, ) model = PipelineModel( - models=[sparkml_model, model], role=_ROLE, sagemaker_session=pipeline_session + models=[sparkml_model, model], role=ROLE, sagemaker_session=pipeline_session ) step_args = model.register( content_types=["text/csv"], @@ -516,7 +463,7 @@ def test_register_model_without_repack(pipeline_session): model_data=model_data, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", sagemaker_session=pipeline_session, - role=_ROLE, + role=ROLE, ) step_args = model.register( content_types=["text/csv"], @@ -547,7 +494,7 @@ def test_register_model_without_repack(pipeline_session): assert containers[0]["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME assert ( containers[0]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY] - == f"s3://{_BUCKET}/{model_name}/sourcedir.tar.gz" + == f"s3://{BUCKET}/{model_name}/sourcedir.tar.gz" ) adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list assert ordered(adjacency_list) == ordered({"MyModelStep-RegisterModel": []}) @@ -560,11 +507,11 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session): model = Model( name=model_name, image_uri=_IMAGE_URI, - model_data=f"s3://{_BUCKET}/model.tar.gz", + model_data=f"s3://{BUCKET}/model.tar.gz", sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", source_dir=f"{DATA_DIR}", - role=_ROLE, + role=ROLE, ) step_args = model.create( instance_type="c4.4xlarge", @@ -582,7 +529,7 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session): arguments = step_dsl_list[0]["Arguments"] assert arguments["PrimaryContainer"]["Image"] == _IMAGE_URI assert ( - arguments["PrimaryContainer"]["ModelDataUrl"] == f"s3://{_BUCKET}/{model_name}/model.tar.gz" + arguments["PrimaryContainer"]["ModelDataUrl"] == f"s3://{BUCKET}/{model_name}/model.tar.gz" ) assert arguments["PrimaryContainer"]["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME assert arguments["PrimaryContainer"]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY] == _DIR_NAME @@ -700,7 +647,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, enable_network_isolation=True, code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH, ), @@ -713,7 +660,7 @@ def test_conditional_model_create_and_regis( framework_version="1.11.0", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, enable_network_isolation=False, ), 1, @@ -724,7 +671,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, framework_version="1.5.0", code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH, ), @@ -736,7 +683,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, framework_version="1.2.0", ), 1, @@ -747,7 +694,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, ), 2, ), @@ -757,7 +704,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH, ), 2, @@ -768,7 +715,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, ), 1, ), @@ -789,7 +736,7 @@ def assert_test_result(steps: list): ) else: assert steps[0]["Arguments"]["OutputDataConfig"]["S3OutputPath"] == ( - f"s3://{_BUCKET}/{model.name}" + f"s3://{BUCKET}/{model.name}" ) model, expected_step_num = test_input @@ -828,7 +775,7 @@ def assert_test_result(steps: list): XGBoostModel( model_data="dummy_model_step", framework_version="1.3-1", - role=_ROLE, + role=ROLE, entry_point=os.path.join(_XGBOOST_PATH, "inference.py"), enable_network_isolation=True, ), @@ -845,7 +792,7 @@ def assert_test_result(steps: list): XGBoostModel( model_data="dummy_model_step", framework_version="1.3-1", - role=_ROLE, + role=ROLE, entry_point=os.path.join(_XGBOOST_PATH, "inference.py"), ), { @@ -861,7 +808,7 @@ def assert_test_result(steps: list): XGBoostModel( model_data="dummy_model_step", framework_version="1.3-1", - role=_ROLE, + role=ROLE, entry_point=None, ), { @@ -876,9 +823,8 @@ def assert_test_result(steps: list): ( TensorFlowModel( model_data="dummy_model_step", - role=_ROLE, + role=ROLE, image_uri=_IMAGE_URI, - sagemaker_session=pipeline_session, entry_point=os.path.join(_TENSORFLOW_PATH, "inference.py"), ), { @@ -893,9 +839,8 @@ def assert_test_result(steps: list): ( TensorFlowModel( model_data="dummy_model_step", - role=_ROLE, + role=ROLE, image_uri=_IMAGE_URI, - sagemaker_session=pipeline_session, ), { "expected_step_num": 1, @@ -941,7 +886,7 @@ def test_request_compare_of_register_model_under_different_sessions( _verify_register_model_container_definition(regis_step_arg, expect, dict) # Get create model package request under Session - model.model_data = f"s3://{_BUCKET}" + model.model_data = f"s3://{BUCKET}" model.sagemaker_session = sagemaker_session with patch.object( Session, "_intercept_create_request", return_value=dict(ModelPackageArn="arn:aws") @@ -996,7 +941,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): model_data=lambda_step.properties.Outputs["model_artifact"], sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, ) step_create_model = ModelStep(name="mymodelstep", step_args=model.create()) @@ -1031,7 +976,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): ( Processor( image_uri=_IMAGE_URI, - role=_ROLE, + role=ROLE, instance_count=1, instance_type=_INSTANCE_TYPE, ), @@ -1052,7 +997,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): ( HyperparameterTuner( estimator=Estimator( - role=_ROLE, + role=ROLE, instance_count=1, instance_type=_INSTANCE_TYPE, image_uri=_IMAGE_URI, @@ -1064,7 +1009,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): ), ( Estimator( - role=_ROLE, + role=ROLE, instance_count=1, instance_type=_INSTANCE_TYPE, image_uri=_IMAGE_URI, @@ -1128,3 +1073,31 @@ def test_pass_in_wrong_type_of_retry_policies(pipeline_session, model): ), ) assert "SageMakerJobStepRetryPolicy is not allowed for a create/registe" in str(error.value) + + +def test_register_model_step_with_model_package_name(pipeline_session): + model = Model( + name="MyModel", + image_uri="my-image", + model_data="s3://", + sagemaker_session=pipeline_session, + ) + step_args = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_name="model-pkg-name-will-be-popped-out", + ) + regis_model_step = ModelStep( + name="MyModelStep", + step_args=step_args, + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[regis_model_step], + sagemaker_session=pipeline_session, + ) + steps = json.loads(pipeline.definition())["Steps"] + assert len(steps) == 1 + assert "ModelPackageName" not in steps[0]["Arguments"] diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index 327443aee7..f0cb2e5234 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -17,7 +17,7 @@ import pytest -from mock import Mock +from mock import Mock, patch from sagemaker import s3 from sagemaker.workflow.condition_step import ConditionStep @@ -78,6 +78,7 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar ) +@patch("sagemaker.s3.S3Uploader.upload_string_as_file_body") def test_large_pipeline_create(sagemaker_session_mock, role_arn): parameter = ParameterString("MyStr") pipeline = Pipeline( @@ -87,8 +88,6 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn): sagemaker_session=sagemaker_session_mock, ) - s3.S3Uploader.upload_string_as_file_body = Mock() - pipeline.create(role_arn=role_arn) assert s3.S3Uploader.upload_string_as_file_body.called_with( @@ -151,6 +150,7 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar ) +@patch("sagemaker.s3.S3Uploader.upload_string_as_file_body") def test_large_pipeline_update(sagemaker_session_mock, role_arn): parameter = ParameterString("MyStr") pipeline = Pipeline( @@ -160,8 +160,6 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn): sagemaker_session=sagemaker_session_mock, ) - s3.S3Uploader.upload_string_as_file_body = Mock() - pipeline.create(role_arn=role_arn) assert s3.S3Uploader.upload_string_as_file_body.called_with( diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index fd32fd7c73..9ba242b9b2 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -13,7 +13,8 @@ from __future__ import absolute_import import json -from mock import Mock, PropertyMock +import os +from mock import Mock, PropertyMock, patch import pytest import warnings @@ -45,9 +46,11 @@ from sagemaker.workflow.steps import CacheConfig, ProcessingStep from sagemaker.workflow.pipeline import Pipeline, PipelineGraph +from sagemaker.workflow.pipeline_context import _PipelineConfig from sagemaker.workflow.properties import PropertyFile from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.functions import Join +from sagemaker.workflow.utilities import hash_files_or_dirs from sagemaker.workflow import is_pipeline_variable from sagemaker.network import NetworkConfig @@ -63,6 +66,7 @@ SHAPConfig, ) from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered, get_step_args_helper +from tests.unit import DATA_DIR REGION = "us-west-2" BUCKET = "my-bucket" @@ -70,7 +74,20 @@ IMAGE_URI = "fakeimage" MODEL_NAME = "gisele" DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py" +LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "workflow/abalone/preprocessing.py") +SPARK_APP_JAR_PATH = os.path.join( + DATA_DIR, "spark/code/java/hello-java-spark/HelloJavaSparkApp.jar" +) +SPARK_DEP_JAR = os.path.join(DATA_DIR, "spark/code/java/TestJarFile.jar") +SPARK_APP_PY_PATH = os.path.join(DATA_DIR, "spark/code/python/hello_py_spark/hello_py_spark_app.py") +SPARK_PY_FILE1 = os.path.join(DATA_DIR, "spark/code/python/hello_py_spark/__init__.py") +SPARK_PY_FILE2 = os.path.join(DATA_DIR, "spark/code/python/hello_py_spark/hello_py_spark_udfs.py") +SPARK_SUBMIT_FILE1 = os.path.join(DATA_DIR, "spark/files/data.jsonl") +SPARK_SUBMIT_FILE2 = os.path.join(DATA_DIR, "spark/files/sample_spark_event_logs") INSTANCE_TYPE = "ml.m4.xlarge" +MOCKED_PIPELINE_CONFIG = _PipelineConfig( + "MyPipeline", "MyProcessingStep", hash_files_or_dirs([LOCAL_SCRIPT_PATH]), "config-hash-abcdefg" +) FRAMEWORK_PROCESSOR = [ ( @@ -154,6 +171,19 @@ ), ] +FRAMEWORK_PROCESSOR_LOCAL_CODE = [ + ( + FrameworkProcessor( + framework_version="1.8", + instance_type=INSTANCE_TYPE, + instance_count=1, + role=ROLE, + estimator_cls=PyTorch, + ), + {"code": LOCAL_SCRIPT_PATH}, + ), +] + PROCESSING_INPUT = [ ProcessingInput(source="s3://my-bucket/processing_manifest", destination="processing_manifest"), ProcessingInput( @@ -318,7 +348,8 @@ def test_processing_step_with_processor( else: expected_step_arguments["ExperimentConfig"] = expected_experiment_config - assert json.loads(pipeline.definition())["Steps"][0] == { + step_def = json.loads(pipeline.definition())["Steps"][0] + assert step_def == { "Name": "MyProcessingStep", "Description": "ProcessingStep description", "DisplayName": "MyProcessingStep", @@ -346,6 +377,10 @@ def test_processing_step_with_processor( } ) + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + assert step_def == step_def2 + @pytest.mark.parametrize( "image_uri", @@ -387,7 +422,11 @@ def test_processing_step_with_processor_and_step_args( assert isinstance(e, ValueError) -def test_processing_step_with_script_processor(pipeline_session, processing_input, network_config): +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +@pytest.mark.parametrize("code_artifact", [DUMMY_S3_SCRIPT_PATH, LOCAL_SCRIPT_PATH]) +def test_processing_step_with_script_processor( + pipeline_session, processing_input, network_config, code_artifact +): processor = ScriptProcessor( role=ROLE, image_uri=IMAGE_URI, @@ -406,7 +445,7 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu ) step_args = processor.run( - inputs=processing_input, code=DUMMY_S3_SCRIPT_PATH, job_name="my-processing-job" + inputs=processing_input, code=code_artifact, job_name="my-processing-job" ) step = ProcessingStep( @@ -420,11 +459,13 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu sagemaker_session=pipeline_session, ) - assert json.loads(pipeline.definition())["Steps"][0] == { - "Name": "MyProcessingStep", - "Type": "Processing", - "Arguments": get_step_args_helper(step_args, "Processing"), - } + step_args = get_step_args_helper(step_args, "Processing") + step_def = json.loads(pipeline.definition())["Steps"][0] + assert step_def == {"Name": "MyProcessingStep", "Type": "Processing", "Arguments": step_args} + + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + assert step_def == step_def2 @pytest.mark.parametrize("framework_processor", FRAMEWORK_PROCESSOR) @@ -477,6 +518,66 @@ def test_processing_step_with_framework_processor( "Arguments": step_args, } + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + del step_def2["Arguments"]["ProcessingInputs"][0]["S3Input"]["S3Uri"] + del step_def2["Arguments"]["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"] + assert step_def == step_def2 + + +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +@pytest.mark.parametrize("framework_processor", FRAMEWORK_PROCESSOR_LOCAL_CODE) +def test_processing_step_with_framework_processor_local_code( + framework_processor, pipeline_session, network_config +): + processor, run_inputs = framework_processor + processor.sagemaker_session = pipeline_session + processor.role = ROLE + + processor.volume_kms_key = "volume-kms-key" + processor.network_config = network_config + + processing_input = ProcessingInput( + source="s3://my-bucket/processing_manifest", + destination="processing_manifest", + input_name="manifest", + ) + processing_output = ProcessingOutput( + output_name="framework_output", source="/opt/ml/processing/framework_output" + ) + + run_inputs["inputs"] = [processing_input] + run_inputs["outputs"] = [processing_output] + + step_args = processor.run(**run_inputs) + + step = ProcessingStep( + name="MyProcessingStep", + step_args=step_args, + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, + ) + + step_args = get_step_args_helper(step_args, "Processing") + step_def = json.loads(pipeline.definition())["Steps"][0] + + del step_args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"] + del step_def["Arguments"]["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"] + + assert step_def == { + "Name": "MyProcessingStep", + "Type": "Processing", + "Arguments": step_args, + } + + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + del step_def2["Arguments"]["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"] + assert step_def == step_def2 + def test_processing_step_with_clarify_processor(pipeline_session): def headers(): @@ -530,12 +631,17 @@ def verify(step_args): steps=[step], sagemaker_session=pipeline_session, ) - assert json.loads(pipeline.definition())["Steps"][0] == { + step_def = json.loads(pipeline.definition())["Steps"][0] + assert step_def == { "Name": "MyProcessingStep", "Type": "Processing", "Arguments": get_step_args_helper(step_args, "Processing"), } + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + assert step_def == step_def2 + test_run = utils.unique_name_from_base("test_run") output_path = "s3://{}/{}/{}".format( pipeline_session.default_bucket(), "linear_learner_analysis_result", test_run @@ -852,4 +958,153 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session): steps=[step], sagemaker_session=pipeline_session, ) - pipeline.definition() + + # test for idempotency + step_def = json.loads(pipeline.definition())["Steps"][0] + step_def_2 = json.loads(pipeline.definition())["Steps"][0] + assert step_def == step_def_2 + + +@pytest.mark.parametrize( + "spark_processor", + [ + ( + SparkJarProcessor( + role=ROLE, + framework_version="2.4", + instance_count=1, + instance_type=INSTANCE_TYPE, + ), + { + "submit_app": SPARK_APP_JAR_PATH, + "submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp", + "arguments": [ + "--input", + "input-data-uri", + "--output", + ParameterString("MyArgOutput"), + ], + "submit_jars": [ + SPARK_DEP_JAR, + ], + "submit_files": [ + SPARK_SUBMIT_FILE1, + SPARK_SUBMIT_FILE2, + ], + "spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri"), + "configuration": { + "Classification": "core-site", + "Properties": {"hadoop.security.groups.cache.secs": "250"}, + }, + }, + ), + ( + PySparkProcessor( + role=ROLE, + framework_version="2.4", + instance_count=1, + instance_type=INSTANCE_TYPE, + ), + { + "submit_app": SPARK_APP_PY_PATH, + "arguments": [ + "--input", + "input-data-uri", + "--output", + ParameterString("MyArgOutput"), + ], + "submit_py_files": [ + SPARK_PY_FILE1, + SPARK_PY_FILE2, + ], + "submit_jars": [SPARK_DEP_JAR], + "submit_files": [SPARK_SUBMIT_FILE1, SPARK_SUBMIT_FILE2], + "spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri"), + "configuration": { + "Classification": "core-site", + "Properties": {"hadoop.security.groups.cache.secs": "250"}, + }, + }, + ), + ], +) +def test_spark_processor_local_code(spark_processor, processing_input, pipeline_session): + processor, run_inputs = spark_processor + processor.sagemaker_session = pipeline_session + processor.role = ROLE + + run_inputs["inputs"] = processing_input + + step_args = processor.run(**run_inputs) + step = ProcessingStep( + name="MyProcessingStep", + step_args=step_args, + ) + + step_args = get_step_args_helper(step_args, "Processing") + + assert step_args["AppSpecification"]["ContainerArguments"] == run_inputs["arguments"] + + entry_points = step_args["AppSpecification"]["ContainerEntrypoint"] + entry_points_expr = [] + for entry_point in entry_points: + if is_pipeline_variable(entry_point): + entry_points_expr.append(entry_point.expr) + else: + entry_points_expr.append(entry_point) + + if "submit_py_files" in run_inputs: + expected = [ + "smspark-submit", + "--py-files", + "/opt/ml/processing/input/py-files", + "--jars", + "/opt/ml/processing/input/jars", + "--files", + "/opt/ml/processing/input/files", + "--local-spark-event-logs-dir", + "/opt/ml/processing/spark-events/", + "/opt/ml/processing/input/code/hello_py_spark_app.py", + ] + # py spark + else: + expected = [ + "smspark-submit", + "--class", + "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp", + "--jars", + "/opt/ml/processing/input/jars", + "--files", + "/opt/ml/processing/input/files", + "--local-spark-event-logs-dir", + "/opt/ml/processing/spark-events/", + "/opt/ml/processing/input/code/HelloJavaSparkApp.jar", + ] + + assert entry_points_expr == expected + for output in step_args["ProcessingOutputConfig"]["Outputs"]: + if is_pipeline_variable(output["S3Output"]["S3Uri"]): + output["S3Output"]["S3Uri"] = output["S3Output"]["S3Uri"].expr + + assert step_args["ProcessingOutputConfig"]["Outputs"] == [ + { + "OutputName": "output-1", + "AppManaged": False, + "S3Output": { + "S3Uri": {"Get": "Parameters.MySparkEventLogS3Uri"}, + "LocalPath": "/opt/ml/processing/spark-events/", + "S3UploadMode": "Continuous", + }, + } + ] + + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, + ) + + # test for idempotency + step_def = json.loads(pipeline.definition())["Steps"][0] + step_def2 = json.loads(pipeline.definition())["Steps"][0] + assert step_def == step_def2 diff --git a/tests/unit/sagemaker/workflow/test_quality_check_step.py b/tests/unit/sagemaker/workflow/test_quality_check_step.py index b60e2de8fa..dc104d71df 100644 --- a/tests/unit/sagemaker/workflow/test_quality_check_step.py +++ b/tests/unit/sagemaker/workflow/test_quality_check_step.py @@ -15,10 +15,6 @@ import json import pytest -import sagemaker - -from mock import Mock, PropertyMock - from sagemaker.model_monitor import DatasetFormat from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.pipeline import Pipeline @@ -31,49 +27,7 @@ from sagemaker.workflow.steps import CacheConfig from sagemaker.workflow.check_job_config import CheckJobConfig -_REGION = "us-west-2" _ROLE = "DummyRole" -_BUCKET = "my-bucket" - - -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=_ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=_REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=_BUCKET, - ) _expected_data_quality_dsl = { diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 2bf47a79d0..95738c99ca 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -796,6 +796,7 @@ def test_register_model_with_model_repack_with_estimator( "CollectionConfigurations": [], "S3OutputPath": f"s3://{BUCKET}/", }, + "ProfilerConfig": {"DisableProfiler": True}, "HyperParameters": { "inference_script": '"dummy_script.py"', "dependencies": f'"{dummy_requirements}"', @@ -923,6 +924,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift "CollectionConfigurations": [], "S3OutputPath": f"s3://{BUCKET}/", }, + "ProfilerConfig": {"DisableProfiler": True}, "HyperParameters": { "inference_script": '"dummy_script.py"', "model_archive": '"s3://my-bucket/model.tar.gz"', @@ -1052,6 +1054,7 @@ def test_register_model_with_model_repack_with_pipeline_model( "CollectionConfigurations": [], "S3OutputPath": f"s3://{BUCKET}/", }, + "ProfilerConfig": {"DisableProfiler": True}, "HyperParameters": { "dependencies": "null", "inference_script": '"dummy_script.py"', @@ -1243,6 +1246,7 @@ def test_estimator_transformer_with_model_repack_with_estimator(estimator): "TrainingImage": "246618743249.dkr.ecr.us-west-2.amazonaws.com/" + "sagemaker-scikit-learn:0.23-1-cpu-py3", }, + "ProfilerConfig": {"DisableProfiler": True}, "OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"}, "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, "ResourceConfig": { diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 9887d43078..f2046cc00f 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -16,15 +16,10 @@ import json import pytest -import sagemaker import os import warnings -from mock import ( - Mock, - PropertyMock, - patch, -) +from mock import patch from sagemaker.debugger import ProfilerConfig from sagemaker.estimator import Estimator @@ -94,46 +89,6 @@ def create_predictor(self, endpoint_name): return Predictor(endpoint_name, self.sagemaker_session) -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=BUCKET, - ) - - @pytest.fixture def script_processor(sagemaker_session): return ScriptProcessor( @@ -374,6 +329,7 @@ def test_training_step_base_estimator(sagemaker_session): "CollectionConfigurations": [], }, "ProfilerConfig": { + "DisableProfiler": False, "ProfilingIntervalInMilliseconds": 500, "S3OutputPath": {"Std:Join": {"On": "/", "Values": ["s3:/", "a", "b"]}}, }, @@ -483,7 +439,7 @@ def test_training_step_tensorflow(sagemaker_session): "sagemaker_instance_type": {"Get": "Parameters.InstanceType"}, "sagemaker_distributed_dataparallel_custom_mpi_options": '""', }, - "ProfilerConfig": {"S3OutputPath": "s3://my-bucket/"}, + "ProfilerConfig": {"DisableProfiler": False, "S3OutputPath": "s3://my-bucket/"}, }, "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index 4133343c93..7f8e6b0c62 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -14,7 +14,7 @@ import os import json -from mock import Mock, PropertyMock +from mock import Mock, PropertyMock, patch import pytest import warnings @@ -25,11 +25,12 @@ from sagemaker.parameter import IntegerParameter from sagemaker.transformer import Transformer from sagemaker.tuner import HyperparameterTuner -from sagemaker.workflow.pipeline_context import PipelineSession +from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig from sagemaker.workflow.parameters import ParameterString, ParameterBoolean from sagemaker.workflow.steps import TrainingStep from sagemaker.workflow.pipeline import Pipeline, PipelineGraph +from sagemaker.workflow.utilities import hash_files_or_dirs from sagemaker.workflow.functions import Join from sagemaker.estimator import Estimator @@ -66,9 +67,19 @@ ROLE = "DummyRole" IMAGE_URI = "fakeimage" MODEL_NAME = "gisele" -DUMMY_LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") +LOCAL_ENTRY_POINT = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-with-handler/training.py") +LOCAL_SOURCE_DIR = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-with-handler") +LOCAL_DEPS = [ + os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies"), +] DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py" INSTANCE_TYPE = "ml.m4.xlarge" +MOCKED_PIPELINE_CONFIG = _PipelineConfig( + "MyPipeline", + "MyTrainingStep", + hash_files_or_dirs([LOCAL_SOURCE_DIR] + LOCAL_DEPS), + "config-hash-abcdefg", +) ESTIMATOR_LISTS = [ SKLearn( @@ -152,6 +163,88 @@ ), ] +ESTIMATOR_LISTS_LOCAL_CODE = [ + SKLearn( + framework_version="0.23-1", + py_version="py3", + instance_type=INSTANCE_TYPE, + instance_count=1, + role=ROLE, + entry_point=LOCAL_ENTRY_POINT, + source_dir=LOCAL_SOURCE_DIR, + ), + PyTorch( + role=ROLE, + instance_type=INSTANCE_TYPE, + instance_count=1, + framework_version="1.8.0", + py_version="py36", + entry_point=LOCAL_ENTRY_POINT, + source_dir=LOCAL_SOURCE_DIR, + ), + TensorFlow( + role=ROLE, + entry_point=LOCAL_ENTRY_POINT, + source_dir=LOCAL_SOURCE_DIR, + instance_type=INSTANCE_TYPE, + instance_count=1, + framework_version="2.0", + py_version="py3", + ), + HuggingFace( + transformers_version="4.6", + pytorch_version="1.7", + role=ROLE, + instance_type="ml.p3.2xlarge", + instance_count=1, + py_version="py36", + entry_point=LOCAL_ENTRY_POINT, + source_dir=LOCAL_SOURCE_DIR, + ), + XGBoost( + framework_version="1.3-1", + py_version="py3", + role=ROLE, + instance_type=INSTANCE_TYPE, + instance_count=1, + entry_point=LOCAL_ENTRY_POINT, + source_dir=LOCAL_SOURCE_DIR, + ), + MXNet( + framework_version="1.4.1", + py_version="py3", + role=ROLE, + instance_type=INSTANCE_TYPE, + instance_count=1, + entry_point=LOCAL_ENTRY_POINT, + source_dir=LOCAL_SOURCE_DIR, + toolkit=RLToolkit.RAY, + framework=RLFramework.TENSORFLOW, + toolkit_version="0.8.5", + ), + RLEstimator( + entry_point=LOCAL_ENTRY_POINT, + source_dir=LOCAL_SOURCE_DIR, + toolkit=RLToolkit.RAY, + framework=RLFramework.TENSORFLOW, + toolkit_version="0.8.5", + role=ROLE, + instance_type=INSTANCE_TYPE, + instance_count=1, + ), + Chainer( + role=ROLE, + entry_point=LOCAL_ENTRY_POINT, + source_dir=LOCAL_SOURCE_DIR, + use_mpi=True, + num_processes=4, + framework_version="5.0.0", + instance_type=INSTANCE_TYPE, + instance_count=1, + py_version="py3", + ), +] + INPUT_PARAM_LISTS = [ "s3://my-bucket/my-training-input", @@ -209,6 +302,7 @@ def hyperparameters(): return {"test-key": "test-val"} +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) @pytest.mark.parametrize( "experiment_config, expected_experiment_config", [ @@ -250,6 +344,9 @@ def test_training_step_with_estimator( hyperparameters=hyperparameters, enable_network_isolation=enable_network_isolation, encrypt_inter_container_traffic=encrypt_container_traffic, + entry_point=LOCAL_ENTRY_POINT, + source_dir=LOCAL_SOURCE_DIR, + dependencies=LOCAL_DEPS, ) with warnings.catch_warnings(record=True) as w: @@ -289,35 +386,28 @@ def test_training_step_with_estimator( sagemaker_session=pipeline_session, ) step_args = get_step_args_helper(step_args, "Training") - expected_step_arguments = deepcopy(step_args) - expected_step_arguments["EnableInterContainerTrafficEncryption"] = { + step_args["EnableInterContainerTrafficEncryption"] = { "Get": "Parameters.encrypt_container_traffic" } - expected_step_arguments["EnableNetworkIsolation"] = { - "Get": "Parameters.enable_network_isolation" - } + step_args["EnableNetworkIsolation"] = {"Get": "Parameters.enable_network_isolation"} if expected_experiment_config is None: - expected_step_arguments.pop("ExperimentConfig", None) + step_args.pop("ExperimentConfig", None) else: - expected_step_arguments["ExperimentConfig"] = expected_experiment_config + step_args["ExperimentConfig"] = expected_experiment_config assert step_condition.conditions[0].left.expr == { "Get": "Steps.MyTrainingStep.FinalMetricDataList['val:acc'].Value" } step_definition = json.loads(pipeline.definition())["Steps"][0] - # delete profiler rule configurations because of timestamp collision - del step_definition["Arguments"]["ProfilerRuleConfigurations"] - del expected_step_arguments["ProfilerRuleConfigurations"] - assert step_definition == { "Name": "MyTrainingStep", "Description": "TrainingStep description", "DisplayName": "MyTrainingStep", "Type": "Training", "DependsOn": ["TestStep"], - "Arguments": expected_step_arguments, + "Arguments": step_args, } assert step_train.properties.TrainingJobName.expr == { "Get": "Steps.MyTrainingStep.TrainingJobName" @@ -332,6 +422,10 @@ def test_training_step_with_estimator( } ) + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + assert step_definition == step_def2 + def test_training_step_estimator_with_param_code_input( pipeline_session, training_input, hyperparameters @@ -374,7 +468,8 @@ def test_training_step_estimator_with_param_code_input( step_args = get_step_args_helper(step_args, "Training") step_args["HyperParameters"]["sagemaker_program"] = {"Get": "Parameters.EntryPoint"} step_args["HyperParameters"]["sagemaker_submit_directory"] = {"Get": "Parameters.SourceDir"} - assert json.loads(pipeline.definition())["Steps"][0] == { + step_def = json.loads(pipeline.definition())["Steps"][0] + assert step_def == { "Name": "MyTrainingStep", "Description": "TrainingStep description", "DisplayName": "MyTrainingStep", @@ -382,6 +477,10 @@ def test_training_step_estimator_with_param_code_input( "Arguments": step_args, } + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + assert step_def == step_def2 + @pytest.mark.parametrize("estimator", ESTIMATOR_LISTS) @pytest.mark.parametrize("training_input", INPUT_PARAM_LISTS) @@ -414,41 +513,109 @@ def test_training_step_with_framework_estimator( ) step_args = get_step_args_helper(step_args, "Training") + expected_step_args = deepcopy(step_args) step_def = json.loads(pipeline.definition())["Steps"][0] - assert step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] == training_input - assert step_args["OutputDataConfig"]["S3OutputPath"] == output_path - step_args["HyperParameters"]["sagemaker_program"] = {"Get": "Parameters.EntryPoint"} - step_args["HyperParameters"]["sagemaker_submit_directory"] = {"Get": "Parameters.SourceDir"} + assert ( + expected_step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] + == training_input + ) + assert expected_step_args["OutputDataConfig"]["S3OutputPath"] == output_path + expected_step_args["HyperParameters"]["sagemaker_program"] = {"Get": "Parameters.EntryPoint"} + expected_step_args["HyperParameters"]["sagemaker_submit_directory"] = { + "Get": "Parameters.SourceDir" + } - del step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] + del expected_step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] - del step_args["OutputDataConfig"]["S3OutputPath"] + del expected_step_args["OutputDataConfig"]["S3OutputPath"] del step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"] - # trim timestamp so RuleConfigurationName will match - rule_config_name_step_args = step_args["ProfilerRuleConfigurations"][0]["RuleConfigurationName"] - step_args["ProfilerRuleConfigurations"][0][ - "RuleConfigurationName" - ] = rule_config_name_step_args[:-11] - rule_config_name_step_def = step_def["Arguments"]["ProfilerRuleConfigurations"][0][ - "RuleConfigurationName" - ] - step_def["Arguments"]["ProfilerRuleConfigurations"][0][ - "RuleConfigurationName" - ] = rule_config_name_step_def[:-11] + if "sagemaker_s3_output" in step_args["HyperParameters"]: + del expected_step_args["HyperParameters"]["sagemaker_s3_output"] + del step_def["Arguments"]["HyperParameters"]["sagemaker_s3_output"] + + assert step_def == { + "Name": "MyTrainingStep", + "Type": "Training", + "Arguments": expected_step_args, + } + + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] + del step_def2["Arguments"]["OutputDataConfig"]["S3OutputPath"] + if "sagemaker_s3_output" in step_def2["Arguments"]["HyperParameters"]: + del step_def2["Arguments"]["HyperParameters"]["sagemaker_s3_output"] + assert step_def == step_def2 + + +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +@pytest.mark.parametrize("estimator", ESTIMATOR_LISTS_LOCAL_CODE) +@pytest.mark.parametrize("training_input", INPUT_PARAM_LISTS) +@pytest.mark.parametrize( + "output_path", ["s3://my-bucket/my-output-path", ParameterString(name="OutputPath")] +) +def test_training_step_with_framework_estimator_local_code( + estimator, pipeline_session, training_input, output_path, hyperparameters +): + estimator.set_hyperparameters(**hyperparameters) + estimator.volume_kms_key = "volume-kms-key" + estimator.output_kms_key = "output-kms-key" + estimator.dependencies = LOCAL_DEPS + estimator.output_path = output_path + # TODO: remove job_name once we merge + # https://github.com/aws/sagemaker-python-sdk/pull/3158/files + estimator.base_job_name = "TestJob" + + estimator.sagemaker_session = pipeline_session + step_args = estimator.fit(inputs=TrainingInput(s3_data=training_input)) + + step = TrainingStep( + name="MyTrainingStep", + step_args=step_args, + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, + ) + + step_args = get_step_args_helper(step_args, "Training") + expected_step_args = deepcopy(step_args) + step_def = json.loads(pipeline.definition())["Steps"][0] + + assert ( + expected_step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] + == training_input + ) + assert expected_step_args["OutputDataConfig"]["S3OutputPath"] == output_path + + del expected_step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] + del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] + + del expected_step_args["OutputDataConfig"]["S3OutputPath"] + del step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"] if "sagemaker_s3_output" in step_args["HyperParameters"]: - del step_args["HyperParameters"]["sagemaker_s3_output"] + del expected_step_args["HyperParameters"]["sagemaker_s3_output"] del step_def["Arguments"]["HyperParameters"]["sagemaker_s3_output"] assert step_def == { "Name": "MyTrainingStep", "Type": "Training", - "Arguments": step_args, + "Arguments": expected_step_args, } + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] + del step_def2["Arguments"]["OutputDataConfig"]["S3OutputPath"] + if "sagemaker_s3_output" in step_def2["Arguments"]["HyperParameters"]: + del step_def2["Arguments"]["HyperParameters"]["sagemaker_s3_output"] + assert step_def == step_def2 + @pytest.mark.parametrize( "algo_estimator", @@ -519,17 +686,88 @@ def test_training_step_with_algorithm_base(algo_estimator, training_input, pipel del step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] - # trim timestamp so RuleConfigurationName will match - rule_config_name_step_args = step_args["ProfilerRuleConfigurations"][0]["RuleConfigurationName"] - step_args["ProfilerRuleConfigurations"][0][ - "RuleConfigurationName" - ] = rule_config_name_step_args[:-11] - rule_config_name_step_def = step_def["Arguments"]["ProfilerRuleConfigurations"][0][ - "RuleConfigurationName" - ] - step_def["Arguments"]["ProfilerRuleConfigurations"][0][ - "RuleConfigurationName" - ] = rule_config_name_step_def[:-11] + assert step_def == { + "Name": "MyTrainingStep", + "Type": "Training", + "Arguments": step_args, + } + + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] + assert step_def == step_def2 + + +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +@pytest.mark.parametrize( + "algo_estimator", + [ + KNN, + KMeans, + LinearLearner, + RandomCutForest, + LDA, + Object2Vec, + NTM, + PCA, + FactorizationMachines, + IPInsights, + ], +) +@pytest.mark.parametrize( + "training_input", + INPUT_PARAM_LISTS, +) +def test_training_step_with_algorithm_base_local_code( + algo_estimator, training_input, pipeline_session +): + estimator = algo_estimator( + role=ROLE, + instance_type=INSTANCE_TYPE, + instance_count=1, + sagemaker_session=pipeline_session, + entry_point=LOCAL_ENTRY_POINT, + source_dir=LOCAL_SOURCE_DIR, + dependencies=LOCAL_DEPS, + # TODO: remove job_name once we merge + # https://github.com/aws/sagemaker-python-sdk/pull/3158/files + base_job_name="TestJob", + ) + data = RecordSet( + s3_data=training_input, + num_records=1000, + feature_dim=128, + channel="train", + ) + + with warnings.catch_warnings(record=True) as w: + step_args = estimator.fit( + records=data, + mini_batch_size=1000, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Running within a PipelineSession" in str(w[-1].message) + + with warnings.catch_warnings(record=True) as w: + step = TrainingStep( + name="MyTrainingStep", + step_args=step_args, + ) + assert len(w) == 0 + + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, + ) + + step_args = get_step_args_helper(step_args, "Training") + + step_def = json.loads(pipeline.definition())["Steps"][0] + assert step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] == training_input + del step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] + del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] assert step_def == { "Name": "MyTrainingStep", @@ -537,6 +775,11 @@ def test_training_step_with_algorithm_base(algo_estimator, training_input, pipel "Arguments": step_args, } + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] + assert step_def == step_def2 + @pytest.mark.parametrize( "inputs", diff --git a/tests/unit/sagemaker/workflow/test_transform_step.py b/tests/unit/sagemaker/workflow/test_transform_step.py index 5699f13538..ffc901bf5c 100644 --- a/tests/unit/sagemaker/workflow/test_transform_step.py +++ b/tests/unit/sagemaker/workflow/test_transform_step.py @@ -176,6 +176,10 @@ def test_transform_step_with_transformer(model_name, data, output_path, pipeline "Arguments": expected_step_arguments, } + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + assert step_def == step_def2 + @pytest.mark.parametrize( "experiment_config, expected_experiment_config", @@ -261,6 +265,10 @@ def test_transform_step_with_transformer_experiment_config( adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list assert adjacency_list == {"MyTransformStep": []} + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + assert step_def == step_def2 + @pytest.mark.parametrize( "inputs", diff --git a/tests/unit/sagemaker/workflow/test_tuning_step.py b/tests/unit/sagemaker/workflow/test_tuning_step.py index 9c7b764c3b..6c022bb255 100644 --- a/tests/unit/sagemaker/workflow/test_tuning_step.py +++ b/tests/unit/sagemaker/workflow/test_tuning_step.py @@ -159,22 +159,14 @@ def test_tuning_step_with_single_algo_tuner(pipeline_session, training_input, en "S3DataSource" ]["S3Uri"] - # trim timestamp so sagemaker_job_name will still match - step_args_sm_job_name = step_args["TrainingJobDefinition"]["StaticHyperParameters"][ - "sagemaker_job_name" - ] - step_args["TrainingJobDefinition"]["StaticHyperParameters"][ - "sagemaker_job_name" - ] = step_args_sm_job_name[:-24] - step_def_sm_job_name = step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][ + # delete sagemaker_job_name b/c of timestamp collision + del step_args["TrainingJobDefinition"]["StaticHyperParameters"]["sagemaker_job_name"] + del step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][ "sagemaker_job_name" ] - step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][ - "sagemaker_job_name" - ] = step_def_sm_job_name[:-24] - # delete S3 path assertions for now because job name is included with timestamp. These will be re-enabled once - # next PRs are submitted with s3 path updates, removing the job name. + # delete S3 path assertions for now because job name is included with timestamp. These will be re-enabled after + # caching improvements phase 2. del step_args["TrainingJobDefinition"]["StaticHyperParameters"]["sagemaker_submit_directory"] del step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][ "sagemaker_submit_directory" @@ -188,6 +180,16 @@ def test_tuning_step_with_single_algo_tuner(pipeline_session, training_input, en adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list assert adjacency_list == {"MyTuningStep": []} + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + del step_def2["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][ + "sagemaker_job_name" + ] + del step_def2["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][ + "sagemaker_submit_directory" + ] + assert step_def == step_def2 + def test_tuning_step_with_multi_algo_tuner(pipeline_session, entry_point): pytorch_estimator = PyTorch( @@ -249,20 +251,14 @@ def test_tuning_step_with_multi_algo_tuner(pipeline_session, entry_point): step_def = json.loads(pipeline.definition())["Steps"][0] for i, step in enumerate(step_args["TrainingJobDefinitions"]): - # trim timestamp so sagemaker_job_name will still match - step_args_sm_job_name = step["StaticHyperParameters"]["sagemaker_job_name"] - step_args["TrainingJobDefinitions"][i]["StaticHyperParameters"][ - "sagemaker_job_name" - ] = step_args_sm_job_name[:-24] - step_def_sm_job_name = step_def["Arguments"]["TrainingJobDefinitions"][i][ - "StaticHyperParameters" - ]["sagemaker_job_name"] - step_def["Arguments"]["TrainingJobDefinitions"][i]["StaticHyperParameters"][ + # delete sagemaker_job_name b/c of timestamp collision + del step_args["TrainingJobDefinitions"][i]["StaticHyperParameters"]["sagemaker_job_name"] + del step_def["Arguments"]["TrainingJobDefinitions"][i]["StaticHyperParameters"][ "sagemaker_job_name" - ] = step_def_sm_job_name[:-24] + ] - # delete S3 path assertions for now because job name is included with timestamp. These will be re-enabled once - # next PRs are submitted with s3 path updates, removing the job name. + # delete S3 path assertions for now because job name is included with timestamp. These will be re-enabled after + # caching improvements phase 2. del step_args["TrainingJobDefinitions"][i]["StaticHyperParameters"][ "sagemaker_submit_directory" ] @@ -278,6 +274,18 @@ def test_tuning_step_with_multi_algo_tuner(pipeline_session, entry_point): adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list assert adjacency_list == {"MyTuningStep": []} + # test idempotency + step_def2 = json.loads(pipeline.definition())["Steps"][0] + for i, step in enumerate(step_def2["Arguments"]["TrainingJobDefinitions"]): + del step_def2["Arguments"]["TrainingJobDefinitions"][i]["StaticHyperParameters"][ + "sagemaker_job_name" + ] + + del step_def2["Arguments"]["TrainingJobDefinitions"][i]["StaticHyperParameters"][ + "sagemaker_submit_directory" + ] + assert step_def == step_def2 + @pytest.mark.parametrize( "inputs", diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index dcbf5a6421..d1b81f3148 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -18,12 +18,6 @@ import tempfile import pytest -import sagemaker - -from mock import ( - Mock, - PropertyMock, -) from sagemaker.estimator import Estimator from sagemaker.workflow._utils import ( @@ -35,51 +29,7 @@ from sagemaker.workflow.properties import Properties from tests.unit.test_utils import FakeS3, list_tar_files from tests.unit import DATA_DIR - -REGION = "us-west-2" -BUCKET = "my-bucket" -IMAGE_URI = "fakeimage" -ROLE = "DummyRole" - - -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=BUCKET, - ) +from tests.unit.sagemaker.workflow.conftest import ROLE, IMAGE_URI, BUCKET @pytest.fixture @@ -157,6 +107,7 @@ def test_repack_model_step(estimator): } ], "OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"}, + "ProfilerConfig": {"DisableProfiler": True}, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.m5.large", @@ -171,7 +122,7 @@ def test_repack_model_step(estimator): } -def test_repack_model_step_with_invalid_input(): +def test_register_model_step_with_invalid_input(): # without both step_args and any of the old required arguments with pytest.raises(ValueError) as error: _RegisterModelStep( @@ -238,6 +189,7 @@ def test_repack_model_step_with_source_dir(estimator, source_dir): } ], "OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"}, + "ProfilerConfig": {"DisableProfiler": True}, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.m5.large", diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 82b154317d..44b5818fc8 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -225,6 +225,9 @@ def test_fit_ndarray(time, sagemaker_session): assert mock_object.put.call_count == 4 + called_args = sagemaker_session.train.call_args + assert not called_args[1]["experiment_config"] + def test_fit_pass_experiment_config(sagemaker_session): kwargs = dict(COMMON_ARGS) @@ -239,12 +242,18 @@ def test_fit_pass_experiment_config(sagemaker_session): labels = [99, 85, 87, 2] pca.fit( pca.record_set(np.array(train), np.array(labels)), - experiment_config={"ExperimentName": "exp"}, + experiment_config={ + "ExperimentName": "exp", + "RunName": "rn", + }, ) called_args = sagemaker_session.train.call_args - assert called_args[1]["experiment_config"] == {"ExperimentName": "exp"} + assert called_args[1]["experiment_config"] == { + "ExperimentName": "exp", + "RunName": "rn", + } def test_build_shards(): diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 7cc973440f..eca4a9bf80 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -150,14 +150,8 @@ def _create_train_job(version, py_version): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 34e6a43fcf..8b771f9184 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -25,7 +25,10 @@ from botocore.exceptions import ClientError from mock import ANY, MagicMock, Mock, patch, PropertyMock from sagemaker.huggingface.estimator import HuggingFace -from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME +from sagemaker.jumpstart.constants import ( + JUMPSTART_BUCKET_NAME_SET, + JUMPSTART_RESOURCE_BASE_NAME, +) from sagemaker.jumpstart.enums import JumpStartTag import sagemaker.local @@ -106,7 +109,11 @@ "training_steps": "100", }, "RoleArn": "arn:aws:iam::366:role/SageMakerRole", - "ResourceConfig": {"VolumeSizeInGB": 30, "InstanceCount": 1, "InstanceType": "ml.c4.xlarge"}, + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, "EnableNetworkIsolation": False, "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, "TrainingJobName": "neo", @@ -143,7 +150,10 @@ } MOCKED_S3_URI = "s3://mocked_s3_uri_from_source_dir" MOCKED_PIPELINE_CONFIG = _PipelineConfig( - "test-pipeline", "test-training-step", "code-hash-0123456789", "config-hash-0123456789" + "test-pipeline", + "test-training-step", + "code-hash-0123456789", + "config-hash-0123456789", ) @@ -247,7 +257,9 @@ def pipeline_session(): session_mock.resource.return_value = resource_mock session_mock.client.return_value = client_mock return PipelineSession( - boto_session=session_mock, sagemaker_client=client_mock, default_bucket=BUCKET_NAME + boto_session=session_mock, + sagemaker_client=client_mock, + default_bucket=BUCKET_NAME, ) @@ -322,7 +334,11 @@ def test_framework_all_init_args(sagemaker_session): }, "metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}], "encrypt_inter_container_traffic": True, - "environment": {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}, + "environment": { + "env_key1": "env_val1", + "env_key2": "env_val2", + "env_key3": "env_val3", + }, "experiment_config": None, "checkpoint_s3_uri": "s3://bucket/checkpoint", "checkpoint_local_path": "file://local/checkpoint", @@ -379,7 +395,8 @@ def test_framework_with_debugger_and_built_in_rule(sagemaker_session): rule_parameters={"threshold": "120", "stop_training_on_fire": "True"}, collections_to_save=[ CollectionConfig( - name="losses", parameters={"train.save_interval": "50", "eval.save_interval": "10"} + name="losses", + parameters={"train.save_interval": "50", "eval.save_interval": "10"}, ) ], ) @@ -405,18 +422,23 @@ def test_framework_with_debugger_and_built_in_rule(sagemaker_session): "CollectionConfigurations": [ { "CollectionName": "losses", - "CollectionParameters": {"train.save_interval": "50", "eval.save_interval": "10"}, + "CollectionParameters": { + "train.save_interval": "50", + "eval.save_interval": "10", + }, } ], } assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } def test_framework_with_debugger_and_custom_rule(sagemaker_session): hook_config = DebuggerHookConfig( - s3_output_path="s3://output", collection_configs=[CollectionConfig(name="weights")] + s3_output_path="s3://output", + collection_configs=[CollectionConfig(name="weights")], ) debugger_custom_rule = Rule.custom( name="CustomRule", @@ -536,7 +558,8 @@ def test_framework_with_debugger_rule_and_multiple_actions(sagemaker_session): def test_framework_with_only_debugger_hook_config(sagemaker_session): hook_config = DebuggerHookConfig( - s3_output_path="s3://output", collection_configs=[CollectionConfig(name="weights")] + s3_output_path="s3://output", + collection_configs=[CollectionConfig(name="weights")], ) f = DummyFramework( entry_point=SCRIPT_PATH, @@ -574,15 +597,9 @@ def test_framework_without_debugger_and_profiler(time, sagemaker_session): } assert "debugger_rule_configs" not in args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } - assert args["profiler_rule_configs"] == [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ] def test_framework_with_debugger_and_profiler_rules(sagemaker_session): @@ -591,7 +608,8 @@ def test_framework_with_debugger_and_profiler_rules(sagemaker_session): rule_parameters={"threshold": "120", "stop_training_on_fire": "True"}, collections_to_save=[ CollectionConfig( - name="losses", parameters={"train.save_interval": "50", "eval.save_interval": "10"} + name="losses", + parameters={"train.save_interval": "50", "eval.save_interval": "10"}, ) ], ) @@ -639,18 +657,25 @@ def test_framework_with_debugger_and_profiler_rules(sagemaker_session): "CollectionConfigurations": [ { "CollectionName": "losses", - "CollectionParameters": {"train.save_interval": "50", "eval.save_interval": "10"}, + "CollectionParameters": { + "train.save_interval": "50", + "eval.save_interval": "10", + }, } ], } assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } assert args["profiler_rule_configs"] == [ { "RuleConfigurationName": "CustomProfilerReportRule", "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport", "CPUBottleneck_threshold": "90"}, + "RuleParameters": { + "rule_to_invoke": "ProfilerReport", + "CPUBottleneck_threshold": "90", + }, }, { "InstanceType": "c4.4xlarge", @@ -679,6 +704,7 @@ def test_framework_with_only_profiler_rule_specified(sagemaker_session): sagemaker_session.train.assert_called_once() _, args = sagemaker_session.train.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } assert args["profiler_rule_configs"] == [ @@ -711,16 +737,10 @@ def test_framework_with_profiler_config_without_s3_output_path(time, sagemaker_s sagemaker_session.train.assert_called_once() _, args = sagemaker_session.train.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), "ProfilingIntervalInMilliseconds": 1000, } - assert args["profiler_rule_configs"] == [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ] @pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS) @@ -745,7 +765,9 @@ def test_framework_with_no_default_profiler_in_unsupported_region(region): f.fit("s3://mydata") sms.train.assert_called_once() _, args = sms.train.call_args - assert args.get("profiler_config") is None + # assert args.get("profiler_config") == {"DisableProfiler": True} + # temporarily check if "DisableProfiler" flag is true until s3_output is changed to optional in service + assert args.get("profiler_config")["DisableProfiler"] is True assert args.get("profiler_rule_configs") is None @@ -865,7 +887,10 @@ def test_framework_with_profiler_config_and_profiler_disabled(sagemaker_session) disable_profiler=True, ) f.fit("s3://mydata") - assert "profiler_config cannot be set when disable_profiler is True." in str(error) + # assert "profiler_config cannot be set when disable_profiler is True." in str(error) + assert "profiler_config.disable_profiler cannot be False when disable_profiler is True." in str( + error + ) def test_framework_with_profiler_rule_and_profiler_disabled(sagemaker_session): @@ -927,15 +952,9 @@ def test_framework_with_enabling_default_profiling( sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } - assert args["profiler_rule_configs"] == [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ] @patch("time.time", return_value=TIME) @@ -960,15 +979,9 @@ def test_framework_with_enabling_default_profiling_with_existed_s3_output_path( sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://custom/", } - assert args["profiler_rule_configs"] == [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ] def test_framework_with_disabling_profiling_when_profiler_is_already_disabled( @@ -1001,7 +1014,9 @@ def test_framework_with_disabling_profiling(sagemaker_session, training_job_desc f.disable_profiling() sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args - assert args["profiler_config"] == {"DisableProfiler": True} + # assert args["profiler_config"] == {"DisableProfiler": True} + # temporarily check if "DisableProfiler" flag is true until s3_output is changed to optional in service + assert args.get("profiler_config")["DisableProfiler"] is True def test_framework_with_update_profiler_when_no_training_job(sagemaker_session): @@ -1058,6 +1073,7 @@ def test_framework_with_update_profiler_config(sagemaker_session): sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "ProfilingIntervalInMilliseconds": 1000, } assert "profiler_rule_configs" not in args @@ -1086,7 +1102,7 @@ def test_framework_with_update_profiler_report_rule(sagemaker_session): "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, } ] - assert "profiler_config" not in args + assert args["profiler_config"]["DisableProfiler"] is False def test_framework_with_disable_framework_metrics(sagemaker_session): @@ -1101,11 +1117,16 @@ def test_framework_with_disable_framework_metrics(sagemaker_session): f.update_profiler(disable_framework_metrics=True) sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args - assert args["profiler_config"] == {"ProfilingParameters": {}} + assert args["profiler_config"] == { + "DisableProfiler": False, + "ProfilingParameters": {}, + } assert "profiler_rule_configs" not in args -def test_framework_with_disable_framework_metrics_and_update_system_metrics(sagemaker_session): +def test_framework_with_disable_framework_metrics_and_update_system_metrics( + sagemaker_session, +): f = DummyFramework( entry_point=SCRIPT_PATH, role=ROLE, @@ -1118,13 +1139,16 @@ def test_framework_with_disable_framework_metrics_and_update_system_metrics(sage sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "ProfilingIntervalInMilliseconds": 1000, "ProfilingParameters": {}, } assert "profiler_rule_configs" not in args -def test_framework_with_disable_framework_metrics_and_update_framework_params(sagemaker_session): +def test_framework_with_disable_framework_metrics_and_update_framework_params( + sagemaker_session, +): with pytest.raises(ValueError) as error: f = DummyFramework( entry_point=SCRIPT_PATH, @@ -1160,7 +1184,10 @@ def test_framework_with_update_profiler_config_and_profiler_rule(sagemaker_sessi f.update_profiler(rules=[profiler_custom_rule], system_monitor_interval_millis=1000) sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args - assert args["profiler_config"] == {"ProfilingIntervalInMilliseconds": 1000} + assert args["profiler_config"] == { + "DisableProfiler": False, + "ProfilingIntervalInMilliseconds": 1000, + } assert args["profiler_rule_configs"] == [ { "InstanceType": "c4.4xlarge", @@ -1659,7 +1686,10 @@ def test_start_new_wait_called(strftime, sagemaker_session): def test_attach_framework(sagemaker_session, training_job_description): - training_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + training_job_description["VpcConfig"] = { + "Subnets": ["foo"], + "SecurityGroupIds": ["bar"], + } training_job_description["EnableNetworkIsolation"] = True framework_estimator = DummyFramework.attach( @@ -1753,7 +1783,8 @@ def test_attach_framework_with_inter_container_traffic_encryption_flag( def test_attach_framework_base_from_generated_name(sagemaker_session, training_job_description): base_job_name = "neo" framework_estimator = DummyFramework.attach( - training_job_name=utils.name_from_base("neo"), sagemaker_session=sagemaker_session + training_job_name=utils.name_from_base("neo"), + sagemaker_session=sagemaker_session, ) assert framework_estimator.base_job_name == base_job_name @@ -1948,7 +1979,8 @@ def test_git_support_bad_repo_url_format(sagemaker_session): @patch( "sagemaker.git_utils.git_clone_repo", side_effect=subprocess.CalledProcessError( - returncode=1, cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir" + returncode=1, + cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir", ), ) def test_git_support_git_clone_fail(git_clone_repo, sagemaker_session): @@ -1973,7 +2005,11 @@ def test_git_support_git_clone_fail(git_clone_repo, sagemaker_session): ), ) def test_git_support_branch_not_exist(git_clone_repo, sagemaker_session): - git_config = {"repo": GIT_REPO, "branch": "branch-that-does-not-exist", "commit": COMMIT} + git_config = { + "repo": GIT_REPO, + "branch": "branch-that-does-not-exist", + "commit": COMMIT, + } fw = DummyFramework( entry_point="entry_point", git_config=git_config, @@ -1994,7 +2030,11 @@ def test_git_support_branch_not_exist(git_clone_repo, sagemaker_session): ), ) def test_git_support_commit_not_exist(git_clone_repo, sagemaker_session): - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": "commit-sha-that-does-not-exist"} + git_config = { + "repo": GIT_REPO, + "branch": BRANCH, + "commit": "commit-sha-that-does-not-exist", + } fw = DummyFramework( entry_point="entry_point", git_config=git_config, @@ -2137,7 +2177,11 @@ def test_git_support_with_token_2fa(git_clone_repo, sagemaker_session): }, ) def test_git_support_ssh_no_passphrase_needed(git_clone_repo, sagemaker_session): - git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} + git_config = { + "repo": PRIVATE_GIT_REPO_SSH, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + } entry_point = "entry_point" fw = DummyFramework( entry_point=entry_point, @@ -2159,7 +2203,11 @@ def test_git_support_ssh_no_passphrase_needed(git_clone_repo, sagemaker_session) ), ) def test_git_support_ssh_passphrase_required(git_clone_repo, sagemaker_session): - git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} + git_config = { + "repo": PRIVATE_GIT_REPO_SSH, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + } entry_point = "entry_point" fw = DummyFramework( entry_point=entry_point, @@ -2457,7 +2505,9 @@ def test_estimator_transformer_creation_with_optional_params(create_model, sagem ) create_model.assert_called_with( - vpc_config_override=new_vpc_config, model_kms_key=kms_key, enable_network_isolation=True + vpc_config_override=new_vpc_config, + model_kms_key=kms_key, + enable_network_isolation=True, ) assert transformer.strategy == strategy @@ -2489,7 +2539,12 @@ def test_start_new(sagemaker_session): hyperparameters=hyperparameters, ) - exp_config = {"ExperimentName": "exp", "TrialName": "t", "TrialComponentDisplayName": "tc"} + exp_config = { + "ExperimentName": "exp", + "TrialName": "t", + "TrialComponentDisplayName": "tc", + "RunName": "rn", + } started_training_job = training_job.start_new(estimator, inputs, experiment_config=exp_config) called_args = sagemaker_session.train.call_args @@ -2630,14 +2685,7 @@ def test_unsupported_type_in_dict(): "input_config": None, "input_mode": "File", "output_config": {"S3OutputPath": OUTPUT_PATH}, - "profiler_config": {"S3OutputPath": OUTPUT_PATH}, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], + "profiler_config": {"DisableProfiler": False, "S3OutputPath": OUTPUT_PATH}, "resource_config": { "InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE, @@ -2680,6 +2728,7 @@ def test_unsupported_type_in_dict(): "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } } ) @@ -2743,7 +2792,11 @@ def test_fit_deploy_tags_in_estimator(name_from_base, sagemaker_session): @patch("sagemaker.estimator.name_from_base") def test_fit_deploy_tags(name_from_base, sagemaker_session): estimator = Estimator( - IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, ) estimator.fit() @@ -2884,6 +2937,7 @@ def test_generic_to_fit_with_experiment_config(time, sagemaker_session): "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", }, ) @@ -3190,7 +3244,10 @@ def test_generic_training_job_analytics(sagemaker_session): "TrainingInputMode": "File", "MetricDefinitions": [ {"Name": "train:loss", "Regex": "train_loss=([0-9]+\\.[0-9]+)"}, - {"Name": "validation:loss", "Regex": "valid_loss=([0-9]+\\.[0-9]+)"}, + { + "Name": "validation:loss", + "Regex": "valid_loss=([0-9]+\\.[0-9]+)", + }, ], }, }, @@ -3221,7 +3278,11 @@ def test_generic_create_model_vpc_config_override(sagemaker_session): vpc_config_b = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]} e = Estimator( - IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, ) e.fit({"train": "s3://bucket/training-prefix"}) assert e.get_vpc_config() is None @@ -3247,7 +3308,11 @@ def test_generic_deploy_vpc_config_override(sagemaker_session): vpc_config_b = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]} e = Estimator( - IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, ) e.fit({"train": "s3://bucket/training-prefix"}) e.deploy(INSTANCE_COUNT, INSTANCE_TYPE) @@ -3267,7 +3332,11 @@ def test_generic_deploy_vpc_config_override(sagemaker_session): def test_generic_deploy_accelerator_type(sagemaker_session): e = Estimator( - IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, ) e.fit({"train": "s3://bucket/training-prefix"}) e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE) @@ -3610,7 +3679,13 @@ def test_file_output_path_not_supported_outside_local_mode(session_class): session_class.return_value = session with pytest.raises(RuntimeError): - Estimator(IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path="file:///tmp/model") + Estimator( + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path="file:///tmp/model", + ) def test_prepare_init_params_from_job_description_with_image_training_job(): @@ -3719,7 +3794,10 @@ def test_prepare_for_training_with_name_based_on_image(sagemaker_session): @patch("sagemaker.algorithm.AlgorithmEstimator.validate_train_spec", Mock()) -@patch("sagemaker.algorithm.AlgorithmEstimator._parse_hyperparameters", Mock(return_value={})) +@patch( + "sagemaker.algorithm.AlgorithmEstimator._parse_hyperparameters", + Mock(return_value={}), +) def test_prepare_for_training_with_name_based_on_algorithm(sagemaker_session): estimator = AlgorithmEstimator( algorithm_arn="arn:aws:sagemaker:us-west-2:1234:algorithm/scikit-decision-trees-1542410022", @@ -3734,7 +3812,9 @@ def test_prepare_for_training_with_name_based_on_algorithm(sagemaker_session): @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) -def test_prepare_for_training_with_pipeline_name_in_s3_path_no_source_dir(pipeline_session): +def test_prepare_for_training_with_pipeline_name_in_s3_path_no_source_dir( + pipeline_session, +): # script_uri is NOT provided -> use new cache key behavior that builds path using pipeline name + code_hash image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38" model_uri = "s3://someprefix2/models/model.tar.gz" @@ -4204,7 +4284,10 @@ def test_script_mode_estimator_tags_jumpstart_models_with_no_estimator_js_tags( @patch("sagemaker.model.Model._upload_code") @patch("sagemaker.utils.repack_model") def test_all_framework_estimators_add_jumpstart_tags( - patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_repack_model, + patched_upload_code, + patched_tar_and_upload_dir, + sagemaker_session, ): sagemaker_session.boto_region_name = REGION @@ -4233,13 +4316,20 @@ def test_all_framework_estimators_add_jumpstart_tags( "transformers_version": "4.6.1", "instance_type": "ml.p2.xlarge", }, - MXNet: {"framework_version": "1.7.0", "py_version": "py3", "instance_type": "ml.p2.xlarge"}, + MXNet: { + "framework_version": "1.7.0", + "py_version": "py3", + "instance_type": "ml.p2.xlarge", + }, SKLearn: {"framework_version": "0.23-1", "instance_type": "ml.m2.xlarge"}, XGBoost: {"framework_version": "1.3-1", "instance_type": "ml.m2.xlarge"}, } jumpstart_model_uri = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz" jumpstart_model_uri_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/model_dirs/model.tar.gz" - for framework_estimator_class, kwargs in framework_estimator_classes_to_kwargs.items(): + for ( + framework_estimator_class, + kwargs, + ) in framework_estimator_classes_to_kwargs.items(): estimator = framework_estimator_class( entry_point=ENTRY_POINT, role=ROLE, @@ -4355,7 +4445,10 @@ def test_script_mode_estimator_uses_jumpstart_base_name_with_js_models( @patch("sagemaker.model.Model._upload_code") @patch("sagemaker.utils.repack_model") def test_all_framework_estimators_add_jumpstart_base_name( - patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_repack_model, + patched_upload_code, + patched_tar_and_upload_dir, + sagemaker_session, ): sagemaker_session.boto_region_name = REGION @@ -4384,13 +4477,20 @@ def test_all_framework_estimators_add_jumpstart_base_name( "transformers_version": "4.6.1", "instance_type": "ml.p2.xlarge", }, - MXNet: {"framework_version": "1.7.0", "py_version": "py3", "instance_type": "ml.p2.xlarge"}, + MXNet: { + "framework_version": "1.7.0", + "py_version": "py3", + "instance_type": "ml.p2.xlarge", + }, SKLearn: {"framework_version": "0.23-1", "instance_type": "ml.m2.xlarge"}, XGBoost: {"framework_version": "1.3-1", "instance_type": "ml.m2.xlarge"}, } jumpstart_model_uri = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz" jumpstart_model_uri_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/model_dirs/model.tar.gz" - for framework_estimator_class, kwargs in framework_estimator_classes_to_kwargs.items(): + for ( + framework_estimator_class, + kwargs, + ) in framework_estimator_classes_to_kwargs.items(): estimator = framework_estimator_class( entry_point=ENTRY_POINT, role=ROLE, diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 667d115d58..4654abb928 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -883,6 +883,7 @@ def test_validate_smdataparallel_args_not_raises(): ("ml.p3.16xlarge", "tensorflow", "2.7", "py38", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.8.0", "py39", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.8", "py39", smdataparallel_enabled), + ("ml.p3.16xlarge", "tensorflow", "2.9.2", "py39", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.9.1", "py39", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.9", "py39", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.10.0", "py39", smdataparallel_enabled), @@ -915,6 +916,7 @@ def test_validate_smdataparallel_args_not_raises(): ("ml.p3.16xlarge", "tensorflow", "2.7.1", "py38", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.8.0", "py39", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.9.1", "py39", smdataparallel_enabled_custom_mpi), + ("ml.p3.16xlarge", "tensorflow", "2.9.2", "py39", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.10.0", "py39", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled_custom_mpi), diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 99b0e839b7..f12d8e160f 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -62,6 +62,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } MODEL_PKG_RESPONSE = {"ModelPackageArn": "arn:model-pkg-arn"} @@ -159,14 +160,8 @@ def _get_train_args(job_name): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.0-cpu-py3", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 082f699d63..5691834c3a 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -54,6 +54,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}} @@ -157,14 +158,8 @@ def _create_train_job(version, py_version): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index 4efc2e5bf8..0c0a9c6d64 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -49,6 +49,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } @@ -152,14 +153,8 @@ def _create_train_job(toolkit, toolkit_version, framework): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, "retry_strategy": None, diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8958210092..119d08cef4 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -588,11 +588,16 @@ def test_user_agent_injected(boto_session): assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent assert "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_client._client_config.user_agent assert ( "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_runtime_client._client_config.user_agent ) + assert ( + "AWS-SageMaker-Notebook-Instance" + not in sess.sagemaker_metrics_client._client_config.user_agent + ) def test_user_agent_injected_with_nbi(boto_session): @@ -607,10 +612,14 @@ def test_user_agent_injected_with_nbi(boto_session): assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent assert "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_client._client_config.user_agent assert ( "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_runtime_client._client_config.user_agent ) + assert ( + "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_metrics_client._client_config.user_agent + ) def test_user_agent_injected_with_nbi_ioerror(boto_session): @@ -625,11 +634,16 @@ def test_user_agent_injected_with_nbi_ioerror(boto_session): assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent assert "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_client._client_config.user_agent assert ( "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_runtime_client._client_config.user_agent ) + assert ( + "AWS-SageMaker-Notebook-Instance" + not in sess.sagemaker_metrics_client._client_config.user_agent + ) def test_training_input_all_defaults(): @@ -700,6 +714,7 @@ def test_training_input_all_arguments(): "ExperimentName": "dummyExp", "TrialName": "dummyT", "TrialComponentDisplayName": "dummyTC", + "RunName": "dummyRN", } MODEL_CLIENT_CONFIG = {"InvocationsMaxRetries": 2, "InvocationsTimeoutInSeconds": 60} @@ -882,6 +897,7 @@ def test_train_pack_to_request(sagemaker_session): "ResourceLimits": {"MaxNumberOfTrainingJobs": 100, "MaxParallelTrainingJobs": 5}, "ParameterRanges": SAMPLE_PARAM_RANGES, "TrainingJobEarlyStoppingType": "Off", + "RandomSeed": 0, }, "TrainingJobDefinition": { "StaticHyperParameters": STATIC_HPs, @@ -941,6 +957,13 @@ def test_train_pack_to_request(sagemaker_session): ], } +SAMPLE_HYPERBAND_STRATEGY_CONFIG = { + "HyperbandStrategyConfig": { + "MinResource": 1, + "MaxResource": 10, + } +} + @pytest.mark.parametrize( "warm_start_type, parents", @@ -967,6 +990,7 @@ def assert_create_tuning_job_request(**kwrags): sagemaker_session.tune( job_name="dummy-tuning-1", strategy="Bayesian", + random_seed=0, objective_type="Maximize", objective_metric_name="val-score", max_jobs=100, @@ -1058,6 +1082,7 @@ def assert_create_tuning_job_request(**kwrags): "max_jobs": 100, "max_parallel_jobs": 5, "parameter_ranges": SAMPLE_PARAM_RANGES, + "random_seed": 0, }, training_config={ "static_hyperparameters": STATIC_HPs, @@ -1142,6 +1167,47 @@ def assert_create_tuning_job_request(**kwrags): assert kwrags["TrainingJobDefinition"] == SAMPLE_TUNING_JOB_REQUEST["TrainingJobDefinition"] assert kwrags.get("WarmStartConfig", None) is None + sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = ( + assert_create_tuning_job_request + ) + sagemaker_session.tune( + job_name="dummy-tuning-1", + strategy="Bayesian", + random_seed=0, + objective_type="Maximize", + objective_metric_name="val-score", + max_jobs=100, + max_parallel_jobs=5, + parameter_ranges=SAMPLE_PARAM_RANGES, + static_hyperparameters=STATIC_HPs, + image_uri="dummy-image-1", + input_mode="File", + metric_definitions=SAMPLE_METRIC_DEF, + role=EXPANDED_ROLE, + input_config=SAMPLE_INPUT, + output_config=SAMPLE_OUTPUT, + resource_config=RESOURCE_CONFIG, + stop_condition=SAMPLE_STOPPING_CONDITION, + tags=None, + warm_start_config=None, + ) + + +def test_tune_with_strategy_config(sagemaker_session): + def assert_create_tuning_job_request(**kwrags): + assert ( + kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][ + "MinResource" + ] + == SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MinResource"] + ) + assert ( + kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][ + "MaxResource" + ] + == SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MaxResource"] + ) + sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = ( assert_create_tuning_job_request ) @@ -1164,6 +1230,7 @@ def assert_create_tuning_job_request(**kwrags): stop_condition=SAMPLE_STOPPING_CONDITION, tags=None, warm_start_config=None, + strategy_config=SAMPLE_HYPERBAND_STRATEGY_CONFIG, ) @@ -1183,6 +1250,7 @@ def assert_create_tuning_job_request(**kwrags): sagemaker_session.tune( job_name="dummy-tuning-1", strategy="Bayesian", + random_seed=0, objective_type="Maximize", objective_metric_name="val-score", max_jobs=100, @@ -1226,6 +1294,7 @@ def assert_create_tuning_job_request(**kwargs): sagemaker_session.tune( job_name="dummy-tuning-1", strategy="Bayesian", + random_seed=0, objective_type="Maximize", objective_metric_name="val-score", max_jobs=100, @@ -2739,6 +2808,35 @@ def test_feature_metadata_describe(sagemaker_session): ) +def test_list_feature_groups(sagemaker_session): + expected_list_feature_groups_args = { + "NameContains": "MyFeatureGroup", + "FeatureGroupStatusEquals": "Created", + "OfflineStoreStatusEquals": "Active", + "CreationTimeAfter": datetime.datetime(2020, 12, 1), + "CreationTimeBefore": datetime.datetime(2022, 7, 1), + "SortOrder": "Ascending", + "SortBy": "Name", + "MaxResults": 50, + "NextToken": "token", + } + sagemaker_session.list_feature_groups( + name_contains="MyFeatureGroup", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after=datetime.datetime(2020, 12, 1), + creation_time_before=datetime.datetime(2022, 7, 1), + sort_order="Ascending", + sort_by="Name", + max_results=50, + next_token="token", + ) + assert sagemaker_session.sagemaker_client.list_feature_groups.called_once() + assert sagemaker_session.sagemaker_client.list_feature_groups.called_with( + **expected_list_feature_groups_args + ) + + def test_start_query_execution(sagemaker_session): athena_mock = Mock() sagemaker_session.boto_session.client( diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 13cc755336..430cb484b4 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -51,6 +51,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } @@ -139,14 +140,8 @@ def _create_train_job(version): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 9bbc882dfa..7e556c7d23 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -545,6 +545,7 @@ def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session assert tuner.strategy == "Bayesian" assert tuner.objective_type == "Minimize" assert tuner.early_stopping_type == "Off" + assert tuner.random_seed == 0 assert isinstance(tuner.estimator, PCA) assert tuner.estimator.role == ROLE diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 0eb81be584..8bcbed41c2 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -25,10 +25,12 @@ from boto3 import exceptions import botocore import pytest -from mock import call, patch, Mock, MagicMock +from mock import call, patch, Mock, MagicMock, PropertyMock import sagemaker +from sagemaker.experiments._run_context import _RunContext from sagemaker.session_settings import SessionSettings +from sagemaker.utils import retry_with_backoff, check_and_get_run_experiment_config from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -795,3 +797,63 @@ def test_start_waiting(capfd): out, _ = capfd.readouterr() assert "." * sagemaker.utils.WAITING_DOT_NUMBER in out + + +def test_retry_with_backoff(): + callable_func = Mock() + + # Invalid input + with pytest.raises(ValueError) as value_err: + retry_with_backoff(callable_func, 0) + assert "The num_attempts must be >= 1" in str(value_err) + callable_func.assert_not_called() + + # All retries fail + run_err_msg = "Test Retry Error" + callable_func.side_effect = RuntimeError(run_err_msg) + with pytest.raises(RuntimeError) as run_err: + retry_with_backoff(callable_func, 2) + assert run_err_msg in str(run_err) + + # One retry passes + func_return_val = "Test Return" + callable_func.side_effect = [RuntimeError(run_err_msg), func_return_val] + assert retry_with_backoff(callable_func, 2) == func_return_val + + # No retry + callable_func.side_effect = None + callable_func.return_value = func_return_val + assert retry_with_backoff(callable_func, 2) == func_return_val + + +def test_check_and_get_run_experiment_config(): + supplied_exp_cfg = {"ExperimentName": "my-supplied-exp-name", "RunName": "my-supplied-run-name"} + run_exp_cfg = {"ExperimentName": "my-run-exp-name", "RunName": "my-run-run-name"} + + # No user supplied exp config and no current Run + assert not _RunContext.get_current_run() + exp_cfg1 = check_and_get_run_experiment_config(None) + assert exp_cfg1 is None + + # With user supplied exp config and no current Run + assert not _RunContext.get_current_run() + exp_cfg2 = check_and_get_run_experiment_config(supplied_exp_cfg) + assert exp_cfg2 == supplied_exp_cfg + + run = Mock() + type(run).experiment_config = PropertyMock(return_value=run_exp_cfg) + _RunContext.add_run_object(run) + + try: + # No user supplied exp config and with current Run + assert _RunContext.get_current_run().experiment_config == run_exp_cfg + exp_cfg3 = check_and_get_run_experiment_config(None) + assert exp_cfg3 == run_exp_cfg + + # With user supplied exp config and current Run + assert _RunContext.get_current_run().experiment_config == run_exp_cfg + exp_cfg4 = check_and_get_run_experiment_config(supplied_exp_cfg) + assert exp_cfg4 == supplied_exp_cfg + finally: + # Clean up the global static variable in case it affects other tests + _RunContext.drop_current_run() diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 82f27c19ae..87a853d5d0 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -54,6 +54,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } @@ -153,14 +154,8 @@ def _create_train_job(version, instance_count=1, instance_type="ml.c4.4xlarge"): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/tuner_test_utils.py b/tests/unit/tuner_test_utils.py index be0dba2ccd..5cf7ba2fc2 100644 --- a/tests/unit/tuner_test_utils.py +++ b/tests/unit/tuner_test_utils.py @@ -112,6 +112,7 @@ ], }, "TrainingJobEarlyStoppingType": "Off", + "RandomSeed": 0, }, "HyperParameterTuningJobName": JOB_NAME, "TrainingJobDefinition": { diff --git a/tox.ini b/tox.ini index 2d5fdf0b40..3a398ca51d 100644 --- a/tox.ini +++ b/tox.ini @@ -73,6 +73,8 @@ passenv = # Can be used to specify which tests to run, e.g.: tox -- -s commands = python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')" + pip install 'apache-airflow==2.4.1' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.4.1/constraints-3.10.txt" + pytest --cov=sagemaker --cov-append {posargs} {env:IGNORE_COVERAGE:} coverage report -i --fail-under=86 deps = .[test]