diff --git a/.githooks/pre-push b/.githooks/pre-push index 995ab70108..f73fa492b3 100755 --- a/.githooks/pre-push +++ b/.githooks/pre-push @@ -12,5 +12,5 @@ start_time=`date +%s` tox -e sphinx,doc8 --parallel all ./ci-scripts/displaytime.sh 'sphinx,doc8' $start_time start_time=`date +%s` -tox -e py38,py39,py310 --parallel all -- tests/unit -./ci-scripts/displaytime.sh 'py38,py39,py310 unit' $start_time +tox -e py39,py310,py311,py312 --parallel all -- tests/unit +./ci-scripts/displaytime.sh 'py39,py310,py311,py312 unit' $start_time diff --git a/.github/workflows/codebuild-canaries.yml b/.github/workflows/codebuild-canaries.yml new file mode 100644 index 0000000000..a6b5a978ef --- /dev/null +++ b/.github/workflows/codebuild-canaries.yml @@ -0,0 +1,24 @@ +name: Canaries +on: + schedule: + - cron: "0 */3 * * *" + workflow_dispatch: + +permissions: + id-token: write # This is required for requesting the JWT + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Integ Tests + uses: aws-actions/aws-codebuild-run-build@v1 + id: codebuild + with: + project-name: sagemaker-python-sdk-canaries diff --git a/.github/workflows/codebuild-ci-health.yml b/.github/workflows/codebuild-ci-health.yml index 7ecefd310f..119b9dbe9c 100644 --- a/.github/workflows/codebuild-ci-health.yml +++ b/.github/workflows/codebuild-ci-health.yml @@ -26,7 +26,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["py38", "py39", "py310", "py311"] + python-version: ["py39", "py310", "py311","py312"] steps: - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v4 diff --git a/.github/workflows/codebuild-ci.yml b/.github/workflows/codebuild-ci.yml index 8c6bd6b337..eef53ff06c 100644 --- a/.github/workflows/codebuild-ci.yml +++ b/.github/workflows/codebuild-ci.yml @@ -63,7 +63,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["py38","py39","py310","py311"] + python-version: ["py39","py310","py311","py312"] steps: - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v4 diff --git a/.pylintrc b/.pylintrc index 5428b86be0..223580f4d3 100644 --- a/.pylintrc +++ b/.pylintrc @@ -94,7 +94,24 @@ disable= useless-object-inheritance, # TODO: Enable this check and fix code once Python 2 is no longer supported. super-with-arguments, raise-missing-from, - E1136, + C0116, # Missing function or method docstring + C0209, # Use f-string instead of format + E0015, # Unrecognized option found in config + E0702, # Raising a string instead of an exception + E1101, # Module has no member (likely dynamic attr) + E1136, # Value assigned to something inferred as None + R0022, # Useless option value in config + R1710, # Inconsistent return statements + R1714, # Consider using `in` with comparisons + R1729, # Use a generator + R1732, + R1735, # Consider using a dict or list literal + W0237, # Argument renamed in override + W0613, # Unused argument + W0621, # Redefining name from outer scope + W0719 + W1404, # Implicit string concatenation + W1514, # `open()` used without encoding [REPORTS] # Set the output format. Available formats are text, parseable, colorized, msvs @@ -436,4 +453,4 @@ analyse-fallback-blocks=no # Exceptions that will emit a warning when being caught. Defaults to # "Exception" -overgeneral-exceptions=Exception +overgeneral-exceptions=builtins.Exception diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 0a6e3928b5..0dcc70b9c3 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -5,9 +5,9 @@ version: 2 build: - os: ubuntu-20.04 + os: ubuntu-22.04 tools: - python: "3.9" + python: "3.12" python: diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c2b1851ae..e59d964bd1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,196 @@ # Changelog +## v2.243.2 (2025-04-16) + +### Bug Fixes and Other Changes + + * tgi image uri unit tests + * Fix deepdiff dependencies + +## v2.243.1 (2025-04-11) + +### Bug Fixes and Other Changes + + * Added handler for pipeline variable while creating process job + * Fix issue #4856 by copying environment variables + * remove historical job_name caching which causes long job name + * Update instance gpu info + * Master + * Add mlflow tracking arn telemetry + * chore: fix semantic versioning for wildcard identifier + * flaky test + +### Documentation Changes + + * update pipelines step caching examples to include more steps + * update ModelStep data dependency info + +## v2.243.0 (2025-03-27) + +### Features + + * Enabled update_endpoint through model_builder + +### Bug Fixes and Other Changes + + * Update for PT 2.5.1, SMP 2.8.0 + * chore: move jumpstart region definitions to json file + * fix flaky clarify model monitor test + * fix flaky spark processor integ + * use temp file in unit tests + * Update transformers version + * Aligned disable_output_compression for @remote with Estimator + * Update Jinja version + * update image_uri_configs 03-26-2025 07:18:16 PST + * chore: fix integ tests to use latest version of model + * update image_uri_configs 03-25-2025 07:18:13 PST + * Skip tests failed due to deprecated instance type + * update image_uri_configs 03-21-2025 07:17:55 PST + * factor in set instance type when building JumpStart models in ModelBuilder. + * ADD Documentation to ReadtheDocs for Upgrading torch versions + * add new regions to JUMPSTART_LAUNCHED_REGIONS + +## v2.242.0 (2025-03-14) + +### Features + + * add integ tests for training JumpStart models in private hub + +### Bug Fixes and Other Changes + + * Torch upgrade + * Prevent RunContext overlap between test_run tests + * remove s3 output location requirement from hub class init + * Fixing Pytorch training python version in tests + * update image_uri_configs 03-11-2025 07:18:09 PST + * resolve infinite loop in _find_config on Windows systems + * pipeline definition function doc update + +## v2.241.0 (2025-03-06) + +### Features + + * Make DistributedConfig Extensible + * support training for JumpStart model references as part of Curated Hub Phase 2 + * Allow ModelTrainer to accept hyperparameters file + +### Bug Fixes and Other Changes + + * Skip tests with deprecated instance type + * Ensure Model.is_repack() returns a boolean + * Fix error when there is no session to call _create_model_request() + * Use sagemaker session's s3_resource in download_folder + * Added check for the presence of model package group before creating one + * Fix key error in _send_metrics() + +## v2.240.0 (2025-02-25) + +### Features + + * Add support for TGI Neuronx 0.0.27 and HF PT 2.3.0 image in PySDK + +### Bug Fixes and Other Changes + + * Remove main function entrypoint in ModelBuilder dependency manager. + * forbid extras in Configs + * altconfig hubcontent and reenable integ test + * Merge branch 'master-rba' into local_merge + * py_version doc fixes + * Add backward compatbility for RecordSerializer and RecordDeserializer + * update image_uri_configs 02-21-2025 06:18:10 PST + * update image_uri_configs 02-20-2025 06:18:08 PST + +### Documentation Changes + + * Removed a line about python version requirements of training script which can misguide users. + +## v2.239.3 (2025-02-19) + +### Bug Fixes and Other Changes + + * added ap-southeast-7 and mx-central-1 for Jumpstart + * update image_uri_configs 02-19-2025 06:18:15 PST + +## v2.239.2 (2025-02-18) + +### Bug Fixes and Other Changes + + * Add warning about not supporting torch.nn.SyncBatchNorm + * pass in inference_ami_version to model_based endpoint type + * Fix hyperparameter strategy docs + * Add framework_version to all TensorFlowModel examples + * Move RecordSerializer and RecordDeserializer to sagemaker.serializers and sagemaker.deserialzers + +## v2.239.1 (2025-02-14) + +### Bug Fixes and Other Changes + + * keep sagemaker_session from being overridden to None + * Fix all type hint and docstrings for callable + * Fix the workshop link for Step Functions + * Fix Tensorflow doc link + * Fix FeatureGroup docstring + * Add type hint for ProcessingOutput + * Fix sourcedir.tar.gz filenames in docstrings + * Fix documentation for local mode + * bug in get latest version was getting the max sorted alphabetically + * Add cleanup logic to model builder integ tests for endpoints + * Fixed pagination failing while listing collections + * fix ValueError when updating a data quality monitoring schedule + * Add docstring for image_uris.retrieve + * Create GitHub action to trigger canaries + * update image_uri_configs 02-04-2025 06:18:00 PST + +## v2.239.0 (2025-02-01) + +### Features + + * Add support for deepseek recipes + +### Bug Fixes and Other Changes + + * mpirun protocol - distributed training with @remote decorator + * Allow telemetry only in supported regions + * Fix ssh host policy + +## v2.238.0 (2025-01-29) + +### Features + + * use jumpstart deployment config image as default optimization image + +### Bug Fixes and Other Changes + + * chore: add new images for HF TGI + * update image_uri_configs 01-29-2025 06:18:08 PST + * skip TF tests for unsupported versions + * Merge branch 'master-rba' into local_merge + * Add missing attributes to local resourceconfig + * update image_uri_configs 01-27-2025 06:18:13 PST + * update image_uri_configs 01-24-2025 06:18:11 PST + * add missing schema definition in docs + * Omegaconf upgrade + * SageMaker @remote function: Added multi-node functionality + * remove option + * fix typo + * fix tests + * Add an option for user to remove inputs and container artifacts when using local model trainer + +## v2.237.3 (2025-01-09) + +### Bug Fixes and Other Changes + + * pin metadata-version to 2.3 + * model server might have already done a serialization. honor that by not decoding the request again if it is not already bytes or bytestream + * Disable jumpstart tests missing clean up logic + * Jumpstart ap southeast 5 + * add autogluon 1.2 + * updated inference script to cover context + * security update -> use sha256 instead of md5 for file hashing + * Fix Flake8 Violations + * Added parsing string support for situations where custom code might be used (ie. mlflow) + * Updating Inference Optimization Validations + ## v2.237.2 (2024-12-17) ### Bug Fixes and Other Changes diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 24226af4ee..65b7c0ee0c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -61,6 +61,10 @@ Before sending us a pull request, please ensure that: 1. Follow the instructions at [Modifying an EBS Volume Using Elastic Volumes (Console)](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/requesting-ebs-volume-modifications.html#modify-ebs-volume) to increase the EBS volume size associated with the newly created EC2 instance. 1. Wait 5-10min for the new EBS volume increase to finalize. 1. Allow EC2 to claim the additional space by stopping and then starting your EC2 host. +2. Set up a venv to manage dependencies: + 1. `python -m venv ~/.venv/myproject-env` to create the venv + 2. `source ~/.venv/myproject-env/bin/activate` to activate the venv + 3. `deactivate` to exit the venv ### Pull Down the Code @@ -74,8 +78,8 @@ Before sending us a pull request, please ensure that: ### Run the Unit Tests 1. Install tox using `pip install tox` -1. Install coverage using `pip install .[test]` -1. cd into the sagemaker-python-sdk folder: `cd sagemaker-python-sdk` or `cd /environment/sagemaker-python-sdk` +1. cd into the github project sagemaker-python-sdk folder: `cd sagemaker-python-sdk` or `cd /environment/sagemaker-python-sdk` +1. Install coverage using `pip install '.[test]'` 1. Run the following tox command and verify that all code checks and unit tests pass: `tox tests/unit` 1. You can also run a single test with the following command: `tox -e py310 -- -s -vv ::` 1. You can run coverage via runcvoerage env : `tox -e runcoverage -- tests/unit` or `tox -e py310 -- tests/unit --cov=sagemaker --cov-append --cov-report xml` diff --git a/VERSION b/VERSION index 4e32236881..4e55ec1ee4 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.237.2 +2.243.3.dev0 diff --git a/doc/amazon_sagemaker_model_building_pipeline.rst b/doc/amazon_sagemaker_model_building_pipeline.rst index e3548f80f2..1645302d52 100644 --- a/doc/amazon_sagemaker_model_building_pipeline.rst +++ b/doc/amazon_sagemaker_model_building_pipeline.rst @@ -408,21 +408,39 @@ Example: step_args=step_args_register_model, ) -CreateModelStep +ModelStep ```````````````` Referable Property List: - `DescribeModel`_ + OR +- `DescribeModelPackage`_ + .. _DescribeModel: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeModel.html#API_DescribeModel_ResponseSyntax +.. _DescribeModelPackage: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeModelPackage.html#API_DescribeModelPackage_ResponseSyntax Example: +For model creation usecase: + .. code-block:: python - step_model = CreateModelStep(...) - model_data = step_model.PrimaryContainer.ModelDataUrl + create_model_step = ModelStep( + name="MyModelCreationStep", + step_args = model.create(...) + ) + model_data = create_model_step.properties.PrimaryContainer.ModelDataUrl +For model registration usercase: + +.. code-block:: python + + register_model_step = ModelStep( + name="MyModelRegistrationStep", + step_args=model.register(...) + ) + approval_status=register_model_step.properties.ModelApprovalStatus LambdaStep ````````````` @@ -912,7 +930,7 @@ Caching is supported for the following step types: - :class:`sagemaker.workflow.clarify_check_step.ClarifyCheckStep` - :class:`sagemaker.workflow.emr_step.EMRStep` -In order to create pipeline steps and eventually construct a SageMaker pipeline, you provide parameters within a Python script or notebook. The SageMaker Python SDK creates a pipeline definition by translating these parameters into SageMaker job attributes. Some of these attributes, when changed, cause the step to re-run (See `Caching Pipeline Steps `__ for a detailed list). Therefore, if you update a SDK parameter that is used to create such an attribute, the step will rerun. See the following discussion for examples of this in processing and training steps, which are commonly used steps in Pipelines. +In order to create pipeline steps and eventually construct a SageMaker pipeline, you provide parameters within a Python script or notebook. The SageMaker Python SDK creates a pipeline definition by translating these parameters into SageMaker job attributes. Some of these attributes, when changed, cause the step to re-run (See `Caching Pipeline Steps `__ for a detailed list). Therefore, if you update a SDK parameter that is used to create such an attribute, the step will rerun. See the following discussion for examples of this in commonly used step types in Pipelines. The following example creates a processing step: @@ -1037,6 +1055,218 @@ The following parameters from the example cause additional training step iterati - :code:`entry_point`: The entry point file is included in the training job’s `InputDataConfig Channel `__ array. A unique hash is created from the file (and any other dependencies), and then the file is uploaded to S3 with the hash included in the path. When a different entry point file is used, a new hash is created and the S3 path for that `InputDataConfig Channel `__ object changes, initiating a new step run. For examples of what the S3 paths look like, see the **S3 Artifact Folder Structure** section. - :code:`inputs`: The inputs are also included in the training job’s `InputDataConfig `__. Local inputs are uploaded to S3. If the S3 path changes, a new training job is initiated. For examples of S3 paths, see the **S3 Artifact Folder Structure** section. +The following example creates a tuning step: + +.. code-block:: python + + from sagemaker.workflow.steps import TuningStep + from sagemaker.tuner import HyperparameterTuner + from sagemaker.estimator import Estimator + from sagemaker.inputs import TrainingInput + + model_path = f"s3://{default_bucket}/{base_job_prefix}/AbaloneTrain" + + xgb_train = Estimator( + image_uri=image_uri, + instance_type=training_instance_type, + instance_count=1, + output_path=model_path, + base_job_name=f"{base_job_prefix}/abalone-train", + sagemaker_session=pipeline_session, + role=role, + ) + + xgb_train.set_hyperparameters( + eval_metric="rmse", + objective="reg:squarederror", # Define the object metric for the training job + num_round=50, + max_depth=5, + eta=0.2, + gamma=4, + min_child_weight=6, + subsample=0.7, + silent=0, + ) + + objective_metric_name = "validation:rmse" + + hyperparameter_ranges = { + "alpha": ContinuousParameter(0.01, 10, scaling_type="Logarithmic"), + "lambda": ContinuousParameter(0.01, 10, scaling_type="Logarithmic"), + } + + tuner = HyperparameterTuner( + xgb_train, + objective_metric_name, + hyperparameter_ranges, + max_jobs=3, + max_parallel_jobs=3, + strategy="Random", + objective_type="Minimize", + ) + + hpo_args = tuner.fit( + inputs={ + "train": TrainingInput( + s3_data=step_process.properties.ProcessingOutputConfig.Outputs["train"].S3Output.S3Uri, + content_type="text/csv", + ), + "validation": TrainingInput( + s3_data=step_process.properties.ProcessingOutputConfig.Outputs[ + "validation" + ].S3Output.S3Uri, + content_type="text/csv", + ), + } + ) + + step_tuning = TuningStep( + name="HPTuning", + step_args=hpo_args, + cache_config=cache_config, + ) + +The following parameters from the example cause additional tuning (or training) step iterations when you change them: + +- :code:`image_uri`: The :code:`image_uri` parameter defines the image used for training, and is used directly in the `AlgorithmSpecification `__ attribute of the training job(s) that are created from the tuning job. +- :code:`hyperparameters`: All of the hyperparameters passed in the :code:`xgb_train.set_hyperparameters()` method are used directly in the `StaticHyperParameters `__ attribute for the tuning job. +- The following parameters are all included in the `HyperParameterTuningJobConfig `__ and if any one of them changes, a new tuning job is initiated: + - :code:`hyperparameter_ranges` + - :code:`objective_metric_name` + - :code:`max_jobs` + - :code:`max_parallel_jobs` + - :code:`strategy` + - :code:`objective_type` +- :code:`inputs`: The inputs are included in any training job’s `InputDataConfig `__ that get created from the tuning job. Local inputs are uploaded to S3. If the S3 path changes, a new tuning job is initiated. For examples of S3 paths, see the S3 Artifact Folder Structure section. + +The following examples creates a transform step: + +.. code-block:: python + + from sagemaker.transformer import Transformer + from sagemaker.inputs import TransformInput + from sagemaker.workflow.steps import TransformStep + + base_uri = f"s3://{default_bucket}/abalone" + batch_data_uri = sagemaker.s3.S3Uploader.upload( + local_path=local_path, + desired_s3_uri=base_uri, + ) + + batch_data = ParameterString( + name="BatchData", + default_value=batch_data_uri, + ) + + transformer = Transformer( + model_name=step_create_model.properties.ModelName, + instance_type="ml.m5.xlarge", + instance_count=1, + output_path=f"s3://{default_bucket}/AbaloneTransform", + env={ + 'class': 'Transformer' + } + ) + + step_transform = TransformStep( + name="AbaloneTransform", + step_args=transformer.transform( + data=batch_data, + data_type="S3Prefix" + ) + ) + +The following parameters from the example cause additional batch transform step iterations when you change them: + +- :code:`model_name`: The name of the SageMaker model being used for the transform job. +- :code:`env`: Environment variables to be set for use during the transform job. +- :code:`batch_data`: The input data will be included in the transform job’s `TransformInputfield `__. If the S3 path changes, a new transform job is initiated. + +The following example creates an automl step: + +.. code-block:: python + + from sagemaker.workflow.pipeline_context import PipelineSession + from sagemaker.workflow.automl_step import AutoMLStep + + pipeline_session = PipelineSession() + + auto_ml = AutoML(..., + role=role, + target_attribute_name="my_target_attribute_name", + mode="ENSEMBLING", + sagemaker_session=pipeline_session) + + input_training = AutoMLInput( + inputs="s3://amzn-s3-demo-bucket/my-training-data", + target_attribute_name="my_target_attribute_name", + channel_type="training", + ) + input_validation = AutoMLInput( + inputs="s3://amzn-s3-demo-bucket/my-validation-data", + target_attribute_name="my_target_attribute_name", + channel_type="validation", + ) + + step_args = auto_ml.fit( + inputs=[input_training, input_validation] + ) + + step_automl = AutoMLStep( + name="AutoMLStep", + step_args=step_args, + ) + + best_model = step_automl.get_best_auto_ml_model(role=) + +The following parameters from the example cause additional automl step iterations when you change them: + +- :code:`target_attribute_name`: The name of the target variable in supervised learning. +- :code:`mode`: The method that AutoML job uses to train the model - either AUTO, ENSEMBLING or HYPERPARAMETER_TUNING. +- :code:`inputs`: The inputs passed to the auto_ml.fit() method are included in the automl job’s `InputDataConfig `__. If the included S3 path(s) change, a new automl job is initiated. + +The following example creates an EMR step: + +.. code-block:: python + + from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig + + emr_config = EMRStepConfig( + jar="jar-location", # required, path to jar file used + args=["--verbose", "--force"], # optional list of arguments to pass to the jar + main_class="com.my.Main1", # optional main class, this can be omitted if jar above has a manifest + properties=[ # optional list of Java properties that are set when the step runs + { + "key": "mapred.tasktracker.map.tasks.maximum", + "value": "2" + }, + { + "key": "mapreduce.map.sort.spill.percent", + "value": "0.90" + }, + { + "key": "mapreduce.tasktracker.reduce.tasks.maximum", + "value": "5" + } + ] + ) + + step_emr = EMRStep( + name="EMRSampleStep", # required + cluster_id="j-1ABCDEFG2HIJK", # include cluster_id to use a running cluster + step_config=emr_config, # required + display_name="My EMR Step", + description="Pipeline step to execute EMR job" + ) + +The following parameters from the example cause additional EMR step iterations when you change them: + +- :code:`cluster_id`: The id of a running cluster to leverage for the EMR job. +- :code:`emr_config`: Configuration regarding the code that will run on the EMR cluster during the job. + +:class:`Note`: A :code:`cluster_config` parameter may also be passed into :code:`EMRStep` in order to spin up a new cluster. This parameter will also trigger additional step iterations if changed. + + S3 Artifact Folder Structure ---------------------------- diff --git a/doc/api/inference/model_builder.rst b/doc/api/inference/model_builder.rst index 3099441850..3cfbcbc2c7 100644 --- a/doc/api/inference/model_builder.rst +++ b/doc/api/inference/model_builder.rst @@ -3,14 +3,14 @@ Model Builder This module contains classes related to Amazon Sagemaker Model Builder -.. autoclass:: sagemaker.serve.builder.model_builder.ModelBuilder +.. autoclass:: sagemaker.serve.ModelBuilder -.. automethod:: sagemaker.serve.builder.model_builder.ModelBuilder.build +.. automethod:: sagemaker.serve.ModelBuilder.build -.. automethod:: sagemaker.serve.builder.model_builder.ModelBuilder.save +.. automethod:: sagemaker.serve.ModelBuilder.save -.. autoclass:: sagemaker.serve.spec.inference_spec.InferenceSpec +.. autoclass:: sagemaker.serve.InferenceSpec -.. autoclass:: sagemaker.serve.builder.schema_builder.SchemaBuilder +.. autoclass:: sagemaker.serve.SchemaBuilder -.. autoclass:: sagemaker.serve.marshalling.custom_payload_translator.CustomPayloadTranslator +.. autoclass:: sagemaker.serve.CustomPayloadTranslator diff --git a/doc/api/training/index.rst b/doc/api/training/index.rst index 0f61cd1931..285d9f266d 100644 --- a/doc/api/training/index.rst +++ b/doc/api/training/index.rst @@ -3,7 +3,7 @@ Training APIs ############# .. toctree:: - :maxdepth: 4 + :maxdepth: 1 model_trainer algorithm diff --git a/doc/conf.py b/doc/conf.py index 94a5c4d9c6..6c88ddd0e7 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -83,16 +83,11 @@ html_css_files = [ "https://cdn.datatables.net/1.10.23/css/jquery.dataTables.min.css", + "theme_overrides.css", + "pagination.css", + "search_accessories.css", ] -html_context = { - "css_files": [ - "_static/theme_overrides.css", - "_static/pagination.css", - "_static/search_accessories.css", - ] -} - # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = {"python": ("http://docs.python.org/", None)} diff --git a/doc/frameworks/pytorch/using_pytorch.rst b/doc/frameworks/pytorch/using_pytorch.rst index d415f38c27..4141dd84db 100644 --- a/doc/frameworks/pytorch/using_pytorch.rst +++ b/doc/frameworks/pytorch/using_pytorch.rst @@ -28,8 +28,6 @@ To train a PyTorch model by using the SageMaker Python SDK: Prepare a PyTorch Training Script ================================= -Your PyTorch training script must be a Python 3.6 compatible source file. - Prepare your script in a separate source file than the notebook, terminal session, or source file you're using to submit the script to SageMaker via a ``PyTorch`` Estimator. This will be discussed in further detail below. @@ -375,6 +373,9 @@ To initialize distributed training in your script, call `torch.distributed.init_process_group `_ with the desired backend and the rank of the current host. +Warning: Some torch features, such as (and likely not limited to) ``torch.nn.SyncBatchNorm`` +is not supported and its existence in ``init_process_group`` will cause an exception during +distributed training. .. code:: python diff --git a/doc/frameworks/tensorflow/deploying_tensorflow_serving.rst b/doc/frameworks/tensorflow/deploying_tensorflow_serving.rst index 1d7344fbbb..a645cd5a62 100644 --- a/doc/frameworks/tensorflow/deploying_tensorflow_serving.rst +++ b/doc/frameworks/tensorflow/deploying_tensorflow_serving.rst @@ -64,7 +64,7 @@ If you already have existing model artifacts in S3, you can skip training and de from sagemaker.tensorflow import TensorFlowModel - model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole') + model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole', framework_version='x.x.x') predictor = model.deploy(initial_instance_count=1, instance_type='ml.c5.xlarge') @@ -74,7 +74,7 @@ Python-based TensorFlow serving on SageMaker has support for `Elastic Inference from sagemaker.tensorflow import TensorFlowModel - model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole') + model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole', framework_version='x.x.x') predictor = model.deploy(initial_instance_count=1, instance_type='ml.c5.xlarge', accelerator_type='ml.eia1.medium') diff --git a/doc/frameworks/tensorflow/using_tf.rst b/doc/frameworks/tensorflow/using_tf.rst index 1e51b5f43a..5b888f95be 100644 --- a/doc/frameworks/tensorflow/using_tf.rst +++ b/doc/frameworks/tensorflow/using_tf.rst @@ -246,7 +246,7 @@ Training with parameter servers If you specify parameter_server as the value of the distribution parameter, the container launches a parameter server thread on each instance in the training cluster, and then executes your training code. You can find more information on -TensorFlow distributed training at `TensorFlow docs `__. +TensorFlow distributed training at `TensorFlow docs `__. To enable parameter server training: .. code:: python @@ -468,7 +468,7 @@ If you already have existing model artifacts in S3, you can skip training and de from sagemaker.tensorflow import TensorFlowModel - model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole') + model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole', framework_version='x.x.x') predictor = model.deploy(initial_instance_count=1, instance_type='ml.c5.xlarge') @@ -478,7 +478,7 @@ Python-based TensorFlow serving on SageMaker has support for `Elastic Inference from sagemaker.tensorflow import TensorFlowModel - model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole') + model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole', framework_version='x.x.x') predictor = model.deploy(initial_instance_count=1, instance_type='ml.c5.xlarge', accelerator_type='ml.eia1.medium') @@ -767,7 +767,8 @@ This customized Python code must be named ``inference.py`` and is specified thro model = TensorFlowModel(entry_point='inference.py', model_data='s3://mybucket/model.tar.gz', - role='MySageMakerRole') + role='MySageMakerRole', + framework_version='x.x.x') In the example above, ``inference.py`` is assumed to be a file inside ``model.tar.gz``. If you want to use a local file instead, you must add the ``source_dir`` argument. See the documentation on `TensorFlowModel `_. @@ -923,7 +924,8 @@ processing. There are 2 ways to do this: model = TensorFlowModel(entry_point='inference.py', dependencies=['requirements.txt'], model_data='s3://mybucket/model.tar.gz', - role='MySageMakerRole') + role='MySageMakerRole', + framework_version='x.x.x') 2. If you are working in a network-isolation situation or if you don't @@ -941,7 +943,8 @@ processing. There are 2 ways to do this: model = TensorFlowModel(entry_point='inference.py', dependencies=['/path/to/folder/named/lib'], model_data='s3://mybucket/model.tar.gz', - role='MySageMakerRole') + role='MySageMakerRole', + framework_version='x.x.x') For more information, see: https://github.com/aws/sagemaker-tensorflow-serving-container#prepost-processing diff --git a/doc/overview.rst b/doc/overview.rst index a1dc5c6918..26601900bd 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -30,6 +30,11 @@ To train a model by using the SageMaker Python SDK, you: After you train a model, you can save it, and then serve the model as an endpoint to get real-time inferences or get inferences for an entire dataset by using batch transform. + +Important Note: + +* When using torch to load Models, it is recommended to use version torch>=2.6.0 and torchvision>=0.17.0 + Prepare a Training script ========================= @@ -1958,7 +1963,7 @@ Make sure to have a Compose Version compatible with your Docker Engine installat Local mode configuration ======================== -The local mode uses a YAML configuration file located at ``~/.sagemaker/config.yaml`` to define the default values that are automatically passed to the ``config`` attribute of ``LocalSession``. This is an example of the configuration, for the full schema, see `sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA `_. +The local mode uses a YAML configuration file located at ``${user_config_directory}/sagemaker/config.yaml`` to define the default values that are automatically passed to the ``config`` attribute of ``LocalSession``. This is an example of the configuration, for the full schema, see `sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA `_. .. code:: yaml @@ -1966,7 +1971,7 @@ The local mode uses a YAML configuration file located at ``~/.sagemaker/config.y local_code: true # Using everything locally region_name: "us-west-2" # Name of the region container_config: # Additional docker container config - shm_size: "128M + shm_size: "128M" If you want to keep everything local, and not use Amazon S3 either, you can enable "local code" in one of two ways: @@ -2565,6 +2570,9 @@ set default values for. For the full schema, see `sagemaker.config.config_schema       KmsKeyId: 'kmskeyid10'     TransformResources:       VolumeKmsKeyId: 'volumekmskeyid4' + Tags: +     - Key: 'tag_key' +       Value: 'tag_value   CompilationJob:   # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateCompilationJob.html     OutputConfig: diff --git a/doc/requirements.txt b/doc/requirements.txt index 9bef9392a8..11098e2bc1 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,8 +1,8 @@ -sphinx==5.1.1 -sphinx-rtd-theme==0.5.0 -docutils==0.15.2 -packaging==20.9 -jinja2==3.1.4 +sphinx==7.2.6 +sphinx-rtd-theme==3.0.0 +docutils>=0.18.1,<0.21 +packaging>=23.0,<25 +jinja2==3.1.6 schema==0.7.5 accelerate>=0.24.1,<=0.27.0 graphene<4.0 diff --git a/doc/v2.rst b/doc/v2.rst index 0677594b31..bca663af33 100644 --- a/doc/v2.rst +++ b/doc/v2.rst @@ -324,9 +324,9 @@ The follow serializer/deserializer classes have been renamed and/or moved: +--------------------------------------------------------+-------------------------------------------------------+ | ``sagemaker.predictor._NPYSerializer`` | ``sagemaker.serializers.NumpySerializer`` | +--------------------------------------------------------+-------------------------------------------------------+ -| ``sagemaker.amazon.common.numpy_to_record_serializer`` | ``sagemaker.amazon.common.RecordSerializer`` | +| ``sagemaker.amazon.common.numpy_to_record_serializer`` | ``sagemaker.serializers.RecordSerializer`` | +--------------------------------------------------------+-------------------------------------------------------+ -| ``sagemaker.amazon.common.record_deserializer`` | ``sagemaker.amazon.common.RecordDeserializer`` | +| ``sagemaker.amazon.common.record_deserializer`` | ``sagemaker.deserializers.RecordDeserializer`` | +--------------------------------------------------------+-------------------------------------------------------+ | ``sagemaker.predictor._JsonDeserializer`` | ``sagemaker.deserializers.JSONDeserializer`` | +--------------------------------------------------------+-------------------------------------------------------+ diff --git a/doc/workflows/step_functions/index.rst b/doc/workflows/step_functions/index.rst index a327d376a0..bfe9582341 100644 --- a/doc/workflows/step_functions/index.rst +++ b/doc/workflows/step_functions/index.rst @@ -11,5 +11,5 @@ without having to provision and integrate the AWS services separately. The AWS Step Functions Python SDK uses the SageMaker Python SDK as a dependency. To get started with step functions, try the workshop or visit the SDK's website: -* `Workshop on using AWS Step Functions with SageMaker `__ +* `Create and manage Amazon SageMaker AI jobs with Step Functions `__ * `AWS Step Functions Python SDK website `__ diff --git a/pyproject.toml b/pyproject.toml index 21eb185090..c5c9bf9874 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "sagemaker" dynamic = ["version", "optional-dependencies"] description = "Open source library for training and deploying models on Amazon SageMaker." readme = "README.rst" -requires-python = ">=3.8" +requires-python = ">=3.9" authors = [ { name = "Amazon Web Services" }, ] @@ -25,10 +25,10 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Natural Language :: English", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] dependencies = [ "attrs>=23.1.0,<24", @@ -39,15 +39,15 @@ dependencies = [ "google-pasta", "importlib-metadata>=1.4.0,<7.0", "jsonschema", - "numpy>=1.9.0,<2.0", - "omegaconf>=2.2,<2.3", - "packaging>=20.0", + "numpy==1.26.4", + "omegaconf>=2.2,<=2.3", + "packaging>=23.0,<25", "pandas", "pathos", "platformdirs", "protobuf>=3.12,<6.0", "psutil", - "PyYAML~=6.0", + "PyYAML>=6.0.1", "requests", "sagemaker-core>=1.0.17,<2.0.0", "schema", @@ -73,6 +73,7 @@ pattern = "(?P.+)" [tool.hatch.metadata.hooks.custom] [tool.hatch.build.targets.wheel] +core-metadata-version = "2.3" packages = ["src/sagemaker"] exclude = ["src/sagemaker/serve/model_server/triton/pack_conda_env.sh"] @@ -80,6 +81,7 @@ exclude = ["src/sagemaker/serve/model_server/triton/pack_conda_env.sh"] "src/sagemaker/serve/model_server/triton/pack_conda_env.sh" = "pack_conda_env.sh" [tool.hatch.build.targets.sdist] +core-metadata-version = "2.3" only-include = [ "/requirements/extras", "/src", diff --git a/requirements/extras/local_requirements.txt b/requirements/extras/local_requirements.txt index 68b9a1bcb3..ea57b82e9a 100644 --- a/requirements/extras/local_requirements.txt +++ b/requirements/extras/local_requirements.txt @@ -1,3 +1,3 @@ urllib3>=1.26.8,<3.0.0 docker>=5.0.2,<8.0.0 -PyYAML>=5.4.1,<7 +PyYAML>=6.0.1,<7 diff --git a/requirements/extras/scipy_requirements.txt b/requirements/extras/scipy_requirements.txt index 0e99587e6e..44ce1d9331 100644 --- a/requirements/extras/scipy_requirements.txt +++ b/requirements/extras/scipy_requirements.txt @@ -1 +1 @@ -scipy==1.10.1 +scipy==1.11.3 diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index fe31300c22..3e6200ee3e 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -1,7 +1,7 @@ tox==3.24.5 -numpy>=1.24.0 +numpy==1.26.4 build[virtualenv]==1.2.1 -flake8==4.0.1 +flake8==7.1.2 pytest==6.2.5 pytest-cov==3.0.0 pytest-rerunfailures==10.2 @@ -14,26 +14,26 @@ awslogs==0.14.0 black==24.3.0 stopit==1.1.2 # Update tox.ini to have correct version of airflow constraints file -apache-airflow==2.9.3 +apache-airflow==2.10.4 apache-airflow-providers-amazon==7.2.1 attrs>=23.1.0,<24 -fabric==2.6.0 +fabric==3.2.2 requests==2.32.2 sagemaker-experiments==0.1.35 -Jinja2==3.1.4 +Jinja2==3.1.6 pyvis==0.2.1 pandas==1.4.4 scikit-learn==1.3.0 cloudpickle==2.2.1 jsonpickle<4.0.0 -PyYAML==6.0 +PyYAML>=6.0.1 # TODO find workaround xgboost>=1.6.2,<=1.7.6 pillow>=10.0.1,<=11 opentelemetry-proto==1.27.0 protobuf==4.25.5 -tensorboard>=2.9.0,<=2.15.2 -transformers==4.46.1 +tensorboard>=2.16.2,<=2.18.0 +transformers==4.48.0 sentencepiece==0.1.99 # https://github.com/triton-inference-server/server/issues/6246 tritonclient[http]<2.37.0 @@ -42,7 +42,7 @@ onnx==1.17.0 nbformat>=5.9,<6 accelerate>=0.24.1,<=0.27.0 schema==0.7.5 -tensorflow>=2.9.0,<=2.15.1 +tensorflow>=2.16.2,<=2.18.0 mlflow>=2.12.2,<2.13 huggingface_hub==0.26.2 uvicorn>=0.30.1 @@ -50,3 +50,5 @@ fastapi==0.115.4 nest-asyncio sagemaker-mlflow>=0.1.0 deepdiff>=8.0.0 +orderly-set<5.4.0 +lexicon diff --git a/requirements/tox/doc8_requirements.txt b/requirements/tox/doc8_requirements.txt index e4a040dd4d..8707c06621 100644 --- a/requirements/tox/doc8_requirements.txt +++ b/requirements/tox/doc8_requirements.txt @@ -1,2 +1,2 @@ -doc8==0.10.1 -Pygments==2.15.0 +doc8==1.1.2 +Pygments==2.18.0 diff --git a/requirements/tox/flake8_requirements.txt b/requirements/tox/flake8_requirements.txt index b3ccfca84f..63a79da444 100644 --- a/requirements/tox/flake8_requirements.txt +++ b/requirements/tox/flake8_requirements.txt @@ -1,2 +1,2 @@ -flake8==4.0.1 -flake8-future-import==0.4.6 +flake8==7.1.2 +flake8-future-import==0.4.7 diff --git a/requirements/tox/pylint_requirements.txt b/requirements/tox/pylint_requirements.txt index b307f21762..0e5db209fe 100644 --- a/requirements/tox/pylint_requirements.txt +++ b/requirements/tox/pylint_requirements.txt @@ -1,2 +1,2 @@ -pylint==2.6.2 -astroid==2.4.2 +pylint==3.0.3 +astroid==3.0.2 diff --git a/requirements/tox/spelling_requirements.txt b/requirements/tox/spelling_requirements.txt index 769415eb2c..94d6bc314e 100644 --- a/requirements/tox/spelling_requirements.txt +++ b/requirements/tox/spelling_requirements.txt @@ -1,2 +1,2 @@ pyenchant==3.2.2 -pylint==2.6.2 +pylint==3.0.3 diff --git a/src/sagemaker/_studio.py b/src/sagemaker/_studio.py index a23fae87e9..22f1c94c5f 100644 --- a/src/sagemaker/_studio.py +++ b/src/sagemaker/_studio.py @@ -65,7 +65,10 @@ def _find_config(working_dir=None): wd = Path(working_dir) if working_dir else Path.cwd() path = None - while path is None and not wd.match("/"): + + # Get the root of the current working directory for both Windows and Unix-like systems + root = Path(wd.anchor) + while path is None and wd != root: candidate = wd / STUDIO_PROJECT_CONFIG if Path.exists(candidate): path = candidate diff --git a/src/sagemaker/amazon/common.py b/src/sagemaker/amazon/common.py index 4632bda628..fc5d355749 100644 --- a/src/sagemaker/amazon/common.py +++ b/src/sagemaker/amazon/common.py @@ -13,282 +13,13 @@ """Placeholder docstring""" from __future__ import absolute_import -import io -import logging -import struct -import sys - -import numpy as np - -from sagemaker.amazon.record_pb2 import Record -from sagemaker.deprecations import deprecated_class -from sagemaker.deserializers import SimpleBaseDeserializer -from sagemaker.serializers import SimpleBaseSerializer -from sagemaker.utils import DeferredError - - -class RecordSerializer(SimpleBaseSerializer): - """Serialize a NumPy array for an inference request.""" - - def __init__(self, content_type="application/x-recordio-protobuf"): - """Initialize a ``RecordSerializer`` instance. - - Args: - content_type (str): The MIME type to signal to the inference endpoint when sending - request data (default: "application/x-recordio-protobuf"). - """ - super(RecordSerializer, self).__init__(content_type=content_type) - - def serialize(self, data): - """Serialize a NumPy array into a buffer containing RecordIO records. - - Args: - data (numpy.ndarray): The data to serialize. - - Returns: - io.BytesIO: A buffer containing the data serialized as records. - """ - if len(data.shape) == 1: - data = data.reshape(1, data.shape[0]) - - if len(data.shape) != 2: - raise ValueError( - "Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape) - ) - - buffer = io.BytesIO() - write_numpy_to_dense_tensor(buffer, data) - buffer.seek(0) - - return buffer - - -class RecordDeserializer(SimpleBaseDeserializer): - """Deserialize RecordIO Protobuf data from an inference endpoint.""" - - def __init__(self, accept="application/x-recordio-protobuf"): - """Initialize a ``RecordDeserializer`` instance. - - Args: - accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that - is expected from the inference endpoint (default: - "application/x-recordio-protobuf"). - """ - super(RecordDeserializer, self).__init__(accept=accept) - - def deserialize(self, data, content_type): - """Deserialize RecordIO Protobuf data from an inference endpoint. - - Args: - data (object): The protobuf message to deserialize. - content_type (str): The MIME type of the data. - Returns: - list: A list of records. - """ - try: - return read_records(data) - finally: - data.close() - - -def _write_feature_tensor(resolved_type, record, vector): - """Placeholder Docstring""" - if resolved_type == "Int32": - record.features["values"].int32_tensor.values.extend(vector) - elif resolved_type == "Float64": - record.features["values"].float64_tensor.values.extend(vector) - elif resolved_type == "Float32": - record.features["values"].float32_tensor.values.extend(vector) - - -def _write_label_tensor(resolved_type, record, scalar): - """Placeholder Docstring""" - if resolved_type == "Int32": - record.label["values"].int32_tensor.values.extend([scalar]) - elif resolved_type == "Float64": - record.label["values"].float64_tensor.values.extend([scalar]) - elif resolved_type == "Float32": - record.label["values"].float32_tensor.values.extend([scalar]) - - -def _write_keys_tensor(resolved_type, record, vector): - """Placeholder Docstring""" - if resolved_type == "Int32": - record.features["values"].int32_tensor.keys.extend(vector) - elif resolved_type == "Float64": - record.features["values"].float64_tensor.keys.extend(vector) - elif resolved_type == "Float32": - record.features["values"].float32_tensor.keys.extend(vector) - - -def _write_shape(resolved_type, record, scalar): - """Placeholder Docstring""" - if resolved_type == "Int32": - record.features["values"].int32_tensor.shape.extend([scalar]) - elif resolved_type == "Float64": - record.features["values"].float64_tensor.shape.extend([scalar]) - elif resolved_type == "Float32": - record.features["values"].float32_tensor.shape.extend([scalar]) - - -def write_numpy_to_dense_tensor(file, array, labels=None): - """Writes a numpy array to a dense tensor - - Args: - file: - array: - labels: - """ - - # Validate shape of array and labels, resolve array and label types - if not len(array.shape) == 2: - raise ValueError("Array must be a Matrix") - if labels is not None: - if not len(labels.shape) == 1: - raise ValueError("Labels must be a Vector") - if labels.shape[0] not in array.shape: - raise ValueError( - "Label shape {} not compatible with array shape {}".format( - labels.shape, array.shape - ) - ) - resolved_label_type = _resolve_type(labels.dtype) - resolved_type = _resolve_type(array.dtype) - - # Write each vector in array into a Record in the file object - record = Record() - for index, vector in enumerate(array): - record.Clear() - _write_feature_tensor(resolved_type, record, vector) - if labels is not None: - _write_label_tensor(resolved_label_type, record, labels[index]) - _write_recordio(file, record.SerializeToString()) - - -def write_spmatrix_to_sparse_tensor(file, array, labels=None): - """Writes a scipy sparse matrix to a sparse tensor - - Args: - file: - array: - labels: - """ - try: - import scipy - except ImportError as e: - logging.warning( - "scipy failed to import. Sparse matrix functions will be impaired or broken." - ) - # Any subsequent attempt to use scipy will raise the ImportError - scipy = DeferredError(e) - - if not scipy.sparse.issparse(array): - raise TypeError("Array must be sparse") - - # Validate shape of array and labels, resolve array and label types - if not len(array.shape) == 2: - raise ValueError("Array must be a Matrix") - if labels is not None: - if not len(labels.shape) == 1: - raise ValueError("Labels must be a Vector") - if labels.shape[0] not in array.shape: - raise ValueError( - "Label shape {} not compatible with array shape {}".format( - labels.shape, array.shape - ) - ) - resolved_label_type = _resolve_type(labels.dtype) - resolved_type = _resolve_type(array.dtype) - - csr_array = array.tocsr() - n_rows, n_cols = csr_array.shape - - record = Record() - for row_idx in range(n_rows): - record.Clear() - row = csr_array.getrow(row_idx) - # Write values - _write_feature_tensor(resolved_type, record, row.data) - # Write keys - _write_keys_tensor(resolved_type, record, row.indices.astype(np.uint64)) - - # Write labels - if labels is not None: - _write_label_tensor(resolved_label_type, record, labels[row_idx]) - - # Write shape - _write_shape(resolved_type, record, n_cols) - - _write_recordio(file, record.SerializeToString()) - - -def read_records(file): - """Eagerly read a collection of amazon Record protobuf objects from file. - - Args: - file: - """ - records = [] - for record_data in read_recordio(file): - record = Record() - record.ParseFromString(record_data) - records.append(record) - return records - - -# MXNet requires recordio records have length in bytes that's a multiple of 4 -# This sets up padding bytes to append to the end of the record, for diferent -# amounts of padding required. -padding = {} -for amount in range(4): - if sys.version_info >= (3,): - padding[amount] = bytes([0x00 for _ in range(amount)]) - else: - padding[amount] = bytearray([0x00 for _ in range(amount)]) - -_kmagic = 0xCED7230A - - -def _write_recordio(f, data): - """Writes a single data point as a RecordIO record to the given file. - - Args: - f: - data: - """ - length = len(data) - f.write(struct.pack("I", _kmagic)) - f.write(struct.pack("I", length)) - pad = (((length + 3) >> 2) << 2) - length - f.write(data) - f.write(padding[pad]) - - -def read_recordio(f): - """Placeholder Docstring""" - while True: - try: - (read_kmagic,) = struct.unpack("I", f.read(4)) - except struct.error: - return - assert read_kmagic == _kmagic - (len_record,) = struct.unpack("I", f.read(4)) - pad = (((len_record + 3) >> 2) << 2) - len_record - yield f.read(len_record) - if pad: - f.read(pad) - - -def _resolve_type(dtype): - """Placeholder Docstring""" - if dtype == np.dtype(int): - return "Int32" - if dtype == np.dtype(float): - return "Float64" - if dtype == np.dtype("float32"): - return "Float32" - raise ValueError("Unsupported dtype {} on array".format(dtype)) - - -numpy_to_record_serializer = deprecated_class(RecordSerializer, "numpy_to_record_serializer") -record_deserializer = deprecated_class(RecordDeserializer, "record_deserializer") +# these imports ensure backward compatibility. +from sagemaker.deserializers import RecordDeserializer # noqa: F401 # pylint: disable=W0611 +from sagemaker.serializers import RecordSerializer # noqa: F401 # pylint: disable=W0611 +from sagemaker.serializer_utils import ( # noqa: F401 # pylint: disable=W0611 + read_recordio, + read_records, + write_numpy_to_dense_tensor, + write_spmatrix_to_sparse_tensor, + _write_recordio, +) diff --git a/src/sagemaker/amazon/factorization_machines.py b/src/sagemaker/amazon/factorization_machines.py index 2b24356ee9..1149cd02b2 100644 --- a/src/sagemaker/amazon/factorization_machines.py +++ b/src/sagemaker/amazon/factorization_machines.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt, isin, ge +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/hyperparameter.py b/src/sagemaker/amazon/hyperparameter.py index 856927cb13..b479f8a271 100644 --- a/src/sagemaker/amazon/hyperparameter.py +++ b/src/sagemaker/amazon/hyperparameter.py @@ -28,7 +28,7 @@ def __init__(self, name, validate=lambda _: True, validation_message="", data_ty """Args: name (str): The name of this hyperparameter validate - (callable[object]->[bool]): A validation function or list of validation + (Callable[object]->[bool]): A validation function or list of validation functions. Each function validates an object and returns False if the object diff --git a/src/sagemaker/amazon/ipinsights.py b/src/sagemaker/amazon/ipinsights.py index 737d13dd44..bc8e1b5d86 100644 --- a/src/sagemaker/amazon/ipinsights.py +++ b/src/sagemaker/amazon/ipinsights.py @@ -209,7 +209,7 @@ def __init__( chain. serializer (sagemaker.serializers.BaseSerializer): Optional. Default serializes input data to text/csv. - deserializer (callable): Optional. Default parses JSON responses + deserializer (Callable): Optional. Default parses JSON responses using ``json.load(...)``. component_name (str): Optional. Name of the Amazon SageMaker inference component corresponding the predictor. diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index 144cdc934a..25abb9cb27 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt, isin, ge, le +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/knn.py b/src/sagemaker/amazon/knn.py index f9c73381b4..89ec979e09 100644 --- a/src/sagemaker/amazon/knn.py +++ b/src/sagemaker/amazon/knn.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import ge, isin +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/lda.py b/src/sagemaker/amazon/lda.py index bd64d3ae2e..c57da9643e 100644 --- a/src/sagemaker/amazon/lda.py +++ b/src/sagemaker/amazon/lda.py @@ -18,11 +18,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer +from sagemaker.deserializers import RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/linear_learner.py b/src/sagemaker/amazon/linear_learner.py index 695eb31dc1..4533dcdaea 100644 --- a/src/sagemaker/amazon/linear_learner.py +++ b/src/sagemaker/amazon/linear_learner.py @@ -18,11 +18,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer +from sagemaker.deserializers import RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import isin, gt, lt, ge, le from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/ntm.py b/src/sagemaker/amazon/ntm.py index 4267ac8969..41dde1c33c 100644 --- a/src/sagemaker/amazon/ntm.py +++ b/src/sagemaker/amazon/ntm.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import ge, le, isin +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/pca.py b/src/sagemaker/amazon/pca.py index 953fff9d0b..b724435afa 100644 --- a/src/sagemaker/amazon/pca.py +++ b/src/sagemaker/amazon/pca.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt, isin +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/randomcutforest.py b/src/sagemaker/amazon/randomcutforest.py index 21d98741b0..d60d5a7741 100644 --- a/src/sagemaker/amazon/randomcutforest.py +++ b/src/sagemaker/amazon/randomcutforest.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import ge, le +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index bb4059c03a..e18d7ba2b9 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -478,7 +478,7 @@ def create_model( training cluster for distributed training. Default: False model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - predictor_cls (callable[string, sagemaker.session.Session]): A + Callable[[string, sagemaker.session.Session], Any]: A function to call to create a predictor (default: None). If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -591,7 +591,7 @@ def deploy( training cluster for distributed training. Default: False model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -609,7 +609,7 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Returns: - callable[string, sagemaker.session.Session] or ``None``: + Optional[Callable[[string, sagemaker.session.Session], Any]]: If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on the created endpoint name. Otherwise, ``None``. """ diff --git a/src/sagemaker/automl/automlv2.py b/src/sagemaker/automl/automlv2.py index 0819e5384e..b071be3b24 100644 --- a/src/sagemaker/automl/automlv2.py +++ b/src/sagemaker/automl/automlv2.py @@ -1022,7 +1022,7 @@ def create_model( training cluster for distributed training. Default: False model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -1130,7 +1130,7 @@ def deploy( training cluster for distributed training. Default: False model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -1148,7 +1148,7 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Returns: - callable[string, sagemaker.session.Session] or ``None``: + Optional[Callable[[string, sagemaker.session.Session], Any]]: If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on the created endpoint name. Otherwise, ``None``. """ diff --git a/src/sagemaker/base_deserializers.py b/src/sagemaker/base_deserializers.py index a152f0144d..ded68fc4b0 100644 --- a/src/sagemaker/base_deserializers.py +++ b/src/sagemaker/base_deserializers.py @@ -23,6 +23,7 @@ import numpy as np from six import with_metaclass +from sagemaker.serializer_utils import read_records from sagemaker.utils import DeferredError try: @@ -388,3 +389,31 @@ def deserialize(self, stream, content_type="tensor/pt"): "Unable to deserialize your data to torch.Tensor.\ Please provide custom deserializer in InferenceSpec." ) + + +class RecordDeserializer(SimpleBaseDeserializer): + """Deserialize RecordIO Protobuf data from an inference endpoint.""" + + def __init__(self, accept="application/x-recordio-protobuf"): + """Initialize a ``RecordDeserializer`` instance. + + Args: + accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that + is expected from the inference endpoint (default: + "application/x-recordio-protobuf"). + """ + super(RecordDeserializer, self).__init__(accept=accept) + + def deserialize(self, data, content_type): + """Deserialize RecordIO Protobuf data from an inference endpoint. + + Args: + data (object): The protobuf message to deserialize. + content_type (str): The MIME type of the data. + Returns: + list: A list of records. + """ + try: + return read_records(data) + finally: + data.close() diff --git a/src/sagemaker/base_serializers.py b/src/sagemaker/base_serializers.py index 45fea23493..0e1df120ff 100644 --- a/src/sagemaker/base_serializers.py +++ b/src/sagemaker/base_serializers.py @@ -22,6 +22,7 @@ from pandas import DataFrame from six import with_metaclass +from sagemaker.serializer_utils import write_numpy_to_dense_tensor from sagemaker.utils import DeferredError try: @@ -466,3 +467,39 @@ def serialize(self, data): ) raise ValueError("Object of type %s is not a torch.Tensor" % type(data)) + + +class RecordSerializer(SimpleBaseSerializer): + """Serialize a NumPy array for an inference request.""" + + def __init__(self, content_type="application/x-recordio-protobuf"): + """Initialize a ``RecordSerializer`` instance. + + Args: + content_type (str): The MIME type to signal to the inference endpoint when sending + request data (default: "application/x-recordio-protobuf"). + """ + super(RecordSerializer, self).__init__(content_type=content_type) + + def serialize(self, data): + """Serialize a NumPy array into a buffer containing RecordIO records. + + Args: + data (numpy.ndarray): The data to serialize. + + Returns: + io.BytesIO: A buffer containing the data serialized as records. + """ + if len(data.shape) == 1: + data = data.reshape(1, data.shape[0]) + + if len(data.shape) != 2: + raise ValueError( + "Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape) + ) + + buffer = io.BytesIO() + write_numpy_to_dense_tensor(buffer, data) + buffer.seek(0) + + return buffer diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 806009b0f6..c2d2187b69 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -96,7 +96,7 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[str] = None, py_version: Optional[str] = None, - predictor_cls: callable = ChainerPredictor, + predictor_cls: Optional[Callable] = ChainerPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): @@ -125,7 +125,7 @@ def __init__( py_version (str): Python version you want to use for executing your model training code. Defaults to ``None``. Required unless ``image_uri`` is provided. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py index 0e2aabbec4..54bccba55e 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py @@ -51,8 +51,8 @@ "StreamDeserializer": ("sagemaker.deserializers",), "NumpyDeserializer": ("sagemaker.deserializers",), "JSONDeserializer": ("sagemaker.deserializers",), - "RecordSerializer ": ("sagemaker.amazon.common",), - "RecordDeserializer": ("sagemaker.amazon.common",), + "RecordSerializer ": ("sagemaker.serializers",), + "RecordDeserializer": ("sagemaker.deserializers",), } OLD_CLASS_NAME_TO_NEW_CLASS_NAME = { @@ -101,8 +101,8 @@ def node_should_be_modified(self, node): - ``sagemaker.predictor.StreamDeserializer`` - ``sagemaker.predictor._NumpyDeserializer`` - ``sagemaker.predictor._JsonDeserializer`` - - ``sagemaker.amazon.common.numpy_to_record_serializer`` - - ``sagemaker.amazon.common.record_deserializer`` + - ``sagemaker.serializers.numpy_to_record_serializer`` + - ``sagemaker.deserializers.record_deserializer`` Args: node (ast.Call): a node that represents a function call. For more, @@ -128,8 +128,8 @@ def modify_node(self, node): - ``sagemaker.deserializers.StreamDeserializer`` - ``sagemaker.deserializers.NumpyDeserializer`` - ``sagemaker.deserializers._JsonDeserializer`` - - ``sagemaker.amazon.common.RecordSerializer`` - - ``sagemaker.amazon.common.RecordDeserializer`` + - ``sagemaker.serializers.RecordSerializer`` + - ``sagemaker.deserializers.RecordDeserializer`` Args: node (ast.Call): a node that represents a SerDe constructor. @@ -303,8 +303,8 @@ def node_should_be_modified(self, node): """Checks if the import statement imports a SerDe from the ``sagemaker.amazon.common``. This checks for: - - ``sagemaker.amazon.common.numpy_to_record_serializer`` - - ``sagemaker.amazon.common.record_deserializer`` + - ``sagemaker.serializers.numpy_to_record_serializer`` + - ``sagemaker.deserializers.record_deserializer`` Args: node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. @@ -322,8 +322,8 @@ def modify_node(self, node): """Upgrades the ``numpy_to_record_serializer`` and ``record_deserializer`` imports. This upgrades the classes to (if applicable): - - ``sagemaker.amazon.common.RecordSerializer`` - - ``sagemaker.amazon.common.RecordDeserializer`` + - ``sagemaker.serializers.RecordSerializer`` + - ``sagemaker.deserializers.RecordDeserializer`` Args: node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index 34a98c0b8e..61da17e7cf 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -540,7 +540,8 @@ def _simple_path(*args: str): "minItems": 0, "maxItems": 50, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html#sagemaker-CreateTrainingJob-request-Environment + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_CreateTrainingJob.html#sagemaker-CreateTrainingJob-request-Environment "environmentVariables": { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, @@ -553,13 +554,15 @@ def _simple_path(*args: str): }, "maxProperties": 48, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html#sagemaker-Type-S3DataSource-S3Uri + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_S3DataSource.html#sagemaker-Type-S3DataSource-S3Uri "s3Uri": { TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint "preExecutionCommand": {TYPE: "string", "pattern": r".*"}, # Regex based on https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_PipelineDefinitionS3Location.html # except with an additional ^ and $ for the beginning and the end to closer align to @@ -570,7 +573,8 @@ def _simple_path(*args: str): "minLength": 3, "maxLength": 63, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_MonitoringJobDefinition.html#sagemaker-Type-MonitoringJobDefinition-Environment + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_MonitoringJobDefinition.html#sagemaker-Type-MonitoringJobDefinition-Environment "environment-Length256-Properties50": { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, @@ -583,7 +587,8 @@ def _simple_path(*args: str): }, "maxProperties": 50, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTransformJob.html#sagemaker-CreateTransformJob-request-Environment + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_CreateTransformJob.html#sagemaker-CreateTransformJob-request-Environment "environment-Length10240-Properties16": { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, @@ -596,7 +601,8 @@ def _simple_path(*args: str): }, "maxProperties": 16, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ContainerDefinition.html#sagemaker-Type-ContainerDefinition-Environment + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_ContainerDefinition.html#sagemaker-Type-ContainerDefinition-Environment "environment-Length1024-Properties16": { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, @@ -609,7 +615,8 @@ def _simple_path(*args: str): }, "maxProperties": 16, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html#sagemaker-CreateProcessingJob-request-Environment + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_CreateProcessingJob.html#sagemaker-CreateProcessingJob-request-Environment "environment-Length256-Properties100": { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, @@ -622,7 +629,8 @@ def _simple_path(*args: str): }, "maxProperties": 100, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html#sagemaker-CreateTrainingJob-request-Environment + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_CreateTrainingJob.html#sagemaker-CreateTrainingJob-request-Environment "environment-Length512-Properties48": { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 957a9dfb0c..dad5137329 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -31,8 +31,10 @@ StreamDeserializer, StringDeserializer, TorchTensorDeserializer, + RecordDeserializer, ) +from sagemaker.deprecations import deprecated_class from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartModelType @@ -150,3 +152,6 @@ def retrieve_default( model_type=model_type, config_name=config_name, ) + + +record_deserializer = deprecated_class(RecordDeserializer, "record_deserializer") diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index 8c724a6502..94db4efe29 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Dict, Any +from typing import Callable, Optional, Dict, Any from sagemaker import image_uris from sagemaker.model import Model @@ -54,7 +54,7 @@ def __init__( parallel_loading: bool = False, model_loading_timeout: Optional[int] = None, prediction_timeout: Optional[int] = None, - predictor_cls: callable = DJLPredictor, + predictor_cls: Optional[Callable] = DJLPredictor, huggingface_hub_token: Optional[str] = None, **kwargs, ): @@ -97,10 +97,10 @@ def __init__( None. If not provided, the default is 240 seconds. prediction_timeout (int): The worker predict call (handler) timeout in seconds. Defaults to None. If not provided, the default is 120 seconds. - predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a - predictor with an endpoint name and SageMaker ``Session``. If specified, - ``deploy()`` returns - the result of invoking this function on the created endpoint name. + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call + to create a predictor with an endpoint name and SageMaker ``Session``. If + specified, ``deploy()`` returns the result of invoking this function on the created + endpoint name. huggingface_hub_token (str): The HuggingFace Hub token to use for downloading the model artifacts for a model stored on the huggingface hub. Defaults to None. If not provided, the token must be specified in the diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 6efc04c88e..fa40719c9f 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -387,8 +387,8 @@ def __init__( source_dir (str or PipelineVariable): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. The structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. The structure within this directory + is preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. With the following GitHub repo directory structure: @@ -2550,7 +2550,6 @@ def _get_train_args(cls, estimator, inputs, experiment_config): raise ValueError( "File URIs are supported in local mode only. Please use a S3 URI instead." ) - config = _Job._load_config(inputs, estimator) current_hyperparameters = estimator.hyperparameters() @@ -3421,8 +3420,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. diff --git a/src/sagemaker/experiments/_metrics.py b/src/sagemaker/experiments/_metrics.py index 31dd679cc8..026e73e8a6 100644 --- a/src/sagemaker/experiments/_metrics.py +++ b/src/sagemaker/experiments/_metrics.py @@ -197,8 +197,8 @@ def _send_metrics(self, metrics): response = self._metrics_client.batch_put_metrics(**request) errors = response["Errors"] if "Errors" in response else None if errors: - message = errors[0]["Message"] - raise Exception(f'{len(errors)} errors with message "{message}"') + error_code = errors[0]["Code"] + raise Exception(f'{len(errors)} errors with error code "{error_code}"') def _construct_batch_put_metrics_request(self, batch): """Creates dictionary object used as request to metrics service.""" diff --git a/src/sagemaker/feature_store/dataset_builder.py b/src/sagemaker/feature_store/dataset_builder.py index 289fa1ee0c..fc9f9372b1 100644 --- a/src/sagemaker/feature_store/dataset_builder.py +++ b/src/sagemaker/feature_store/dataset_builder.py @@ -929,7 +929,7 @@ def _construct_query_string(self, base: FeatureGroupToBeMerged) -> str: selected_features += ", " selected_features += ", ".join( [ - f'fg_{i}."{feature_name}" as "{feature_name}.{(i+1)}"' + f'fg_{i}."{feature_name}" as "{feature_name}.{(i + 1)}"' for feature_name in feature_group.projected_feature_names ] ) diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 39915b60dc..4eb8d82b0c 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -631,7 +631,7 @@ def __str__(self) -> str: class FeatureGroup: """FeatureGroup definition. - This class instantiates a FeatureGroup object that comprises of a name for the FeatureGroup, + This class instantiates a FeatureGroup object that comprises a name for the FeatureGroup, session instance, and a list of feature definition objects i.e., FeatureDefinition. Attributes: diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 0ddb3cd255..234f0c61fa 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -10,30 +10,29 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Utility methods used by framework classes""" +"""Utility methods used by framework classes.""" from __future__ import absolute_import import json import logging import os import re -import time import shutil import tempfile +import time from collections import namedtuple -from typing import List, Optional, Union, Dict +from typing import Dict, List, Optional, Union + from packaging import version import sagemaker.image_uris +import sagemaker.utils +from sagemaker.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning from sagemaker.instance_group import InstanceGroup from sagemaker.s3_utils import s3_path_join from sagemaker.session_settings import SessionSettings -import sagemaker.utils from sagemaker.workflow import is_pipeline_variable - -from sagemaker.deprecations import renamed_warning, renamed_kwargs from sagemaker.workflow.entities import PipelineVariable -from sagemaker.deprecations import deprecation_warn_base logger = logging.getLogger(__name__) @@ -41,6 +40,7 @@ UploadedCode = namedtuple("UploadedCode", ["s3_prefix", "script_name"]) """sagemaker.fw_utils.UploadedCode: An object containing the S3 prefix and script name. + This is for the source code used for the entry point with an ``Estimator``. It can be instantiated with positional or keyword arguments. """ @@ -152,8 +152,10 @@ "2.1.0", "2.1.2", "2.2.0", + "2.3.0", "2.3.1", "2.4.1", + "2.5.1", ] TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] @@ -210,7 +212,7 @@ def validate_source_code_input_against_pipeline_variables( git_config: Optional[Dict[str, str]] = None, enable_network_isolation: Union[bool, PipelineVariable] = False, ): - """Validate source code input against pipeline variables + """Validate source code input against pipeline variables. Args: entry_point (str or PipelineVariable): The path to the local Python source file that @@ -251,7 +253,7 @@ def validate_source_code_input_against_pipeline_variables( logger.warning( "The source_dir is a pipeline variable: %s. During pipeline execution, " "the interpreted value of source_dir has to be an S3 URI and " - "must point to a tar.gz file", + "must point to a file with name ``sourcedir.tar.gz``", type(source_dir), ) @@ -480,7 +482,7 @@ def tar_and_upload_dir( def _list_files_to_compress(script, directory): - """Placeholder docstring""" + """Placeholder docstring.""" if directory is None: return [script] @@ -584,7 +586,6 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image): The location returned is a potential concatenation of 2 parts 1. code_location_key_prefix if it exists 2. model_name or a name derived from the image - Args: code_location_key_prefix (str): the s3 key prefix from code_location model_name (str): the name of the model @@ -619,8 +620,6 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution "enabled": True } } - - """ if training_instance_type == "local" or distribution is None: return @@ -645,7 +644,7 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution def profiler_config_deprecation_warning( profiler_config, image_uri, framework_name, framework_version ): - """Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0""" + """Deprecation message if framework profiling is specified TF >= 2.12 and PT >= 2.0.""" if profiler_config is None or profiler_config.framework_profile_params is None: return @@ -691,6 +690,7 @@ def validate_smdistributed( framework_name (str): A string representing the name of framework selected. framework_version (str): A string representing the framework version selected. py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` distribution (dict): A dictionary with information to enable distributed training. (Defaults to None if distributed training is not enabled.) For example: @@ -762,7 +762,8 @@ def _validate_smdataparallel_args( instance_type (str): A string representing the type of training instance selected. Ex: `ml.p3.16xlarge` framework_name (str): A string representing the name of framework selected. Ex: `tensorflow` framework_version (str): A string representing the framework version selected. Ex: `2.3.1` - py_version (str): A string representing the python version selected. Ex: `py3` + py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` distribution (dict): A dictionary with information to enable distributed training. (Defaults to None if distributed training is not enabled.) Ex: @@ -846,6 +847,7 @@ def validate_distribution( framework_name (str): A string representing the name of framework selected. framework_version (str): A string representing the framework version selected. py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` image_uri (str): A string representing a Docker image URI. kwargs(dict): Additional kwargs passed to this function @@ -952,7 +954,7 @@ def validate_distribution( def validate_distribution_for_instance_type(instance_type, distribution): - """Check if the provided distribution strategy is supported for the instance_type + """Check if the provided distribution strategy is supported for the instance_type. Args: instance_type (str): A string representing the type of training instance selected. @@ -1009,6 +1011,7 @@ def validate_torch_distributed_distribution( } framework_version (str): A string representing the framework version selected. py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` image_uri (str): A string representing a Docker image URI. entry_point (str or PipelineVariable): The absolute or relative path to the local Python source file that should be executed as the entry point to @@ -1071,7 +1074,7 @@ def validate_torch_distributed_distribution( def _is_gpu_instance(instance_type): - """Returns bool indicating whether instance_type supports GPU + """Returns bool indicating whether instance_type supports GPU. Args: instance_type (str): Name of the instance_type to check against. @@ -1090,7 +1093,7 @@ def _is_gpu_instance(instance_type): def _is_trainium_instance(instance_type): - """Returns bool indicating whether instance_type is a Trainium instance + """Returns bool indicating whether instance_type is a Trainium instance. Args: instance_type (str): Name of the instance_type to check against. @@ -1106,7 +1109,7 @@ def _is_trainium_instance(instance_type): def python_deprecation_warning(framework, latest_supported_version): - """Placeholder docstring""" + """Placeholder docstring.""" return PYTHON_2_DEPRECATION_WARNING.format( framework=framework, latest_supported_version=latest_supported_version ) @@ -1120,7 +1123,6 @@ def _region_supports_debugger(region_name): Returns: bool: Whether or not the region supports Amazon SageMaker Debugger. - """ return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS @@ -1133,7 +1135,6 @@ def _region_supports_profiler(region_name): Returns: bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature. - """ return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS @@ -1161,7 +1162,8 @@ def validate_version_or_image_args(framework_version, py_version, image_uri): Args: framework_version (str): The version of the framework. - py_version (str): The version of Python. + py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` image_uri (str): The URI of the image. Raises: @@ -1193,9 +1195,8 @@ def create_image_uri( instance_type (str): SageMaker instance type. Used to determine device type (cpu/gpu/family-specific optimized). framework_version (str): The version of the framework. - py_version (str): Optional. Python version. If specified, should be one - of 'py2' or 'py3'. If not specified, image uri will not include a - python component. + py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`. + If not specified, image uri will not include a python component. account (str): AWS account that contains the image. (default: '520713654638') accelerator_type (str): SageMaker Elastic Inference accelerator type. diff --git a/src/sagemaker/huggingface/estimator.py b/src/sagemaker/huggingface/estimator.py index 86df43d4e9..70cc17b209 100644 --- a/src/sagemaker/huggingface/estimator.py +++ b/src/sagemaker/huggingface/estimator.py @@ -15,17 +15,13 @@ import logging import re -from typing import Optional, Union, Dict +from typing import Dict, Optional, Union -from sagemaker.estimator import Framework, EstimatorBase -from sagemaker.fw_utils import ( - framework_name_from_image, - validate_distribution, -) +from sagemaker.estimator import EstimatorBase, Framework +from sagemaker.fw_utils import framework_name_from_image, validate_distribution from sagemaker.huggingface.model import HuggingFaceModel -from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT - from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig +from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -66,7 +62,7 @@ def __init__( Args: py_version (str): Python version you want to use for executing your model training code. Defaults to ``None``. Required unless ``image_uri`` is provided. If - using PyTorch, the current supported version is ``py36``. If using TensorFlow, + using PyTorch, the current supported version is ``py39``. If using TensorFlow, the current supported version is ``py37``. entry_point (str or PipelineVariable): Path (absolute or relative) to the Python source file which should be executed as the entry point to training. @@ -84,8 +80,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory are + preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index ea99be2fc0..3ca25fb3ce 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -123,7 +123,7 @@ def __init__( pytorch_version: Optional[str] = None, py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = HuggingFacePredictor, + predictor_cls: Optional[Callable] = HuggingFacePredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): @@ -158,7 +158,7 @@ def __init__( If not specified, a default image for PyTorch will be used. If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If also ``None``, then a ``ValueError`` will be raised. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -218,6 +218,7 @@ def deploy( container_startup_health_check_timeout=None, inference_recommendation_id=None, explainer_config=None, + update_endpoint: Optional[bool] = False, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -296,6 +297,11 @@ def deploy( would like to deploy the model and endpoint with recommended parameters. explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability configuration for use with Amazon SageMaker Clarify. (default: None) + update_endpoint (Optional[bool]): + Flag to update the model in an existing Amazon SageMaker endpoint. + If True, this will deploy a new EndpointConfig to an already existing endpoint + and delete resources corresponding to the previous EndpointConfig. Default: False + Note: Currently this is supported for single model endpoints Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -304,7 +310,7 @@ def deploy( - If a wrong type of object is provided as serverless inference config or async inference config Returns: - callable[string, sagemaker.session.Session] or None: Invocation of + Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. """ @@ -335,6 +341,7 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, inference_recommendation_id=inference_recommendation_id, explainer_config=explainer_config, + update_endpoint=update_endpoint, **kwargs, ) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 68fcd8ca8a..f1edd9d287 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -12,7 +12,8 @@ "0.7": "0.7.0", "0.8": "0.8.2", "1.0": "1.0.0", - "1.1": "1.1.1" + "1.1": "1.1.1", + "1.2": "1.2.0" }, "versions": { "0.3.1": { @@ -563,6 +564,47 @@ "py_versions": [ "py311" ] + }, + "1.2.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "autogluon-training", + "processors": [ + "cpu", + "gpu" + ], + "py_versions": [ + "py311" + ] } } }, @@ -575,7 +617,8 @@ "0.7": "0.7.0", "0.8": "0.8.2", "1.0": "1.0.0", - "1.1": "1.1.1" + "1.1": "1.1.1", + "1.2": "1.2.0" }, "versions": { "0.3.1": { @@ -1157,6 +1200,49 @@ "py_versions": [ "py311" ] + }, + "1.2.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "autogluon-inference", + "processors": [ + "cpu", + "gpu" + ], + "py_versions": [ + "py311" + ] } } } diff --git a/src/sagemaker/image_uri_config/djl-neuronx.json b/src/sagemaker/image_uri_config/djl-neuronx.json index 3fd3c7619f..1fd7492ff4 100644 --- a/src/sagemaker/image_uri_config/djl-neuronx.json +++ b/src/sagemaker/image_uri_config/djl-neuronx.json @@ -13,12 +13,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -37,12 +39,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -61,12 +65,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -85,12 +91,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -109,12 +117,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -133,12 +143,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -157,12 +169,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -181,12 +195,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json b/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json index 7b96b60ff8..d79e7637ed 100644 --- a/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json +++ b/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json @@ -4,7 +4,7 @@ "inf2" ], "version_aliases": { - "0.0": "0.0.25" + "0.0": "0.0.27" }, "versions": { "0.0.16": { @@ -12,27 +12,47 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "1.13.1-optimum0.0.16", "repository": "huggingface-pytorch-tgi-inference", @@ -45,27 +65,47 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "1.13.1-optimum0.0.17", "repository": "huggingface-pytorch-tgi-inference", @@ -78,27 +118,47 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "1.13.1-optimum0.0.18", "repository": "huggingface-pytorch-tgi-inference", @@ -111,27 +171,47 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "1.13.1-optimum0.0.19", "repository": "huggingface-pytorch-tgi-inference", @@ -144,27 +224,47 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "1.13.1-optimum0.0.20", "repository": "huggingface-pytorch-tgi-inference", @@ -177,27 +277,47 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "1.13.1-optimum0.0.21", "repository": "huggingface-pytorch-tgi-inference", @@ -210,25 +330,47 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "2.1.2-optimum0.0.22", "repository": "huggingface-pytorch-tgi-inference", @@ -241,27 +383,47 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "2.1.2-optimum0.0.23", "repository": "huggingface-pytorch-tgi-inference", @@ -274,27 +436,47 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "2.1.2-optimum0.0.24", "repository": "huggingface-pytorch-tgi-inference", @@ -307,34 +489,107 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "2.1.2-optimum0.0.25", "repository": "huggingface-pytorch-tgi-inference", "container_version": { "inf2": "ubuntu22.04" } + }, + "0.0.27": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.1.2-optimum0.0.27", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "inf2": "ubuntu22.04" + } } } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/huggingface-llm.json b/src/sagemaker/image_uri_config/huggingface-llm.json index 24cbd5ca96..ed85f0d2bf 100644 --- a/src/sagemaker/image_uri_config/huggingface-llm.json +++ b/src/sagemaker/image_uri_config/huggingface-llm.json @@ -12,7 +12,11 @@ "1.2": "1.2.0", "1.3": "1.3.3", "1.4": "1.4.5", - "2.0": "2.3.1" + "2.0": "2.4.0", + "2.3": "2.3.1", + "3.0": "3.0.1", + "3.2": "3.2.3", + "3.1": "3.1.1" }, "versions": { "0.6.0": { @@ -21,8 +25,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -32,19 +36,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -52,9 +61,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.0.0-tgi0.6.0", "repository": "huggingface-pytorch-tgi-inference", @@ -68,8 +78,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -79,19 +89,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -99,9 +114,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.0.0-tgi0.8.2", "repository": "huggingface-pytorch-tgi-inference", @@ -115,8 +131,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -126,19 +142,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -146,9 +167,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.0.1-tgi0.9.3", "repository": "huggingface-pytorch-tgi-inference", @@ -162,8 +184,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -173,19 +195,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -193,9 +220,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.0.1-tgi1.0.3", "repository": "huggingface-pytorch-tgi-inference", @@ -209,8 +237,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -220,19 +248,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -240,9 +273,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.0.1-tgi1.1.0", "repository": "huggingface-pytorch-tgi-inference", @@ -256,8 +290,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -267,19 +301,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -287,9 +326,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.1.1-tgi1.2.0", "repository": "huggingface-pytorch-tgi-inference", @@ -303,8 +343,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -314,19 +354,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -334,9 +379,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.1.1-tgi1.3.1", "repository": "huggingface-pytorch-tgi-inference", @@ -350,8 +396,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -361,19 +407,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -381,9 +432,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.1.1-tgi1.3.3", "repository": "huggingface-pytorch-tgi-inference", @@ -397,8 +449,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -408,19 +460,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", "eu-south-1": "692866216735", "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -428,9 +485,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.1.1-tgi1.4.0", "repository": "huggingface-pytorch-tgi-inference", @@ -444,8 +502,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -455,19 +513,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", "eu-south-1": "692866216735", "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -475,9 +538,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.1.1-tgi1.4.2", "repository": "huggingface-pytorch-tgi-inference", @@ -491,8 +555,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -502,19 +566,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", "eu-south-1": "692866216735", "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -522,9 +591,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.1.1-tgi1.4.5", "repository": "huggingface-pytorch-tgi-inference", @@ -538,8 +608,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -549,19 +619,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", "eu-south-1": "692866216735", "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -569,9 +644,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.1.1-tgi2.0.0", "repository": "huggingface-pytorch-tgi-inference", @@ -585,8 +661,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -596,19 +672,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", "eu-south-1": "692866216735", "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -616,9 +697,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.1.1-tgi2.0.1", "repository": "huggingface-pytorch-tgi-inference", @@ -632,8 +714,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -643,19 +725,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", "eu-south-1": "692866216735", "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -663,9 +750,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.3.0-tgi2.0.2", "repository": "huggingface-pytorch-tgi-inference", @@ -679,8 +767,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -690,19 +778,24 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", "eu-south-1": "692866216735", "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -710,9 +803,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.3.0-tgi2.2.0", "repository": "huggingface-pytorch-tgi-inference", @@ -726,8 +820,61 @@ ], "registries": { "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.4.0-tgi2.3.1", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu124-ubuntu22.04" + } + }, + "2.4.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -737,19 +884,130 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", - "eu-west-3": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.4.0-tgi2.4.0", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu124-ubuntu22.04-v2.2" + } + }, + "3.0.1": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", "eu-south-1": "692866216735", "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.4.0-tgi3.0.1", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu124-ubuntu22.04-v2.1" + } + }, + "3.1.1": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -757,11 +1015,118 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, - "tag_prefix": "2.4.0-tgi2.3.1", + "tag_prefix": "2.6.0-tgi3.1.1", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu124-ubuntu22.04" + } + }, + "3.2.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.6.0-tgi3.2.0", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu124-ubuntu22.04" + } + }, + "3.2.3": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.6.0-tgi3.2.3", "repository": "huggingface-pytorch-tgi-inference", "container_version": { "gpu": "cu124-ubuntu22.04" @@ -769,4 +1134,4 @@ } } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/huggingface-neuron.json b/src/sagemaker/image_uri_config/huggingface-neuron.json index ae38ce209b..4e950bdb70 100644 --- a/src/sagemaker/image_uri_config/huggingface-neuron.json +++ b/src/sagemaker/image_uri_config/huggingface-neuron.json @@ -23,12 +23,15 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-neuronx.json b/src/sagemaker/image_uri_config/huggingface-neuronx.json index 5b45e37586..a3426d5e0c 100644 --- a/src/sagemaker/image_uri_config/huggingface-neuronx.json +++ b/src/sagemaker/image_uri_config/huggingface-neuronx.json @@ -25,6 +25,8 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -33,6 +35,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -65,6 +68,8 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -73,6 +78,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -105,6 +111,8 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -113,6 +121,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -162,6 +171,8 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -174,6 +185,7 @@ "eu-south-1": "692866216735", "eu-south-2": "503227376785", "me-south-1": "217643126080", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -214,6 +226,8 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -226,6 +240,7 @@ "eu-south-1": "692866216735", "eu-south-2": "503227376785", "me-south-1": "217643126080", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -267,6 +282,8 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -279,6 +296,7 @@ "eu-south-1": "692866216735", "eu-south-2": "503227376785", "me-south-1": "217643126080", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -314,6 +332,8 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -326,6 +346,7 @@ "eu-south-1": "692866216735", "eu-south-2": "503227376785", "me-south-1": "217643126080", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-training-compiler.json b/src/sagemaker/image_uri_config/huggingface-training-compiler.json index 735e7917b3..fa3a4119ca 100644 --- a/src/sagemaker/image_uri_config/huggingface-training-compiler.json +++ b/src/sagemaker/image_uri_config/huggingface-training-compiler.json @@ -69,6 +69,8 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-central-2": "380420809688", @@ -80,6 +82,7 @@ "eu-west-3": "763104351884", "me-south-1": "217643126080", "me-central-1": "914824155844", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -109,6 +112,8 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-central-2": "380420809688", @@ -120,6 +125,7 @@ "eu-west-3": "763104351884", "me-south-1": "217643126080", "me-central-1": "914824155844", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -154,6 +160,8 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-central-2": "380420809688", @@ -165,6 +173,7 @@ "eu-west-3": "763104351884", "me-south-1": "217643126080", "me-central-1": "914824155844", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index 930b24566d..c314436346 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -13,7 +13,8 @@ "4.17": "4.17.0", "4.26": "4.26.0", "4.28": "4.28.1", - "4.36": "4.36.0" + "4.36": "4.36.0", + "4.46": "4.46.1" }, "versions": { "4.4.2": { @@ -1018,6 +1019,53 @@ "gpu": "cu121-ubuntu20.04" } } + }, + "4.46.1": { + "version_aliases": { + "pytorch2.3": "pytorch2.3.0" + }, + "pytorch2.3.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "huggingface-pytorch-training", + "container_version": { + "gpu": "cu121-ubuntu20.04" + } + } } } }, @@ -1883,6 +1931,58 @@ "cpu": "ubuntu22.04" } } + }, + "4.48.0": { + "version_aliases": { + "pytorch2.3": "pytorch2.3.0" + }, + "pytorch2.3.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "huggingface-pytorch-inference", + "container_version": { + "gpu": "cu121-ubuntu22.04", + "cpu": "ubuntu22.04" + } + } } } } diff --git a/src/sagemaker/image_uri_config/instance_gpu_info.json b/src/sagemaker/image_uri_config/instance_gpu_info.json index 9fc005bc47..e64a9bcf88 100644 --- a/src/sagemaker/image_uri_config/instance_gpu_info.json +++ b/src/sagemaker/image_uri_config/instance_gpu_info.json @@ -23,7 +23,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-east-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -49,7 +49,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-northeast-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -75,7 +75,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-northeast-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -101,7 +101,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-northeast-3": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -127,7 +127,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-south-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -153,7 +153,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-southeast-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -179,7 +179,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-southeast-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -205,7 +205,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-southeast-3": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -231,7 +231,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ca-central-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -257,7 +257,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "cn-north-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -283,7 +283,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "cn-northwest-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -309,7 +309,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-central-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -335,7 +335,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-central-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -361,7 +361,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-north-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -387,7 +387,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-south-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -413,7 +413,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-south-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -439,7 +439,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-west-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -465,7 +465,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-west-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -491,7 +491,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-west-3": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -517,7 +517,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "il-central-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -543,7 +543,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "me-central-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -569,7 +569,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "me-south-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -595,7 +595,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "sa-east-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -621,7 +621,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "us-east-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -647,7 +647,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "us-east-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -673,7 +673,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "us-gov-east-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -699,7 +699,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "us-gov-west-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -725,7 +725,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "us-west-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -751,7 +751,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "us-west-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -777,6 +777,6 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} } } \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/pytorch-smp.json b/src/sagemaker/image_uri_config/pytorch-smp.json index 449726927a..53c2a75e13 100644 --- a/src/sagemaker/image_uri_config/pytorch-smp.json +++ b/src/sagemaker/image_uri_config/pytorch-smp.json @@ -9,7 +9,8 @@ "2.2": "2.3.1", "2.2.0": "2.3.1", "2.3.1": "2.5.0", - "2.4.1": "2.7.0" + "2.4.1": "2.7.0", + "2.5.1": "2.8.0" }, "versions": { "2.0.1": { @@ -186,6 +187,31 @@ "us-west-2": "658645717510" }, "repository": "smdistributed-modelparallel" + }, + "2.8.0": { + "py_versions": [ + "py311" + ], + "registries": { + "ap-northeast-1": "658645717510", + "ap-northeast-2": "658645717510", + "ap-northeast-3": "658645717510", + "ap-south-1": "658645717510", + "ap-southeast-1": "658645717510", + "ap-southeast-2": "658645717510", + "ca-central-1": "658645717510", + "eu-central-1": "658645717510", + "eu-north-1": "658645717510", + "eu-west-1": "658645717510", + "eu-west-2": "658645717510", + "eu-west-3": "658645717510", + "sa-east-1": "658645717510", + "us-east-1": "658645717510", + "us-east-2": "658645717510", + "us-west-1": "658645717510", + "us-west-2": "658645717510" + }, + "repository": "smdistributed-modelparallel" } } } diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 66150da2b0..dbff976442 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -85,7 +85,8 @@ "2.2": "2.2.0", "2.3": "2.3.0", "2.4": "2.4.0", - "2.5": "2.5.1" + "2.5": "2.5.1", + "2.6": "2.6.0" }, "versions": { "0.4.0": { @@ -198,6 +199,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -208,6 +210,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -223,6 +226,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -243,6 +247,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -253,6 +258,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -268,6 +274,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -288,6 +295,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -298,6 +306,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -313,6 +322,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -333,6 +343,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -343,6 +354,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -358,6 +370,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -378,6 +391,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -388,6 +402,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -403,6 +418,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -423,6 +439,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -433,6 +450,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -448,6 +466,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -468,6 +487,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -478,6 +498,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -493,6 +514,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -513,6 +535,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -523,6 +546,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -538,6 +562,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -557,6 +582,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -567,6 +593,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -582,6 +609,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -601,6 +629,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -611,6 +640,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -626,6 +656,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -645,6 +676,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -655,6 +687,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -670,6 +703,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -689,6 +723,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -699,6 +734,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -714,6 +750,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -733,6 +770,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -743,6 +781,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -758,6 +797,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -777,6 +817,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -787,6 +828,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -802,6 +844,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -821,6 +864,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -831,6 +875,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -846,6 +891,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -865,6 +911,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -875,6 +922,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -890,6 +938,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -909,6 +958,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -919,6 +969,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -934,6 +985,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -953,6 +1005,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -963,6 +1016,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -978,6 +1032,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -997,6 +1052,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1007,6 +1063,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1022,6 +1079,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1043,6 +1101,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1053,6 +1112,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1068,6 +1128,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1089,6 +1150,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1099,6 +1161,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1114,6 +1177,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1131,6 +1195,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1141,6 +1206,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1156,6 +1222,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1173,6 +1240,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1183,6 +1251,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1198,6 +1267,52 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-inference" + }, + "2.6.0": { + "py_versions": [ + "py312" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1233,6 +1348,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1243,6 +1359,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1258,6 +1375,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1280,6 +1398,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1290,6 +1409,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1305,6 +1425,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1325,6 +1446,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1335,6 +1457,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1350,6 +1473,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1370,6 +1494,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1380,6 +1505,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1395,6 +1521,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1415,6 +1542,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1425,6 +1553,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1440,6 +1569,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1460,6 +1590,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1470,6 +1601,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1485,6 +1617,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1505,6 +1638,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1515,6 +1649,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1530,6 +1665,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1568,7 +1704,8 @@ "2.2": "2.2.0", "2.3": "2.3.0", "2.4": "2.4.0", - "2.5": "2.5.1" + "2.5": "2.5.1", + "2.6": "2.6.0" }, "versions": { "0.4.0": { @@ -1681,6 +1818,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1691,6 +1829,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1706,6 +1845,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1726,6 +1866,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1736,6 +1877,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1751,6 +1893,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1772,6 +1915,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1782,6 +1926,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1797,6 +1942,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1817,6 +1963,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1827,6 +1974,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1842,6 +1990,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1862,6 +2011,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1872,6 +2022,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1887,6 +2038,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1907,6 +2059,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1917,6 +2070,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1932,6 +2086,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1952,6 +2107,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1962,6 +2118,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1977,6 +2134,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1997,6 +2155,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2007,6 +2166,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2022,6 +2182,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2041,6 +2202,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2051,6 +2213,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2066,6 +2229,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2085,6 +2249,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2095,6 +2260,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2110,6 +2276,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2129,6 +2296,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2139,6 +2307,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2154,6 +2323,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2173,6 +2343,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2183,6 +2354,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2198,6 +2370,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2217,6 +2390,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2227,6 +2401,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2242,6 +2417,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2261,6 +2437,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2271,6 +2448,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2286,6 +2464,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2305,6 +2484,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2315,6 +2495,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2330,6 +2511,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2349,6 +2531,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2359,6 +2542,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2374,6 +2558,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2393,6 +2578,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2403,6 +2589,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2418,6 +2605,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2437,6 +2625,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2447,6 +2636,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2462,6 +2652,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2481,6 +2672,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2491,6 +2683,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2506,6 +2699,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2527,6 +2721,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2537,6 +2732,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2552,6 +2748,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2573,6 +2770,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2583,6 +2781,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2598,6 +2797,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2619,6 +2819,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2629,6 +2830,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2644,6 +2846,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2661,6 +2864,52 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" + }, + "2.6.0": { + "py_versions": [ + "py312" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2671,6 +2920,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2686,6 +2936,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/sagemaker-base-python.json b/src/sagemaker/image_uri_config/sagemaker-base-python.json index d4bb35f77b..65b284d25e 100644 --- a/src/sagemaker/image_uri_config/sagemaker-base-python.json +++ b/src/sagemaker/image_uri_config/sagemaker-base-python.json @@ -11,6 +11,8 @@ "ap-southeast-1": "492261229750", "ap-southeast-2": "452832661640", "ap-southeast-3": "276181064229", + "ap-southeast-5": "148761635175", + "ap-southeast-7": "528757812139", "ca-central-1": "310906938811", "cn-north-1": "390048526115", "cn-northwest-1": "390780980154", @@ -25,11 +27,14 @@ "il-central-1": "380164790875", "me-central-1": "103105715889", "me-south-1": "117516905037", + "mx-central-1": "396913743851", "sa-east-1": "782484402741", "us-east-1": "081325390199", "us-east-2": "429704687514", "us-gov-east-1": "107072934176", "us-gov-west-1": "107173498710", + "us-isof-east-1": "840123138293", + "us-isof-south-1": "883091641454", "us-west-1": "742091327244", "us-west-2": "236514542706" }, diff --git a/src/sagemaker/image_uri_config/spark.json b/src/sagemaker/image_uri_config/spark.json index 9a33ca87d9..bbb8c9b123 100644 --- a/src/sagemaker/image_uri_config/spark.json +++ b/src/sagemaker/image_uri_config/spark.json @@ -20,6 +20,8 @@ "ap-southeast-2": "440695851116", "ap-southeast-3": "800295151634", "ap-southeast-4": "819679513684", + "ap-southeast-5": "841784149062", + "ap-southeast-7": "471112967968", "ca-central-1": "446299261295", "ca-west-1": "000907499111", "cn-north-1": "671472414489", @@ -35,6 +37,7 @@ "il-central-1": "408426139102", "me-central-1": "395420993607", "me-south-1": "750251592176", + "mx-central-1": "211125459255", "sa-east-1": "737130764395", "us-east-1": "173754725891", "us-east-2": "314815235551", @@ -61,6 +64,8 @@ "ap-southeast-2": "440695851116", "ap-southeast-3": "800295151634", "ap-southeast-4": "819679513684", + "ap-southeast-5": "841784149062", + "ap-southeast-7": "471112967968", "ca-central-1": "446299261295", "ca-west-1": "000907499111", "cn-north-1": "671472414489", @@ -76,6 +81,7 @@ "il-central-1": "408426139102", "me-central-1": "395420993607", "me-south-1": "750251592176", + "mx-central-1": "211125459255", "sa-east-1": "737130764395", "us-east-1": "173754725891", "us-east-2": "314815235551", @@ -102,6 +108,8 @@ "ap-southeast-2": "440695851116", "ap-southeast-3": "800295151634", "ap-southeast-4": "819679513684", + "ap-southeast-5": "841784149062", + "ap-southeast-7": "471112967968", "ca-central-1": "446299261295", "ca-west-1": "000907499111", "cn-north-1": "671472414489", @@ -117,6 +125,7 @@ "il-central-1": "408426139102", "me-central-1": "395420993607", "me-south-1": "750251592176", + "mx-central-1": "211125459255", "sa-east-1": "737130764395", "us-east-1": "173754725891", "us-east-2": "314815235551", @@ -143,6 +152,8 @@ "ap-southeast-2": "440695851116", "ap-southeast-3": "800295151634", "ap-southeast-4": "819679513684", + "ap-southeast-5": "841784149062", + "ap-southeast-7": "471112967968", "ca-central-1": "446299261295", "ca-west-1": "000907499111", "cn-north-1": "671472414489", @@ -158,6 +169,7 @@ "il-central-1": "408426139102", "me-central-1": "395420993607", "me-south-1": "750251592176", + "mx-central-1": "211125459255", "sa-east-1": "737130764395", "us-east-1": "173754725891", "us-east-2": "314815235551", @@ -184,6 +196,8 @@ "ap-southeast-2": "440695851116", "ap-southeast-3": "800295151634", "ap-southeast-4": "819679513684", + "ap-southeast-5": "841784149062", + "ap-southeast-7": "471112967968", "ca-central-1": "446299261295", "ca-west-1": "000907499111", "cn-north-1": "671472414489", @@ -199,6 +213,7 @@ "il-central-1": "408426139102", "me-central-1": "395420993607", "me-south-1": "750251592176", + "mx-central-1": "211125459255", "sa-east-1": "737130764395", "us-east-1": "173754725891", "us-east-2": "314815235551", diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 5f12889fd0..097baafa9b 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -332,7 +332,8 @@ "2.12": "2.12.1", "2.13": "2.13.0", "2.14": "2.14.1", - "2.16": "2.16.1" + "2.16": "2.16.1", + "2.18": "2.18.0" }, "versions": { "1.4.1": { @@ -630,6 +631,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -640,6 +642,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -655,6 +658,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -671,6 +675,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -681,6 +686,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -696,6 +702,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -712,6 +719,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -722,6 +730,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -737,6 +746,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -753,6 +763,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -763,6 +774,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -778,6 +790,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -794,6 +807,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -804,6 +818,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -819,6 +834,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -835,6 +851,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -845,6 +862,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -860,6 +878,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -876,6 +895,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -886,6 +906,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -901,6 +922,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -917,6 +939,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -927,6 +950,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -942,6 +966,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -958,6 +983,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -968,6 +994,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -983,6 +1010,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -999,6 +1027,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1009,6 +1038,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1024,6 +1054,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1040,6 +1071,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1050,6 +1082,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1065,6 +1098,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1081,6 +1115,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1091,6 +1126,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1106,6 +1142,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1122,6 +1159,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1132,6 +1170,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1147,6 +1186,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1163,6 +1203,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1173,6 +1214,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1188,6 +1230,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1204,6 +1247,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1214,6 +1258,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1229,6 +1274,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1245,6 +1291,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1255,6 +1302,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1270,6 +1318,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1286,6 +1335,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1296,6 +1346,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1311,6 +1362,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1327,6 +1379,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1337,6 +1390,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1352,6 +1406,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1368,6 +1423,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1378,6 +1434,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1393,6 +1450,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1409,6 +1467,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1419,6 +1478,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1434,6 +1494,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1450,6 +1511,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1460,6 +1522,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1475,6 +1538,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1491,6 +1555,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1501,6 +1566,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1516,6 +1582,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1532,6 +1599,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1542,6 +1610,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1557,6 +1626,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1573,6 +1643,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1583,6 +1654,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1598,6 +1670,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1614,6 +1687,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1624,6 +1698,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1639,6 +1714,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1655,6 +1731,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1665,6 +1742,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1680,6 +1758,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1696,6 +1775,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1706,6 +1786,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1721,6 +1802,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1737,6 +1819,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1747,6 +1830,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1762,6 +1846,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1778,6 +1863,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1788,6 +1874,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1803,6 +1890,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1819,6 +1907,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1829,6 +1918,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1844,6 +1934,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1860,6 +1951,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1870,6 +1962,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1885,6 +1978,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1901,6 +1995,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1911,6 +2006,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1926,6 +2022,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1942,6 +2039,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1952,6 +2050,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1967,6 +2066,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1983,6 +2083,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1993,6 +2094,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2008,6 +2110,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2024,6 +2127,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2034,6 +2138,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2049,6 +2154,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2065,6 +2171,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2075,6 +2182,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2090,6 +2198,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2106,6 +2215,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2116,6 +2226,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2131,6 +2242,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2147,6 +2259,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2157,6 +2270,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2172,6 +2286,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2190,6 +2305,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2200,6 +2316,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2215,6 +2332,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2233,6 +2351,49 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, + "2.18.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2243,6 +2404,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2258,6 +2420,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2292,6 +2455,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2302,6 +2466,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2317,6 +2482,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2339,6 +2505,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2349,6 +2516,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2364,6 +2532,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2386,6 +2555,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2396,6 +2566,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2411,6 +2582,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2433,6 +2605,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2443,6 +2616,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2458,6 +2632,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2480,6 +2655,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2490,6 +2666,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2505,6 +2682,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2550,7 +2728,8 @@ "2.12": "2.12.0", "2.13": "2.13.0", "2.14": "2.14.1", - "2.16": "2.16.2" + "2.16": "2.16.2", + "2.18": "2.18.0" }, "versions": { "1.4.1": { @@ -2932,6 +3111,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2942,6 +3122,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2957,6 +3138,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2977,6 +3159,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2987,6 +3170,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3002,6 +3186,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3023,6 +3208,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3033,6 +3219,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3048,6 +3235,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3069,6 +3257,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3079,6 +3268,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3094,6 +3284,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3115,6 +3306,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3125,6 +3317,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3140,6 +3333,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3161,6 +3355,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3171,6 +3366,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3186,6 +3382,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3206,6 +3403,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3216,6 +3414,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3231,6 +3430,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3251,6 +3451,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3261,6 +3462,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3276,6 +3478,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3296,6 +3499,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3306,6 +3510,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3321,6 +3526,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3341,6 +3547,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3351,6 +3558,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3366,6 +3574,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3386,6 +3595,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3396,6 +3606,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3411,6 +3622,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3431,6 +3643,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3441,6 +3654,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3456,6 +3670,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3476,6 +3691,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3486,6 +3702,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3501,6 +3718,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3521,6 +3739,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3531,6 +3750,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3546,6 +3766,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3566,6 +3787,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3576,6 +3798,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3591,6 +3814,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3610,6 +3834,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3620,6 +3845,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3635,6 +3861,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3654,6 +3881,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3664,6 +3892,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3679,6 +3908,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3698,6 +3928,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3708,6 +3939,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3723,6 +3955,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3742,6 +3975,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3752,6 +3986,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3767,6 +4002,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3786,6 +4022,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3796,6 +4033,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3811,6 +4049,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3830,6 +4069,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3840,6 +4080,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3855,6 +4096,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3874,6 +4116,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3884,6 +4127,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3899,6 +4143,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3918,6 +4163,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3928,6 +4174,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3943,6 +4190,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3962,6 +4210,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3972,6 +4221,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3987,6 +4237,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4006,6 +4257,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4016,6 +4268,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4031,6 +4284,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4050,6 +4304,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4060,6 +4315,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4075,6 +4331,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4094,6 +4351,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4104,6 +4362,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4119,6 +4378,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4138,6 +4398,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4148,6 +4409,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4163,6 +4425,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4182,6 +4445,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4192,6 +4456,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4207,6 +4472,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4226,6 +4492,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4236,6 +4503,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4251,6 +4519,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4270,6 +4539,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4280,6 +4550,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4295,6 +4566,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4314,6 +4586,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4324,6 +4597,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4339,6 +4613,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4358,6 +4633,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4368,6 +4644,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4383,6 +4660,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4402,6 +4680,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4412,6 +4691,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4427,6 +4707,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4444,6 +4725,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4454,6 +4736,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4469,6 +4752,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4490,6 +4774,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4500,6 +4785,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4515,6 +4801,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4536,6 +4823,52 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-training" + }, + "2.18.0": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4546,6 +4879,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4561,6 +4895,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 8c3449e8c4..de6d622f78 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -101,6 +101,8 @@ def retrieve( https://github.com/aws/deep-learning-containers/blob/master/available_images.md (default: None). distribution (dict): A dictionary with information on how to run distributed training + base_framework_version (str): The base version number of PyTorch or Tensorflow. + (default: None). training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): A configuration class for the SageMaker Training Compiler (default: None). @@ -699,12 +701,16 @@ def get_training_image_uri( if "modelparallel" in distribution["smdistributed"]: if distribution["smdistributed"]["modelparallel"].get("enabled", True): framework = "pytorch-smp" - if ( - "p5" in instance_type - or "2.1" in framework_version - or "2.2" in framework_version - or "2.3" in framework_version - or "2.4" in framework_version + supported_smp_pt_versions_cu124 = ("2.5",) + supported_smp_pt_versions_cu121 = ("2.1", "2.2", "2.3", "2.4") + if any( + pt_version in framework_version + for pt_version in supported_smp_pt_versions_cu124 + ): + container_version = "cu124" + elif "p5" in instance_type or any( + pt_version in framework_version + for pt_version in supported_smp_pt_versions_cu121 ): container_version = "cu121" else: diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index 89779bef44..71678021d4 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -43,6 +43,8 @@ def __init__( attribute_names: Optional[List[Union[str, PipelineVariable]]] = None, target_attribute_name: Optional[Union[str, PipelineVariable]] = None, shuffle_config: Optional["ShuffleConfig"] = None, + hub_access_config: Optional[dict] = None, + model_access_config: Optional[dict] = None, ): r"""Create a definition for input data used by an SageMaker training job. @@ -102,6 +104,13 @@ def __init__( shuffle_config (sagemaker.inputs.ShuffleConfig): If specified this configuration enables shuffling on this channel. See the SageMaker API documentation for more info: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html + hub_access_config (dict): Specify the HubAccessConfig of a + Model Reference for which a training job is being created for. + model_access_config (dict): For models that require a Model Access Config, specify True + or False for to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). """ self.config = { "DataSource": {"S3DataSource": {"S3DataType": s3_data_type, "S3Uri": s3_data}} @@ -129,6 +138,27 @@ def __init__( self.config["TargetAttributeName"] = target_attribute_name if shuffle_config is not None: self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed} + self.add_hub_access_config(hub_access_config) + self.add_model_access_config(model_access_config) + + def add_hub_access_config(self, hub_access_config=None): + """Add Hub Access Config to the channel's configuration. + + Args: + hub_access_config (dict): The HubAccessConfig to be added to the + channel's configuration. + """ + if hub_access_config is not None: + self.config["DataSource"]["S3DataSource"]["HubAccessConfig"] = hub_access_config + + def add_model_access_config(self, model_access_config=None): + """Add Model Access Config to the channel's configuration. + + Args: + model_access_config (dict): Whether model terms of use have been accepted. + """ + if model_access_config is not None: + self.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] = model_access_config class ShuffleConfig(object): diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 210dd426c5..1ad7e3b981 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -65,6 +65,7 @@ def stop(self): @staticmethod def _load_config(inputs, estimator, expand_role=True, validate_uri=True): """Placeholder docstring""" + model_access_config, hub_access_config = _Job._get_access_configs(estimator) input_config = _Job._format_inputs_to_input_config(inputs, validate_uri) role = ( estimator.sagemaker_session.expand_role(estimator.role) @@ -95,19 +96,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): validate_uri, content_type="application/x-sagemaker-model", input_mode="File", + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) if model_channel: input_config = [] if input_config is None else input_config input_config.append(model_channel) - if estimator.enable_network_isolation(): - code_channel = _Job._prepare_channel( - input_config, estimator.code_uri, estimator.code_channel_name, validate_uri - ) + code_channel = _Job._prepare_channel( + input_config, + estimator.code_uri, + estimator.code_channel_name, + validate_uri, + ) - if code_channel: - input_config = [] if input_config is None else input_config - input_config.append(code_channel) + if code_channel: + input_config = [] if input_config is None else input_config + input_config.append(code_channel) return { "input_config": input_config, @@ -118,6 +123,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): "vpc_config": vpc_config, } + @staticmethod + def _get_access_configs(estimator): + """Return access configs from estimator object. + + JumpStartEstimator uses access configs which need to be added to the model channel, + so they are passed down to the job level. + + Args: + estimator (EstimatorBase): estimator object with access config field if applicable + """ + model_access_config, hub_access_config = None, None + if hasattr(estimator, "model_access_config"): + model_access_config = estimator.model_access_config + if hasattr(estimator, "hub_access_config"): + hub_access_config = estimator.hub_access_config + return model_access_config, hub_access_config + @staticmethod def _format_inputs_to_input_config(inputs, validate_uri=True): """Placeholder docstring""" @@ -173,6 +195,8 @@ def _format_string_uri_input( input_mode=None, compression=None, target_attribute_name=None, + model_access_config=None, + hub_access_config=None, ): """Placeholder docstring""" s3_input_result = TrainingInput( @@ -181,6 +205,8 @@ def _format_string_uri_input( input_mode=input_mode, compression=compression, target_attribute_name=target_attribute_name, + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"): return s3_input_result @@ -193,7 +219,11 @@ def _format_string_uri_input( ) if isinstance(uri_input, str): return s3_input_result - if isinstance(uri_input, (TrainingInput, file_input, FileSystemInput)): + if isinstance(uri_input, (file_input, FileSystemInput)): + return uri_input + if isinstance(uri_input, TrainingInput): + uri_input.add_hub_access_config(hub_access_config=hub_access_config) + uri_input.add_model_access_config(model_access_config=model_access_config) return uri_input if is_pipeline_variable(uri_input): return s3_input_result @@ -211,6 +241,8 @@ def _prepare_channel( validate_uri=True, content_type=None, input_mode=None, + model_access_config=None, + hub_access_config=None, ): """Placeholder docstring""" if not channel_uri: @@ -226,7 +258,12 @@ def _prepare_channel( raise ValueError("Duplicate channel {} not allowed.".format(channel_name)) channel_input = _Job._format_string_uri_input( - channel_uri, validate_uri, content_type, input_mode + channel_uri, + validate_uri, + content_type, + input_mode, + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) channel = _Job._convert_input_to_channel(channel_name, channel_input) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 20a2d16c15..9ebc2880bc 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -25,6 +25,7 @@ from sagemaker.jumpstart.hub.utils import ( construct_hub_model_arn_from_inputs, construct_hub_model_reference_arn_from_inputs, + generate_hub_arn_for_init_kwargs, ) from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.session import Session @@ -288,8 +289,13 @@ def get_model_specs( ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) + # Users only input model id, not contentType, so first try to describe with ModelReference, then with Model if hub_arn: try: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_arn, region=region, session=sagemaker_session + ) + hub_model_arn = construct_hub_model_reference_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) @@ -308,11 +314,22 @@ def get_model_specs( hub_model_arn = construct_hub_model_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) - model_specs = JumpStartModelsAccessor._cache.get_hub_model( - hub_model_arn=hub_model_arn - ) - model_specs.set_hub_content_type(HubContentType.MODEL) - return model_specs + + # Failed to describe ModelReference, try with Model + try: + model_specs = JumpStartModelsAccessor._cache.get_hub_model( + hub_model_arn=hub_model_arn + ) + model_specs.set_hub_content_type(HubContentType.MODEL) + + return model_specs + except Exception as ex: + # Failed with both, throw a custom error message + raise RuntimeError( + f"Cannot get details for {model_id} in Hub {hub_arn}. \ + {model_id} does not exist as a Model or ModelReference: \n" + + str(ex) + ) return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, version_str=version, model_type=model_type diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 90ee7dea8d..c1ad9710f1 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -29,6 +29,7 @@ get_region_fallback, verify_model_region_and_return_specs, ) +from sagemaker.s3_utils import is_s3_url from sagemaker.session import Session from sagemaker.jumpstart.types import JumpStartModelSpecs @@ -74,7 +75,7 @@ def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_ty def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str: """Returns instance specific training artifact key or default one as fallback.""" instance_specific_training_artifact_key: Optional[str] = ( - model_specs.training_instance_type_variants.get_instance_specific_artifact_key( + model_specs.training_instance_type_variants.get_instance_specific_training_artifact_key( instance_type=instance_type ) if instance_type @@ -185,8 +186,8 @@ def _retrieve_model_uri( os.environ.get(ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE) or default_jumpstart_bucket ) - - model_s3_uri = f"s3://{bucket}/{model_artifact_key}" + if not is_s3_url(model_artifact_key): + model_s3_uri = f"s3://{bucket}/{model_artifact_key}" return model_s3_uri diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 8ac813a6c5..29a903e00b 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -150,7 +150,8 @@ def __init__( if s3_client_config else boto3.client("s3", region_name=self._region) ) - self._sagemaker_session = sagemaker_session + # Fallback in case a caller overrides sagemaker_session to None + self._sagemaker_session = sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION def set_region(self, region: str) -> None: """Set region for cache. Clears cache after new region is set.""" @@ -262,7 +263,7 @@ def _model_id_retrieval_function( return JumpStartVersionedModelId(model_id, sm_compatible_model_version) versions_incompatible_with_sagemaker = [ - Version(header.version) + header.version for header in manifest.values() # type: ignore if header.model_id == model_id ] @@ -540,9 +541,7 @@ def _select_version( """ if version_str == "*": - if len(available_versions) == 0: - return None - return str(max(available_versions)) + return utils.get_latest_version(available_versions) if model_type == JumpStartModelType.PROPRIETARY: if "*" in version_str: @@ -553,6 +552,12 @@ def _select_version( ) return version_str if version_str in available_versions else None + if version_str[-1] == "*": + # major or minor version is pinned, e.g 1.* or 1.0.* + return utils.get_latest_version( + [version for version in available_versions if version.startswith(version_str[:-1])] + ) + try: spec = SpecifierSet(f"=={version_str}") except InvalidSpecifier: diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index f3f7ecad1b..b81f97ce3a 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -15,6 +15,7 @@ import logging import os from typing import Dict, Set, Type +import json import boto3 from sagemaker.base_deserializers import BaseDeserializer, JSONDeserializer from sagemaker.jumpstart.enums import ( @@ -35,180 +36,58 @@ from sagemaker.session import Session +JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart") + +# disable logging if env var is set +JUMPSTART_LOGGER.addHandler( + type( + "", + (logging.StreamHandler,), + { + "emit": lambda self, *args, **kwargs: ( + logging.StreamHandler.emit(self, *args, **kwargs) + if not os.environ.get(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING) + else None + ) + }, + )() +) + + +_CURRENT_FILE_DIRECTORY_PATH = os.path.dirname(os.path.realpath(__file__)) +REGION_CONFIG_JSON_FILENAME = "region_config.json" +REGION_CONFIG_JSON_FILEPATH = os.path.join( + _CURRENT_FILE_DIRECTORY_PATH, REGION_CONFIG_JSON_FILENAME +) + + +def _load_region_config(filepath: str) -> Set[JumpStartLaunchedRegionInfo]: + """Load the JumpStart region config from a JSON file.""" + debug_msg = f"Loading JumpStart region config from '{filepath}'." + JUMPSTART_LOGGER.debug(debug_msg) + try: + with open(filepath) as f: + config = json.load(f) + + return { + JumpStartLaunchedRegionInfo( + region_name=region, + content_bucket=data["content_bucket"], + gated_content_bucket=data.get("gated_content_bucket"), + neo_content_bucket=data.get("neo_content_bucket"), + ) + for region, data in config.items() + } + except Exception: # pylint: disable=W0703 + JUMPSTART_LOGGER.error("Unable to load JumpStart region config.", exc_info=True) + return set() + + ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING = "DISABLE_JUMPSTART_LOGGING" ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY = "DISABLE_JUMPSTART_TELEMETRY" -JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set( - [ - JumpStartLaunchedRegionInfo( - region_name="us-west-2", - content_bucket="jumpstart-cache-prod-us-west-2", - gated_content_bucket="jumpstart-private-cache-prod-us-west-2", - neo_content_bucket="sagemaker-sd-models-prod-us-west-2", - ), - JumpStartLaunchedRegionInfo( - region_name="us-east-1", - content_bucket="jumpstart-cache-prod-us-east-1", - gated_content_bucket="jumpstart-private-cache-prod-us-east-1", - neo_content_bucket="sagemaker-sd-models-prod-us-east-1", - ), - JumpStartLaunchedRegionInfo( - region_name="us-east-2", - content_bucket="jumpstart-cache-prod-us-east-2", - gated_content_bucket="jumpstart-private-cache-prod-us-east-2", - neo_content_bucket="sagemaker-sd-models-prod-us-east-2", - ), - JumpStartLaunchedRegionInfo( - region_name="eu-west-1", - content_bucket="jumpstart-cache-prod-eu-west-1", - gated_content_bucket="jumpstart-private-cache-prod-eu-west-1", - neo_content_bucket="sagemaker-sd-models-prod-eu-west-1", - ), - JumpStartLaunchedRegionInfo( - region_name="eu-central-1", - content_bucket="jumpstart-cache-prod-eu-central-1", - gated_content_bucket="jumpstart-private-cache-prod-eu-central-1", - neo_content_bucket="sagemaker-sd-models-prod-eu-central-1", - ), - JumpStartLaunchedRegionInfo( - region_name="eu-central-2", - content_bucket="jumpstart-cache-prod-eu-central-2", - gated_content_bucket="jumpstart-private-cache-prod-eu-central-2", - ), - JumpStartLaunchedRegionInfo( - region_name="eu-north-1", - content_bucket="jumpstart-cache-prod-eu-north-1", - gated_content_bucket="jumpstart-private-cache-prod-eu-north-1", - neo_content_bucket="sagemaker-sd-models-prod-eu-north-1", - ), - JumpStartLaunchedRegionInfo( - region_name="me-south-1", - content_bucket="jumpstart-cache-prod-me-south-1", - gated_content_bucket="jumpstart-private-cache-prod-me-south-1", - ), - JumpStartLaunchedRegionInfo( - region_name="me-central-1", - content_bucket="jumpstart-cache-prod-me-central-1", - gated_content_bucket="jumpstart-private-cache-prod-me-central-1", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-south-1", - content_bucket="jumpstart-cache-prod-ap-south-1", - gated_content_bucket="jumpstart-private-cache-prod-ap-south-1", - neo_content_bucket="sagemaker-sd-models-prod-ap-south-1", - ), - JumpStartLaunchedRegionInfo( - region_name="eu-west-3", - content_bucket="jumpstart-cache-prod-eu-west-3", - gated_content_bucket="jumpstart-private-cache-prod-eu-west-3", - neo_content_bucket="sagemaker-sd-models-prod-eu-west-3", - ), - JumpStartLaunchedRegionInfo( - region_name="af-south-1", - content_bucket="jumpstart-cache-prod-af-south-1", - gated_content_bucket="jumpstart-private-cache-prod-af-south-1", - ), - JumpStartLaunchedRegionInfo( - region_name="sa-east-1", - content_bucket="jumpstart-cache-prod-sa-east-1", - gated_content_bucket="jumpstart-private-cache-prod-sa-east-1", - neo_content_bucket="sagemaker-sd-models-prod-sa-east-1", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-east-1", - content_bucket="jumpstart-cache-prod-ap-east-1", - gated_content_bucket="jumpstart-private-cache-prod-ap-east-1", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-northeast-2", - content_bucket="jumpstart-cache-prod-ap-northeast-2", - gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-2", - neo_content_bucket="sagemaker-sd-models-prod-ap-northeast-2", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-northeast-3", - content_bucket="jumpstart-cache-prod-ap-northeast-3", - gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-3", - neo_content_bucket="sagemaker-sd-models-prod-ap-northeast-3", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-southeast-3", - content_bucket="jumpstart-cache-prod-ap-southeast-3", - gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-3", - neo_content_bucket="sagemaker-sd-models-prod-ap-southeast-3", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-southeast-5", - content_bucket="jumpstart-cache-prod-ap-southeast-5", - gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-5", - ), - JumpStartLaunchedRegionInfo( - region_name="eu-west-2", - content_bucket="jumpstart-cache-prod-eu-west-2", - gated_content_bucket="jumpstart-private-cache-prod-eu-west-2", - neo_content_bucket="sagemaker-sd-models-prod-eu-west-2", - ), - JumpStartLaunchedRegionInfo( - region_name="eu-south-1", - content_bucket="jumpstart-cache-prod-eu-south-1", - gated_content_bucket="jumpstart-private-cache-prod-eu-south-1", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-northeast-1", - content_bucket="jumpstart-cache-prod-ap-northeast-1", - gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-1", - neo_content_bucket="sagemaker-sd-models-prod-ap-northeast-1", - ), - JumpStartLaunchedRegionInfo( - region_name="us-west-1", - content_bucket="jumpstart-cache-prod-us-west-1", - gated_content_bucket="jumpstart-private-cache-prod-us-west-1", - neo_content_bucket="sagemaker-sd-models-prod-us-west-1", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-southeast-1", - content_bucket="jumpstart-cache-prod-ap-southeast-1", - gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-1", - neo_content_bucket="sagemaker-sd-models-prod-ap-southeast-1", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-southeast-2", - content_bucket="jumpstart-cache-prod-ap-southeast-2", - gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-2", - neo_content_bucket="sagemaker-sd-models-prod-ap-southeast-2", - ), - JumpStartLaunchedRegionInfo( - region_name="ca-central-1", - content_bucket="jumpstart-cache-prod-ca-central-1", - gated_content_bucket="jumpstart-private-cache-prod-ca-central-1", - neo_content_bucket="sagemaker-sd-models-prod-ca-central-1", - ), - JumpStartLaunchedRegionInfo( - region_name="cn-north-1", - content_bucket="jumpstart-cache-prod-cn-north-1", - gated_content_bucket="jumpstart-private-cache-prod-cn-north-1", - ), - JumpStartLaunchedRegionInfo( - region_name="cn-northwest-1", - content_bucket="jumpstart-cache-prod-cn-northwest-1", - gated_content_bucket="jumpstart-private-cache-prod-cn-northwest-1", - ), - JumpStartLaunchedRegionInfo( - region_name="il-central-1", - content_bucket="jumpstart-cache-prod-il-central-1", - gated_content_bucket="jumpstart-private-cache-prod-il-central-1", - ), - JumpStartLaunchedRegionInfo( - region_name="us-gov-east-1", - content_bucket="jumpstart-cache-prod-us-gov-east-1", - gated_content_bucket="jumpstart-private-cache-prod-us-gov-east-1", - ), - JumpStartLaunchedRegionInfo( - region_name="us-gov-west-1", - content_bucket="jumpstart-cache-prod-us-gov-west-1", - gated_content_bucket="jumpstart-private-cache-prod-us-gov-west-1", - ), - ] +JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = _load_region_config( + REGION_CONFIG_JSON_FILEPATH ) JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT = { @@ -297,23 +176,6 @@ MODEL_ID_LIST_WEB_URL = "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html" -JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart") - -# disable logging if env var is set -JUMPSTART_LOGGER.addHandler( - type( - "", - (logging.StreamHandler,), - { - "emit": lambda self, *args, **kwargs: ( - logging.StreamHandler.emit(self, *args, **kwargs) - if not os.environ.get(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING) - else None - ) - }, - )() -) - try: DEFAULT_JUMPSTART_SAGEMAKER_SESSION = Session( boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index a41c9ed952..4daf9b1810 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -14,7 +14,7 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union from sagemaker import session from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -41,6 +41,9 @@ validate_model_id_and_get_type, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, + remove_env_var_from_estimator_kwargs_if_model_access_config_present, + get_model_access_config, + get_hub_access_config, ) from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model_monitor.data_capture_config import DataCaptureConfig @@ -350,8 +353,8 @@ def __init__( source_dir (Optional[Union[str, PipelineVariable]]): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file. If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + is preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. (Default: None). @@ -613,12 +616,17 @@ def _validate_model_id_and_get_type_hook(): self.tolerate_vulnerable_model = estimator_init_kwargs.tolerate_vulnerable_model self.instance_count = estimator_init_kwargs.instance_count self.region = estimator_init_kwargs.region + self.environment = estimator_init_kwargs.environment self.orig_predictor_cls = None self.role = estimator_init_kwargs.role self.sagemaker_session = estimator_init_kwargs.sagemaker_session self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation self.config_name = estimator_init_kwargs.config_name self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False) + # Access configs initialized to None, would be given a value when .fit() is called + # if applicable + self.model_access_config = None + self.hub_access_config = None super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict()) @@ -629,6 +637,7 @@ def fit( logs: Optional[str] = None, job_name: Optional[str] = None, experiment_config: Optional[Dict[str, str]] = None, + accept_eula: Optional[bool] = None, ) -> None: """Start training job by calling base ``Estimator`` class ``fit`` method. @@ -679,8 +688,16 @@ def fit( is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. However, the value of `TrialComponentDisplayName` is honored for display in Studio. (Default: None). + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). """ - + self.model_access_config = get_model_access_config(accept_eula, self.environment) + self.hub_access_config = get_hub_access_config( + hub_content_arn=self.init_kwargs.get("model_reference_arn", None) + ) estimator_fit_kwargs = get_fit_kwargs( model_id=self.model_id, model_version=self.model_version, @@ -695,6 +712,10 @@ def fit( tolerate_deprecated_model=self.tolerate_deprecated_model, sagemaker_session=self.sagemaker_session, config_name=self.config_name, + hub_access_config=self.hub_access_config, + ) + remove_env_var_from_estimator_kwargs_if_model_access_config_present( + self.init_kwargs, self.model_access_config ) return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) @@ -817,7 +838,7 @@ def deploy( explainer_config: Optional[ExplainerConfig] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, model_name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, @@ -918,7 +939,7 @@ def deploy( It can be null if this is being used to create a Model to pass to a ``PipelineModel`` which has its own Role field. (Default: None). - predictor_cls (Optional[callable[string, sagemaker.session.Session]]): A + predictor_cls (Optional[Callable[[string, sagemaker.session.Session], Any]]): A function to call to create a predictor (Default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. (Default: None). @@ -947,8 +968,8 @@ def deploy( source_dir (Optional[str]): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (Default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory is + preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. If the directory points to S3, no code is uploaded and the S3 location is used instead. (Default: None). diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index e4020a39bd..12eb30daaf 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -14,7 +14,7 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union from sagemaker import ( environment_variables, hyperparameters as hyperparameters_utils, @@ -56,6 +56,7 @@ JUMPSTART_LOGGER, TRAINING_ENTRY_POINT_SCRIPT_NAME, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, + JUMPSTART_MODEL_HUB_NAME, ) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.factory import model @@ -71,7 +72,6 @@ from sagemaker.jumpstart.utils import ( add_hub_content_arn_tags, add_jumpstart_model_info_tags, - get_eula_message, get_default_jumpstart_session_with_user_agent_suffix, get_top_ranked_config_name, update_dict_if_key_not_present, @@ -265,6 +265,7 @@ def get_fit_kwargs( tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, config_name: Optional[str] = None, + hub_access_config: Optional[Dict] = None, ) -> JumpStartEstimatorFitKwargs: """Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object.""" @@ -301,10 +302,47 @@ def get_fit_kwargs( estimator_fit_kwargs = _add_region_to_kwargs(estimator_fit_kwargs) estimator_fit_kwargs = _add_training_job_name_to_kwargs(estimator_fit_kwargs) estimator_fit_kwargs = _add_fit_extra_kwargs(estimator_fit_kwargs) + estimator_fit_kwargs = _add_hub_access_config_to_kwargs_inputs( + estimator_fit_kwargs, hub_access_config + ) return estimator_fit_kwargs +def _add_hub_access_config_to_kwargs_inputs( + kwargs: JumpStartEstimatorFitKwargs, hub_access_config=None +): + """Adds HubAccessConfig to kwargs inputs""" + + dataset_uri = kwargs.specs.default_training_dataset_uri + if isinstance(kwargs.inputs, str): + if dataset_uri is not None and dataset_uri == kwargs.inputs: + kwargs.inputs = TrainingInput( + s3_data=kwargs.inputs, hub_access_config=hub_access_config + ) + elif isinstance(kwargs.inputs, TrainingInput): + if ( + dataset_uri is not None + and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ): + kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config) + elif isinstance(kwargs.inputs, dict): + for k, v in kwargs.inputs.items(): + if isinstance(v, str): + training_input = TrainingInput(s3_data=v) + if dataset_uri is not None and dataset_uri == v: + training_input.add_hub_access_config(hub_access_config=hub_access_config) + kwargs.inputs[k] = training_input + elif isinstance(kwargs.inputs, TrainingInput): + if ( + dataset_uri is not None + and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ): + kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config) + + return kwargs + + def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, @@ -330,7 +368,7 @@ def get_deploy_kwargs( explainer_config: Optional[ExplainerConfig] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, sagemaker_session: Optional[Session] = None, @@ -594,8 +632,13 @@ def _add_model_reference_arn_to_kwargs( def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs: """Sets model uri in kwargs based on default or override, returns full kwargs.""" - - if _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)): + # hub_arn is by default None unless the user specifies the hub_name + # If no hub_name is specified, it is assumed the public hub + is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False + if ( + _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)) + or is_private_hub + ): default_model_uri = model_uris.retrieve( model_scope=JumpStartScriptScope.TRAINING, instance_type=kwargs.instance_type, @@ -668,18 +711,6 @@ def _add_env_to_kwargs( value, ) - environment = getattr(kwargs, "environment", {}) or {} - if ( - environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY) - and str(environment.get("accept_eula", "")).lower() != "true" - ): - model_specs = kwargs.specs - if model_specs.is_gated_model(): - raise ValueError( - "Need to define ‘accept_eula'='true' within Environment. " - f"{get_eula_message(model_specs, kwargs.region)}" - ) - return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 328e1e8227..53ded3f275 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -15,7 +15,7 @@ import json -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from sagemaker_core.shapes import ModelAccessConfig from sagemaker import environment_variables, image_uris, instance_types, model_uris, script_uris from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig @@ -104,7 +104,7 @@ def get_default_predictor( """ # if there's a non-default predictor, do not mutate -- return as is - if type(predictor) != Predictor: # pylint: disable=C0123 + if not isinstance(predictor, Predictor): raise RuntimeError( "Can only get default predictor from base Predictor class. " f"Using Predictor class '{type(predictor).__name__}'." @@ -855,7 +855,7 @@ def get_init_kwargs( image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index bc42eebea0..692966cee4 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -16,15 +16,11 @@ from datetime import datetime import logging from typing import Optional, Dict, List, Any, Union -from botocore import exceptions from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.session import Session -from sagemaker.jumpstart.constants import ( - JUMPSTART_LOGGER, -) from sagemaker.jumpstart.types import ( HubContentType, ) @@ -32,9 +28,6 @@ from sagemaker.jumpstart.hub.utils import ( get_hub_model_version, get_info_from_hub_resource_arn, - create_hub_bucket_if_it_does_not_exist, - generate_default_hub_bucket_name, - create_s3_object_reference_from_uri, construct_hub_arn_from_name, ) @@ -42,9 +35,6 @@ list_jumpstart_models, ) -from sagemaker.jumpstart.hub.types import ( - S3ObjectLocation, -) from sagemaker.jumpstart.hub.interfaces import ( DescribeHubResponse, DescribeHubContentResponse, @@ -66,8 +56,8 @@ class Hub: def __init__( self, hub_name: str, + sagemaker_session: Session, bucket_name: Optional[str] = None, - sagemaker_session: Optional[Session] = None, ) -> None: """Instantiates a SageMaker ``Hub``. @@ -78,41 +68,11 @@ def __init__( """ self.hub_name = hub_name self.region = sagemaker_session.boto_region_name + self.bucket_name = bucket_name self._sagemaker_session = ( sagemaker_session or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True) ) - self.hub_storage_location = self._generate_hub_storage_location(bucket_name) - - def _fetch_hub_bucket_name(self) -> str: - """Retrieves hub bucket name from Hub config if exists""" - try: - hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name) - hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath") - if hub_output_location: - location = create_s3_object_reference_from_uri(hub_output_location) - return location.bucket - default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) - JUMPSTART_LOGGER.warning( - "There is not a Hub bucket associated with %s. Using %s", - self.hub_name, - default_bucket_name, - ) - return default_bucket_name - except exceptions.ClientError: - hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) - JUMPSTART_LOGGER.warning( - "There is not a Hub bucket associated with %s. Using %s", - self.hub_name, - hub_bucket_name, - ) - return hub_bucket_name - - def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None: - """Generates an ``S3ObjectLocation`` given a Hub name.""" - hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() - curr_timestamp = datetime.now().timestamp() - return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}") def _get_latest_model_version(self, model_id: str) -> str: """Populates the lastest version of a model from specs no matter what is passed. @@ -132,19 +92,22 @@ def create( tags: Optional[str] = None, ) -> Dict[str, str]: """Creates a hub with the given description""" + curr_timestamp = datetime.now().timestamp() - create_hub_bucket_if_it_does_not_exist( - self.hub_storage_location.bucket, self._sagemaker_session - ) + request = { + "hub_name": self.hub_name, + "hub_description": description, + "hub_display_name": display_name, + "hub_search_keywords": search_keywords, + "tags": tags, + } - return self._sagemaker_session.create_hub( - hub_name=self.hub_name, - hub_description=description, - hub_display_name=display_name, - hub_search_keywords=search_keywords, - s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()}, - tags=tags, - ) + if self.bucket_name: + request["s3_storage_config"] = { + "S3OutputPath": (f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}") + } + + return self._sagemaker_session.create_hub(**request) def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse: """Returns descriptive information about the Hub""" @@ -272,18 +235,21 @@ def delete_model_reference(self, model_name: str) -> None: def describe_model( self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None ) -> DescribeHubContentResponse: - """Describe model in the SageMaker Hub.""" + """Describe Model or ModelReference in a Hub.""" + hub_name = hub_name or self.hub_name + + # Users only input model id, not contentType, so first try to describe with ModelReference, then with Model try: model_version = get_hub_model_version( hub_model_name=model_name, hub_model_type=HubContentType.MODEL_REFERENCE.value, - hub_name=self.hub_name if not hub_name else hub_name, + hub_name=hub_name, sagemaker_session=self._sagemaker_session, hub_model_version=model_version, ) hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=self.hub_name if not hub_name else hub_name, + hub_name=hub_name, hub_content_name=model_name, hub_content_version=model_version, hub_content_type=HubContentType.MODEL_REFERENCE.value, @@ -294,19 +260,32 @@ def describe_model( "Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: " + str(ex) ) - model_version = get_hub_model_version( - hub_model_name=model_name, - hub_model_type=HubContentType.MODEL.value, - hub_name=self.hub_name if not hub_name else hub_name, - sagemaker_session=self._sagemaker_session, - hub_model_version=model_version, - ) - hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=self.hub_name if not hub_name else hub_name, - hub_content_name=model_name, - hub_content_version=model_version, - hub_content_type=HubContentType.MODEL.value, - ) + # Failed to describe ModelReference, try with Model + try: + model_version = get_hub_model_version( + hub_model_name=model_name, + hub_model_type=HubContentType.MODEL.value, + hub_name=hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=model_version, + ) + + hub_content_description: Dict[str, Any] = ( + self._sagemaker_session.describe_hub_content( + hub_name=hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL.value, + ) + ) + + except Exception as ex: + # Failed with both, throw a custom error message + raise RuntimeError( + f"Cannot get details for {model_name} in Hub {hub_name}. \ + {model_name} does not exist as a Model or ModelReference in {hub_name}: \n" + + str(ex) + ) return DescribeHubContentResponse(hub_content_description) diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index fd38868dcc..6ba5a37c3c 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -630,7 +630,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if json_obj.get("ValidationSupported") else None ) - self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri") self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase") self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False)) self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( @@ -671,6 +670,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ) if self.training_supported: + self.default_training_dataset_uri: Optional[str] = json_obj.get( + "DefaultTrainingDatasetUri" + ) self.training_model_package_artifact_uri: Optional[str] = json_obj.get( "TrainingModelPackageArtifactUri" ) diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py index 01b6c5fe87..8070b54e87 100644 --- a/src/sagemaker/jumpstart/hub/parsers.py +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -279,4 +279,10 @@ def make_model_specs_from_describe_hub_content_response( specs["training_instance_type_variants"] = ( hub_model_document.training_instance_type_variants ) + if hub_model_document.default_training_dataset_uri: + _, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.default_training_dataset_uri + ) + specs["default_training_dataset_key"] = default_training_dataset_key + specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri return JumpStartModelSpecs(_to_json(specs), is_hub_content=True) diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 77540926c6..0df5e9d5c3 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -15,13 +15,12 @@ from __future__ import absolute_import import re from typing import Optional, List, Any -from sagemaker.jumpstart.hub.types import S3ObjectLocation -from sagemaker.s3_utils import parse_s3_url from sagemaker.session import Session from sagemaker.utils import aws_partition from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo from sagemaker.jumpstart import constants from packaging.specifiers import SpecifierSet, InvalidSpecifier +from packaging import version PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:" @@ -78,6 +77,9 @@ def construct_hub_arn_from_name( account_id: Optional[str] = None, ) -> str: """Constructs a Hub arn from the Hub name using default Session values.""" + if session is None: + # session is overridden to none by some callers + session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION account_id = account_id or session.account_id() region = region or session.boto_region_name @@ -106,7 +108,7 @@ def construct_hub_model_reference_arn_from_inputs( info = get_info_from_hub_resource_arn(hub_arn) arn = ( f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" - f"{info.hub_name}/{HubContentType.MODEL_REFERENCE}/{model_name}/{version}" + f"{info.hub_name}/{HubContentType.MODEL_REFERENCE.value}/{model_name}/{version}" ) return arn @@ -135,61 +137,6 @@ def generate_hub_arn_for_init_kwargs( return hub_arn -def generate_default_hub_bucket_name( - sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: - """Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions. - - Returns: - str: The name of the default bucket. If the name was not explicitly specified through - the Session or sagemaker_config, the bucket will take the form: - ``sagemaker-hubs-{region}-{AWS account ID}``. - """ - - region: str = sagemaker_session.boto_region_name - account_id: str = sagemaker_session.account_id() - - # TODO: Validate and fast fail - - return f"sagemaker-hubs-{region}-{account_id}" - - -def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]: - """Utiity to help generate an S3 object reference""" - if not s3_uri: - return None - - bucket, key = parse_s3_url(s3_uri) - - return S3ObjectLocation( - bucket=bucket, - key=key, - ) - - -def create_hub_bucket_if_it_does_not_exist( - bucket_name: Optional[str] = None, - sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: - """Creates the default SageMaker Hub bucket if it does not exist. - - Returns: - str: The name of the default bucket. Takes the form: - ``sagemaker-hubs-{region}-{AWS account ID}``. - """ - - region: str = sagemaker_session.boto_region_name - if bucket_name is None: - bucket_name: str = generate_default_hub_bucket_name(sagemaker_session) - - sagemaker_session._create_s3_bucket_if_it_does_not_exist( - bucket_name=bucket_name, - region=region, - ) - - return bucket_name - - def is_gated_bucket(bucket_name: str) -> bool: """Returns true if the bucket name is the JumpStart gated bucket.""" return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET @@ -211,11 +158,17 @@ def get_hub_model_version( ClientError: If the specified model is not found in the hub. KeyError: If the specified model version is not found. """ + if sagemaker_session is None: + # sagemaker_session is overridden to none by some callers + sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION try: - hub_content_summaries = sagemaker_session.list_hub_content_versions( - hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type - ).get("HubContentSummaries") + hub_content_summaries = _list_hub_content_versions_helper( + hub_name=hub_name, + hub_content_name=hub_model_name, + hub_content_type=hub_model_type, + sagemaker_session=sagemaker_session, + ) except Exception as ex: raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") @@ -232,13 +185,34 @@ def get_hub_model_version( raise +def _list_hub_content_versions_helper( + hub_name, hub_content_name, hub_content_type, sagemaker_session +): + all_hub_content_summaries = [] + list_hub_content_versions_response = sagemaker_session.list_hub_content_versions( + hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type + ) + all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries")) + while "NextToken" in list_hub_content_versions_response: + list_hub_content_versions_response = sagemaker_session.list_hub_content_versions( + hub_name=hub_name, + hub_content_name=hub_content_name, + hub_content_type=hub_content_type, + next_token=list_hub_content_versions_response["NextToken"], + ) + all_hub_content_summaries.extend( + list_hub_content_versions_response.get("HubContentSummaries") + ) + return all_hub_content_summaries + + def _get_hub_model_version_for_open_weight_version( hub_content_summaries: List[Any], hub_model_version: Optional[str] = None ) -> str: available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries] if hub_model_version == "*" or hub_model_version is None: - return str(max(available_model_versions)) + return str(max(version.parse(v) for v in available_model_versions)) try: spec = SpecifierSet(f"=={hub_model_version}") diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index b0b54db557..7dec3d78f9 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Any, Union +from typing import Callable, Dict, List, Optional, Any, Union import pandas as pd from botocore.exceptions import ClientError @@ -95,7 +95,7 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, @@ -149,7 +149,7 @@ def __init__( It can be null if this is being used to create a Model to pass to a ``PipelineModel`` which has its own Role field. (Default: None). - predictor_cls (Optional[callable[string, sagemaker.session.Session]]): A + predictor_cls (Optional[Callable[[string, sagemaker.session.Session], Any]]): A function to call to create a predictor (Default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. (Default: None). @@ -178,8 +178,8 @@ def __init__( source_dir (Optional[str]): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (Default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory is + preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. If the directory points to S3, no code is uploaded and the S3 location is used instead. (Default: None). diff --git a/src/sagemaker/jumpstart/region_config.json b/src/sagemaker/jumpstart/region_config.json new file mode 100644 index 0000000000..30bea6ee70 --- /dev/null +++ b/src/sagemaker/jumpstart/region_config.json @@ -0,0 +1,163 @@ +{ + "af-south-1": { + "content_bucket": "jumpstart-cache-prod-af-south-1", + "gated_content_bucket": "jumpstart-private-cache-prod-af-south-1" + }, + "ap-east-1": { + "content_bucket": "jumpstart-cache-prod-ap-east-1", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-east-1" + }, + "ap-northeast-1": { + "content_bucket": "jumpstart-cache-prod-ap-northeast-1", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-northeast-1", + "neo_content_bucket": "sagemaker-sd-models-prod-ap-northeast-1" + }, + "ap-northeast-2": { + "content_bucket": "jumpstart-cache-prod-ap-northeast-2", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-northeast-2", + "neo_content_bucket": "sagemaker-sd-models-prod-ap-northeast-2" + }, + "ap-northeast-3": { + "content_bucket": "jumpstart-cache-prod-ap-northeast-3", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-northeast-3", + "neo_content_bucket": "sagemaker-sd-models-prod-ap-northeast-3" + }, + "ap-south-1": { + "content_bucket": "jumpstart-cache-prod-ap-south-1", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-south-1", + "neo_content_bucket": "sagemaker-sd-models-prod-ap-south-1" + }, + "ap-south-2": { + "content_bucket": "jumpstart-cache-prod-ap-south-2", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-south-2" + }, + "ap-southeast-1": { + "content_bucket": "jumpstart-cache-prod-ap-southeast-1", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-1", + "neo_content_bucket": "sagemaker-sd-models-prod-ap-southeast-1" + }, + "ap-southeast-2": { + "content_bucket": "jumpstart-cache-prod-ap-southeast-2", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-2", + "neo_content_bucket": "sagemaker-sd-models-prod-ap-southeast-2" + }, + "ap-southeast-3": { + "content_bucket": "jumpstart-cache-prod-ap-southeast-3", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-3" + }, + "ap-southeast-4": { + "content_bucket": "jumpstart-cache-prod-ap-southeast-4", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-4" + }, + "ap-southeast-5": { + "content_bucket": "jumpstart-cache-prod-ap-southeast-5", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-5" + }, + "ap-southeast-7": { + "content_bucket": "jumpstart-cache-prod-ap-southeast-7", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-7" + }, + "ca-central-1": { + "content_bucket": "jumpstart-cache-prod-ca-central-1", + "gated_content_bucket": "jumpstart-private-cache-prod-ca-central-1", + "neo_content_bucket": "sagemaker-sd-models-prod-ca-central-1" + }, + "ca-west-1": { + "content_bucket": "jumpstart-cache-prod-ca-west-1", + "gated_content_bucket": "jumpstart-private-cache-prod-ca-west-1" + }, + "cn-north-1": { + "content_bucket": "jumpstart-cache-prod-cn-north-1", + "gated_content_bucket": "jumpstart-private-cache-prod-cn-north-1" + }, + "cn-northwest-1": { + "content_bucket": "jumpstart-cache-prod-cn-northwest-1", + "gated_content_bucket": "jumpstart-private-cache-prod-cn-northwest-1" + }, + "eu-central-1": { + "content_bucket": "jumpstart-cache-prod-eu-central-1", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-central-1", + "neo_content_bucket": "sagemaker-sd-models-prod-eu-central-1" + }, + "eu-central-2": { + "content_bucket": "jumpstart-cache-prod-eu-central-2", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-central-2" + }, + "eu-north-1": { + "content_bucket": "jumpstart-cache-prod-eu-north-1", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-north-1", + "neo_content_bucket": "sagemaker-sd-models-prod-eu-north-1" + }, + "eu-south-1": { + "content_bucket": "jumpstart-cache-prod-eu-south-1", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-south-1" + }, + "eu-south-2": { + "content_bucket": "jumpstart-cache-prod-eu-south-2", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-south-2" + }, + "eu-west-1": { + "content_bucket": "jumpstart-cache-prod-eu-west-1", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-west-1", + "neo_content_bucket": "sagemaker-sd-models-prod-eu-west-1" + }, + "eu-west-2": { + "content_bucket": "jumpstart-cache-prod-eu-west-2", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-west-2", + "neo_content_bucket": "sagemaker-sd-models-prod-eu-west-2" + }, + "eu-west-3": { + "content_bucket": "jumpstart-cache-prod-eu-west-3", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-west-3", + "neo_content_bucket": "sagemaker-sd-models-prod-eu-west-3" + }, + "il-central-1": { + "content_bucket": "jumpstart-cache-prod-il-central-1", + "gated_content_bucket": "jumpstart-private-cache-prod-il-central-1" + }, + "me-central-1": { + "content_bucket": "jumpstart-cache-prod-me-central-1", + "gated_content_bucket": "jumpstart-private-cache-prod-me-central-1" + }, + "me-south-1": { + "content_bucket": "jumpstart-cache-prod-me-south-1", + "gated_content_bucket": "jumpstart-private-cache-prod-me-south-1" + }, + "mx-central-1": { + "content_bucket": "jumpstart-cache-prod-mx-central-1", + "gated_content_bucket": "jumpstart-private-cache-prod-mx-central-1" + }, + "sa-east-1": { + "content_bucket": "jumpstart-cache-prod-sa-east-1", + "gated_content_bucket": "jumpstart-private-cache-prod-sa-east-1", + "neo_content_bucket": "sagemaker-sd-models-prod-sa-east-1" + }, + "us-east-1": { + "content_bucket": "jumpstart-cache-prod-us-east-1", + "gated_content_bucket": "jumpstart-private-cache-prod-us-east-1", + "neo_content_bucket": "sagemaker-sd-models-prod-us-east-1" + }, + "us-east-2": { + "content_bucket": "jumpstart-cache-prod-us-east-2", + "gated_content_bucket": "jumpstart-private-cache-prod-us-east-2", + "neo_content_bucket": "sagemaker-sd-models-prod-us-east-2" + }, + "us-gov-east-1": { + "content_bucket": "jumpstart-cache-prod-us-gov-east-1", + "gated_content_bucket": "jumpstart-private-cache-prod-us-gov-east-1" + }, + "us-gov-west-1": { + "content_bucket": "jumpstart-cache-prod-us-gov-west-1", + "gated_content_bucket": "jumpstart-private-cache-prod-us-gov-west-1" + }, + "us-west-1": { + "content_bucket": "jumpstart-cache-prod-us-west-1", + "gated_content_bucket": "jumpstart-private-cache-prod-us-west-1", + "neo_content_bucket": "sagemaker-sd-models-prod-us-west-1" + }, + "us-west-2": { + "content_bucket": "jumpstart-cache-prod-us-west-2", + "gated_content_bucket": "jumpstart-private-cache-prod-us-west-2", + "neo_content_bucket": "sagemaker-sd-models-prod-us-west-2" + } +} \ No newline at end of file diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index f59e2eddf4..0cd4bcc902 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -16,7 +16,7 @@ import re from copy import deepcopy from enum import Enum -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union from sagemaker_core.shapes import ModelAccessConfig as CoreModelAccessConfig from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard from sagemaker.utils import ( @@ -619,6 +619,19 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str instance_type=instance_type, property_name="artifact_key" ) + def get_instance_specific_training_artifact_key(self, instance_type: str) -> Optional[str]: + """Returns instance specific training artifact key. + + Returns None if a model, instance type tuple does not have specific + training artifact key. + """ + + return self._get_instance_specific_property( + instance_type=instance_type, property_name="training_artifact_uri" + ) or self._get_instance_specific_property( + instance_type=instance_type, property_name="training_artifact_key" + ) + def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]: """Returns instance specific resource requirements. @@ -1266,6 +1279,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "hosting_neuron_model_version", "hub_content_type", "_is_hub_content", + "default_training_dataset_key", + "default_training_dataset_uri", ] _non_serializable_slots = ["_is_hub_content"] @@ -1363,9 +1378,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {})) self.predictor_specs: Optional[JumpStartPredictorSpecs] = ( JumpStartPredictorSpecs( - json_obj["predictor_specs"], is_hub_content=self._is_hub_content + json_obj.get("predictor_specs"), + is_hub_content=self._is_hub_content, ) - if "predictor_specs" in json_obj + if json_obj.get("predictor_specs") else None ) self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( @@ -1448,6 +1464,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: else None ) self.model_subscription_link = json_obj.get("model_subscription_link") + self.default_training_dataset_key: Optional[str] = json_obj.get( + "default_training_dataset_key" + ) + self.default_training_dataset_uri: Optional[str] = json_obj.get( + "default_training_dataset_uri" + ) def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataBaseFields object.""" @@ -1501,6 +1523,9 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields): "incremental_training_supported", ] + # Map of HubContent fields that map to custom names in MetadataBaseFields + CUSTOM_FIELD_MAP = {"sage_maker_sdk_predictor_specifications": "predictor_specs"} + __slots__ = slots + JumpStartMetadataBaseFields.__slots__ def __init__( @@ -1532,6 +1557,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if field in self.__slots__: setattr(self, field, json_obj[field]) + # Handle custom fields + for custom_field, field in self.CUSTOM_FIELD_MAP.items(): + if custom_field in json_obj: + setattr(self, field, json_obj.get(custom_field)) + class JumpStartMetadataConfig(JumpStartDataHolderType): """Data class of JumpStart metadata config.""" @@ -2150,7 +2180,7 @@ def __init__( image_uri: Optional[Union[str, Any]] = None, model_data: Optional[Union[str, Any, dict]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, Any]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None, @@ -2698,7 +2728,7 @@ def __init__( explainer_config: Optional[Any] = None, image_uri: Optional[Union[str, Any]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, Any]]] = None, model_name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None, diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 46e5f8a847..15f9e9b52e 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -21,7 +21,7 @@ from urllib.parse import urlparse import boto3 from botocore.exceptions import ClientError -from packaging.version import Version +from packaging.version import Version, InvalidVersion import botocore from sagemaker_core.shapes import ModelAccessConfig import sagemaker @@ -1630,3 +1630,72 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str: return get_jumpstart_gated_content_bucket(region=region) return get_jumpstart_content_bucket(region=region) return neo_bucket + + +def remove_env_var_from_estimator_kwargs_if_model_access_config_present( + init_kwargs: dict, model_access_config: Optional[dict] +): + """Remove env vars if ModelAccessConfig is used + + Args: + init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated. + accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit(). + """ + if ( + model_access_config is not None + and init_kwargs.get("environment") is not None + and init_kwargs.get("model_uri") is not None + ): + if ( + constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY + in init_kwargs["environment"] + ): + del init_kwargs["environment"][ + constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY + ] + if "accept_eula" in init_kwargs["environment"]: + del init_kwargs["environment"]["accept_eula"] + + +def get_hub_access_config(hub_content_arn: Optional[str]): + """Get hub access config + + Args: + hub_content_arn (Optional[bool]): Arn of the model reference hub content + """ + if hub_content_arn is not None: + hub_access_config = {"HubContentArn": hub_content_arn} + else: + hub_access_config = None + + return hub_access_config + + +def get_model_access_config(accept_eula: Optional[bool], environment: Optional[dict]): + """Get access configs + + Args: + accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit(). + """ + env_var_eula = environment.get("accept_eula") if environment else None + if env_var_eula is not None and accept_eula is not None: + raise ValueError( + "Cannot pass in both accept_eula and environment variables. " + "Please remove the environment variable and pass in the accept_eula parameter." + ) + + model_access_config = None + if env_var_eula is not None: + model_access_config = {"AcceptEula": env_var_eula == "true"} + if accept_eula is not None: + model_access_config = {"AcceptEula": accept_eula} + + return model_access_config + + +def get_latest_version(versions: List[str]) -> Optional[str]: + """Returns the latest version using sem-ver when possible.""" + try: + return None if not versions else max(versions, key=Version) + except InvalidVersion: + return max(versions) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index a21a375f54..0cf6c6d55a 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -845,10 +845,10 @@ def _initialize_and_validate_parameters(self, overridden_parameters): ) raise ClientError(error_msg, "start_pipeline_execution") parameter_type = default_parameters[param_name].parameter_type - if type(param_value) != parameter_type.python_type: # pylint: disable=C0123 + if not isinstance(param_value, parameter_type.python_type): error_msg = self._construct_validation_exception_message( - "Unexpected type for parameter '{}'. Expected {} but found " - "{}.".format(param_name, parameter_type.python_type, type(param_value)) + f"Unexpected type for parameter '{param_name}'. Expected \ + {parameter_type.python_type} but found {type(param_value)}." ) raise ClientError(error_msg, "start_pipeline_execution") if param_value == "": diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index ef24bb0d99..3d0f8394ab 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -473,7 +473,12 @@ def write_processing_config_files( """ config_path = os.path.join(self.container_root, host, "config") - resource_config = {"current_host": host, "hosts": self.hosts} + resource_config = { + "current_host": host, + "hosts": self.hosts, + "network_interface_name": "eth0", + "current_instance_type": self.instance_type, + } _write_json_file(os.path.join(config_path, "resourceconfig.json"), resource_config) processing_job_config = { @@ -519,7 +524,12 @@ def write_config_files(self, host, hyperparameters, input_data_config): """ config_path = os.path.join(self.container_root, host, "input", "config") - resource_config = {"current_host": host, "hosts": self.hosts} + resource_config = { + "current_host": host, + "hosts": self.hosts, + "network_interface_name": "eth0", + "current_instance_type": self.instance_type, + } json_input_data_config = {} for c in input_data_config: diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 863bbf376c..b281d9f489 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -20,7 +20,7 @@ import os import re import copy -from typing import List, Dict, Optional, Union, Any +from typing import Callable, List, Dict, Optional, Union, Any import sagemaker from sagemaker import ( @@ -53,7 +53,6 @@ from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum from sagemaker.session import Session from sagemaker.model_metrics import ModelMetrics -from sagemaker.deprecations import removed_kwargs from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.explainer import ExplainerConfig from sagemaker.metadata_properties import MetadataProperties @@ -154,7 +153,7 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, @@ -186,7 +185,7 @@ def __init__( It can be null if this is being used to create a Model to pass to a ``PipelineModel`` which has its own Role field. (default: None) - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. @@ -215,8 +214,8 @@ def __init__( source_dir (str): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + is preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. If the directory points to S3, no code is uploaded and the S3 location is used instead. @@ -745,6 +744,8 @@ def is_repack(self) -> bool: Returns: bool: if the source need to be repacked or not """ + if self.source_dir is None or self.entry_point is None: + return False return self.source_dir and self.entry_point and not self.git_config def _upload_code(self, key_prefix: str, repack: bool = False) -> None: @@ -1384,6 +1385,7 @@ def deploy( routing_config: Optional[Dict[str, Any]] = None, model_reference_arn: Optional[str] = None, inference_ami_version: Optional[str] = None, + update_endpoint: Optional[bool] = False, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -1492,6 +1494,14 @@ def deploy( } model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type content (default: None). + inference_ami_version (Optional [str]): Specifies an option from a collection of preconfigured + Amazon Machine Image (AMI) images. For a full list of options, see: + https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html + update_endpoint (Optional[bool]): + Flag to update the model in an existing Amazon SageMaker endpoint. + If True, this will deploy a new EndpointConfig to an already existing endpoint + and delete resources corresponding to the previous EndpointConfig. Default: False + Note: Currently this is supported for single model endpoints Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -1501,14 +1511,12 @@ def deploy( inference config or - If inference recommendation id is specified along with incompatible parameters Returns: - callable[string, sagemaker.session.Session] or None: Invocation of + Callable[[string, sagemaker.session.Session], Any] or None: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. """ self.accept_eula = accept_eula - removed_kwargs("update_endpoint", kwargs) - self._init_sagemaker_session_if_does_not_exist(instance_type) # Depending on the instance type, a local session (or) a session is initialized. self.role = resolve_value_from_config( @@ -1623,6 +1631,10 @@ def deploy( # Support multiple models on same endpoint if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED: + if update_endpoint: + raise ValueError( + "Currently update_endpoint is supported for single model endpoints" + ) if endpoint_name: self.endpoint_name = endpoint_name else: @@ -1743,6 +1755,7 @@ def deploy( model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, routing_config=routing_config, + inference_ami_version=inference_ami_version, ) if endpoint_name: self.endpoint_name = endpoint_name @@ -1777,17 +1790,38 @@ def deploy( if is_explainer_enabled: explainer_config_dict = explainer_config._to_request_dict() - self.sagemaker_session.endpoint_from_production_variants( - name=self.endpoint_name, - production_variants=[production_variant], - tags=tags, - kms_key=kms_key, - wait=wait, - data_capture_config_dict=data_capture_config_dict, - explainer_config_dict=explainer_config_dict, - async_inference_config_dict=async_inference_config_dict, - live_logging=endpoint_logging, - ) + if update_endpoint: + endpoint_config_name = self.sagemaker_session.create_endpoint_config( + name=self.name, + model_name=self.name, + initial_instance_count=initial_instance_count, + instance_type=instance_type, + accelerator_type=accelerator_type, + tags=tags, + kms_key=kms_key, + data_capture_config_dict=data_capture_config_dict, + volume_size=volume_size, + model_data_download_timeout=model_data_download_timeout, + container_startup_health_check_timeout=container_startup_health_check_timeout, + explainer_config_dict=explainer_config_dict, + async_inference_config_dict=async_inference_config_dict, + serverless_inference_config=serverless_inference_config_dict, + routing_config=routing_config, + inference_ami_version=inference_ami_version, + ) + self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name) + else: + self.sagemaker_session.endpoint_from_production_variants( + name=self.endpoint_name, + production_variants=[production_variant], + tags=tags, + kms_key=kms_key, + wait=wait, + data_capture_config_dict=data_capture_config_dict, + explainer_config_dict=explainer_config_dict, + async_inference_config_dict=async_inference_config_dict, + live_logging=endpoint_logging, + ) if self.predictor_cls: predictor = self.predictor_cls(self.endpoint_name, self.sagemaker_session) @@ -1959,7 +1993,7 @@ def __init__( role: Optional[str] = None, entry_point: Optional[str] = None, source_dir: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, container_log_level: Union[int, PipelineVariable] = logging.INFO, @@ -1996,11 +2030,11 @@ def __init__( source_dir (str): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. If 'git_config' is provided, - 'source_dir' should be a relative location to a directory in the Git repo. - If the directory points to S3, no code will be uploaded and the S3 location - will be used instead. + point to a file with name ``sourcedir.tar.gz``. Structure within this + directory are preserved when training on Amazon SageMaker. If 'git_config' + is provided, 'source_dir' should be a relative location to a directory in the + Git repo. If the directory points to S3, no code will be uploaded and the S3 + location will be used instead. .. admonition:: Example @@ -2012,7 +2046,7 @@ def __init__( >>> |----- test.py You can assign entry_point='inference.py', source_dir='src'. - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. @@ -2139,6 +2173,8 @@ def is_repack(self) -> bool: Returns: bool: if the source need to be repacked or not """ + if self.source_dir is None or self.entry_point is None: + return False return self.source_dir and self.entry_point and not (self.key_prefix or self.git_config) diff --git a/src/sagemaker/model_monitor/clarify_model_monitoring.py b/src/sagemaker/model_monitor/clarify_model_monitoring.py index 3edfabc747..2d9a4a69e4 100644 --- a/src/sagemaker/model_monitor/clarify_model_monitoring.py +++ b/src/sagemaker/model_monitor/clarify_model_monitoring.py @@ -86,11 +86,9 @@ def __init__( object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. """ - if type(self) == __class__: # pylint: disable=unidiomatic-typecheck + if self.__class__ is __class__: raise TypeError( - "{} is abstract, please instantiate its subclasses instead.".format( - __class__.__name__ - ) + f"{__class__.__name__} is abstract, please instantiate its subclasses instead." ) session = sagemaker_session or Session() diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 436377fea5..3bc29a1cf4 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -2413,7 +2413,12 @@ def _update_data_quality_monitoring_schedule( ) self.sagemaker_session.sagemaker_client.create_data_quality_job_definition(**request_dict) try: - self._update_monitoring_schedule(new_job_definition_name, schedule_cron_expression) + self._update_monitoring_schedule( + job_definition_name=new_job_definition_name, + schedule_cron_expression=schedule_cron_expression, + data_analysis_start_time=data_analysis_start_time, + data_analysis_end_time=data_analysis_end_time, + ) self.job_definition_name = new_job_definition_name if role is not None: self.role = role diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index ec0df519f5..458c596a36 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -22,7 +22,7 @@ from __future__ import absolute_import from typing import Optional, Union -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, model_validator, ConfigDict import sagemaker_core.shapes as shapes @@ -74,7 +74,13 @@ ] -class SourceCode(BaseModel): +class BaseConfig(BaseModel): + """BaseConfig""" + + model_config = ConfigDict(validate_assignment=True, extra="forbid") + + +class SourceCode(BaseConfig): """SourceCode. The SourceCode class allows the user to specify the source code location, dependencies, @@ -194,7 +200,7 @@ def _to_vpc_config(self) -> shapes.VpcConfig: return shapes.VpcConfig(**filtered_dict) -class InputData(BaseModel): +class InputData(BaseConfig): """InputData. This config allows the user to specify an input data source for the training job. diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py index 6cdc136dcf..f248b9b77c 100644 --- a/src/sagemaker/modules/distributed.py +++ b/src/sagemaker/modules/distributed.py @@ -13,12 +13,16 @@ """Distributed module.""" from __future__ import absolute_import +import os + +from abc import ABC, abstractmethod from typing import Optional, Dict, Any, List -from pydantic import BaseModel, PrivateAttr from sagemaker.modules.utils import safe_serialize +from sagemaker.modules.constants import SM_DRIVERS_LOCAL_PATH +from sagemaker.modules.configs import BaseConfig -class SMP(BaseModel): +class SMP(BaseConfig): """SMP. This class is used for configuring the SageMaker Model Parallelism v2 parameters. @@ -72,16 +76,37 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]: return hyperparameters -class DistributedConfig(BaseModel): - """Base class for distributed training configurations.""" +class DistributedConfig(BaseConfig, ABC): + """Abstract base class for distributed training configurations. + + This class defines the interface that all distributed training configurations + must implement. It provides a standardized way to specify driver scripts and + their locations for distributed training jobs. + """ + + @property + @abstractmethod + def driver_dir(self) -> str: + """Directory containing the driver script. + + This property should return the path to the directory containing + the driver script, relative to the container's working directory. - _type: str = PrivateAttr() + Returns: + str: Path to directory containing the driver script + """ - def model_dump(self, *args, **kwargs): - """Dump the model to a dictionary.""" - result = super().model_dump(*args, **kwargs) - result["_type"] = self._type - return result + @property + @abstractmethod + def driver_script(self) -> str: + """Name of the driver script. + + This property should return the name of the Python script that implements + the distributed training driver logic. + + Returns: + str: Name of the driver script file + """ class Torchrun(DistributedConfig): @@ -98,11 +123,27 @@ class Torchrun(DistributedConfig): The SageMaker Model Parallelism v2 parameters. """ - _type: str = PrivateAttr(default="torchrun") - process_count_per_node: Optional[int] = None smp: Optional["SMP"] = None + @property + def driver_dir(self) -> str: + """Directory containing the driver script. + + Returns: + str: Path to directory containing the driver script + """ + return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers") + + @property + def driver_script(self) -> str: + """Name of the driver script. + + Returns: + str: Name of the driver script file + """ + return "torchrun_driver.py" + class MPI(DistributedConfig): """MPI. @@ -118,7 +159,23 @@ class MPI(DistributedConfig): The custom MPI options to use for the training job. """ - _type: str = PrivateAttr(default="mpi") - process_count_per_node: Optional[int] = None mpi_additional_options: Optional[List[str]] = None + + @property + def driver_dir(self) -> str: + """Directory containing the driver script. + + Returns: + str: Path to directory containing the driver script + """ + return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers") + + @property + def driver_script(self) -> str: + """Name of the driver script. + + Returns: + str: Name of the driver script + """ + return "mpi_driver.py" diff --git a/src/sagemaker/modules/local_core/local_container.py b/src/sagemaker/modules/local_core/local_container.py index 5424f4f865..448330092d 100644 --- a/src/sagemaker/modules/local_core/local_container.py +++ b/src/sagemaker/modules/local_core/local_container.py @@ -108,6 +108,8 @@ class _LocalContainer(BaseModel): container_entrypoint: Optional[List[str]] container_arguments: Optional[List[str]] + _temporary_folders: List[str] = [] + def model_post_init(self, __context: Any): """Post init method to perform custom validation and set default values.""" self.hosts = [f"algo-{i}" for i in range(1, self.instance_count + 1)] @@ -201,6 +203,13 @@ def train( # Print our Job Complete line logger.info("Local training job completed, output artifacts saved to %s", artifacts) + + shutil.rmtree(os.path.join(self.container_root, "input")) + shutil.rmtree(os.path.join(self.container_root, "shared")) + for host in self.hosts: + shutil.rmtree(os.path.join(self.container_root, host)) + for folder in self._temporary_folders: + shutil.rmtree(os.path.join(self.container_root, folder)) return artifacts def retrieve_artifacts( @@ -540,6 +549,7 @@ def _get_data_source_local_path(self, data_source: DataSource): uri = data_source.s3_data_source.s3_uri parsed_uri = urlparse(uri) local_dir = TemporaryDirectory(prefix=os.path.join(self.container_root + "/")).name + self._temporary_folders.append(local_dir) download_folder(parsed_uri.netloc, parsed_uri.path, local_dir, self.sagemaker_session) return local_dir else: diff --git a/src/sagemaker/modules/templates.py b/src/sagemaker/modules/templates.py index fba60dda47..d888b7bcb9 100644 --- a/src/sagemaker/modules/templates.py +++ b/src/sagemaker/modules/templates.py @@ -21,17 +21,12 @@ EXECUTE_BASIC_SCRIPT_DRIVER = """ echo "Running Basic Script driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/basic_script_driver.py +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/basic_script_driver.py """ -EXEUCTE_TORCHRUN_DRIVER = """ -echo "Running Torchrun driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/torchrun_driver.py -""" - -EXECUTE_MPI_DRIVER = """ -echo "Running MPI driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/mpi_driver.py +EXEUCTE_DISTRIBUTED_DRIVER = """ +echo "Running {driver_name} Driver" +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/{driver_script} """ TRAIN_SCRIPT_TEMPLATE = """ diff --git a/src/sagemaker/modules/train/container_drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/__init__.py index 18557a2eb5..864f3663b8 100644 --- a/src/sagemaker/modules/train/container_drivers/__init__.py +++ b/src/sagemaker/modules/train/container_drivers/__init__.py @@ -10,5 +10,5 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Sagemaker modules container_drivers directory.""" +"""Sagemaker modules container drivers directory.""" from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/common/__init__.py b/src/sagemaker/modules/train/container_drivers/common/__init__.py new file mode 100644 index 0000000000..aab88c6b97 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/common/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container drivers - common directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/utils.py b/src/sagemaker/modules/train/container_drivers/common/utils.py similarity index 98% rename from src/sagemaker/modules/train/container_drivers/utils.py rename to src/sagemaker/modules/train/container_drivers/common/utils.py index e939a6e0b8..c07aa1359a 100644 --- a/src/sagemaker/modules/train/container_drivers/utils.py +++ b/src/sagemaker/modules/train/container_drivers/common/utils.py @@ -99,10 +99,10 @@ def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAME return hyperparameters_dict -def get_process_count(distributed_dict: Dict[str, Any]) -> int: +def get_process_count(process_count: Optional[int] = None) -> int: """Get the number of processes to run on each node in the training job.""" return ( - int(distributed_dict.get("process_count_per_node", 0)) + process_count or int(os.environ.get("SM_NUM_GPUS", 0)) or int(os.environ.get("SM_NUM_NEURONS", 0)) or 1 diff --git a/src/sagemaker/modules/train/container_drivers/distributed_drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/__init__.py new file mode 100644 index 0000000000..a44e7e81a9 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container drivers - drivers directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/basic_script_driver.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/basic_script_driver.py similarity index 88% rename from src/sagemaker/modules/train/container_drivers/basic_script_driver.py rename to src/sagemaker/modules/train/container_drivers/distributed_drivers/basic_script_driver.py index cb0278bc9f..0b086a8e4f 100644 --- a/src/sagemaker/modules/train/container_drivers/basic_script_driver.py +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/basic_script_driver.py @@ -13,16 +13,19 @@ """This module is the entry point for the Basic Script Driver.""" from __future__ import absolute_import +import os import sys +import json import shlex +from pathlib import Path from typing import List -from utils import ( +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 logger, get_python_executable, - read_source_code_json, - read_hyperparameters_json, execute_commands, write_failure_file, hyperparameters_to_cli_args, @@ -31,11 +34,10 @@ def create_commands() -> List[str]: """Create the commands to execute.""" - source_code = read_source_code_json() - hyperparameters = read_hyperparameters_json() + entry_script = os.environ["SM_ENTRY_SCRIPT"] + hyperparameters = json.loads(os.environ["SM_HPS"]) python_executable = get_python_executable() - entry_script = source_code["entry_script"] args = hyperparameters_to_cli_args(hyperparameters) if entry_script.endswith(".py"): commands = [python_executable, entry_script] diff --git a/src/sagemaker/modules/train/container_drivers/mpi_driver.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_driver.py similarity index 83% rename from src/sagemaker/modules/train/container_drivers/mpi_driver.py rename to src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_driver.py index dceb748cc0..9946272617 100644 --- a/src/sagemaker/modules/train/container_drivers/mpi_driver.py +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_driver.py @@ -16,18 +16,8 @@ import os import sys import json +from pathlib import Path -from utils import ( - logger, - read_source_code_json, - read_distributed_json, - read_hyperparameters_json, - hyperparameters_to_cli_args, - get_process_count, - execute_commands, - write_failure_file, - USER_CODE_PATH, -) from mpi_utils import ( start_sshd_daemon, bootstrap_master_node, @@ -38,6 +28,16 @@ ) +sys.path.insert(0, str(Path(__file__).parent.parent)) +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + logger, + hyperparameters_to_cli_args, + get_process_count, + execute_commands, + write_failure_file, +) + + def main(): """Main function for the MPI driver script. @@ -58,9 +58,9 @@ def main(): 5. Exit """ - source_code = read_source_code_json() - distribution = read_distributed_json() - hyperparameters = read_hyperparameters_json() + entry_script = os.environ["SM_ENTRY_SCRIPT"] + distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + hyperparameters = json.loads(os.environ["SM_HPS"]) sm_current_host = os.environ["SM_CURRENT_HOST"] sm_hosts = json.loads(os.environ["SM_HOSTS"]) @@ -77,7 +77,8 @@ def main(): host_list = json.loads(os.environ["SM_HOSTS"]) host_count = int(os.environ["SM_HOST_COUNT"]) - process_count = get_process_count(distribution) + process_count = int(distributed_config["process_count_per_node"] or 0) + process_count = get_process_count(process_count) if process_count > 1: host_list = ["{}:{}".format(host, process_count) for host in host_list] @@ -86,8 +87,8 @@ def main(): host_count=host_count, host_list=host_list, num_processes=process_count, - additional_options=distribution.get("mpi_additional_options", []), - entry_script_path=os.path.join(USER_CODE_PATH, source_code["entry_script"]), + additional_options=distributed_config["mpi_additional_options"] or [], + entry_script_path=entry_script, ) args = hyperparameters_to_cli_args(hyperparameters) diff --git a/src/sagemaker/modules/train/container_drivers/mpi_utils.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_utils.py similarity index 81% rename from src/sagemaker/modules/train/container_drivers/mpi_utils.py rename to src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_utils.py index c3c2b7effe..ec9e1fcef9 100644 --- a/src/sagemaker/modules/train/container_drivers/mpi_utils.py +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_utils.py @@ -14,12 +14,23 @@ from __future__ import absolute_import import os -import time +import sys import subprocess +import time +from pathlib import Path from typing import List -from utils import logger, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable +import paramiko + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + SM_EFA_NCCL_INSTANCES, + SM_EFA_RDMA_INSTANCES, + get_python_executable, + logger, +) FINISHED_STATUS_FILE = "/tmp/done.algo-1" READY_FILE = "/tmp/ready.%s" @@ -75,19 +86,45 @@ def start_sshd_daemon(): logger.info("Started SSH daemon.") +class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy): + """Class to handle host key policy for SageMaker distributed training SSH connections. + + Example: + >>> client = paramiko.SSHClient() + >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) + >>> # Will succeed for SageMaker algorithm containers + >>> client.connect('algo-1234.internal') + >>> # Will raise SSHException for other unknown hosts + >>> client.connect('unknown-host') # raises SSHException + """ + + def missing_host_key(self, client, hostname, key): + """Accept host keys for algo-* hostnames, reject others. + + Args: + client: The SSHClient instance + hostname: The hostname attempting to connect + key: The host key + + Raises: + paramiko.SSHException: If hostname doesn't match algo-* pattern + """ + if hostname.startswith("algo-"): + client.get_host_keys().add(hostname, key.get_name(), key) + return + raise paramiko.SSHException(f"Unknown host key for {hostname}") + + def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: """Check if the connection to the provided host and port is possible.""" try: - import paramiko - logger.debug("Testing connection to host %s", host) - client = paramiko.SSHClient() - client.load_system_host_keys() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - client.connect(host, port=port) - client.close() - logger.info("Can connect to host %s", host) - return True + with paramiko.SSHClient() as client: + client.load_system_host_keys() + client.set_missing_host_key_policy(CustomHostKeyPolicy()) + client.connect(host, port=port) + logger.info("Can connect to host %s", host) + return True except Exception as e: # pylint: disable=W0703 logger.info("Cannot connect to host %s", host) logger.debug(f"Connection failed with exception: {e}") @@ -183,9 +220,9 @@ def validate_smddpmprun() -> bool: def write_env_vars_to_file(): """Write environment variables to /etc/environment file.""" - with open("/etc/environment", "a") as f: + with open("/etc/environment", "a", encoding="utf-8") as f: for name in os.environ: - f.write("{}={}\n".format(name, os.environ.get(name))) + f.write(f"{name}={os.environ.get(name)}\n") def get_mpirun_command( diff --git a/src/sagemaker/modules/train/container_drivers/torchrun_driver.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/torchrun_driver.py similarity index 87% rename from src/sagemaker/modules/train/container_drivers/torchrun_driver.py rename to src/sagemaker/modules/train/container_drivers/distributed_drivers/torchrun_driver.py index 666479ec84..7fcfabe05d 100644 --- a/src/sagemaker/modules/train/container_drivers/torchrun_driver.py +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/torchrun_driver.py @@ -15,20 +15,20 @@ import os import sys +import json +from pathlib import Path from typing import List, Tuple -from utils import ( +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 logger, - read_source_code_json, - read_distributed_json, - read_hyperparameters_json, hyperparameters_to_cli_args, get_process_count, get_python_executable, execute_commands, write_failure_file, - USER_CODE_PATH, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, ) @@ -65,11 +65,12 @@ def setup_env(): def create_commands(): """Create the Torch Distributed command to execute""" - source_code = read_source_code_json() - distribution = read_distributed_json() - hyperparameters = read_hyperparameters_json() + entry_script = os.environ["SM_ENTRY_SCRIPT"] + distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + hyperparameters = json.loads(os.environ["SM_HPS"]) - process_count = get_process_count(distribution) + process_count = int(distributed_config["process_count_per_node"] or 0) + process_count = get_process_count(process_count) host_count = int(os.environ["SM_HOST_COUNT"]) torch_cmd = [] @@ -94,7 +95,7 @@ def create_commands(): ] ) - torch_cmd.extend([os.path.join(USER_CODE_PATH, source_code["entry_script"])]) + torch_cmd.extend([entry_script]) args = hyperparameters_to_cli_args(hyperparameters) torch_cmd += args diff --git a/src/sagemaker/modules/train/container_drivers/scripts/__init__.py b/src/sagemaker/modules/train/container_drivers/scripts/__init__.py index 1abbce4067..f04c5b17a0 100644 --- a/src/sagemaker/modules/train/container_drivers/scripts/__init__.py +++ b/src/sagemaker/modules/train/container_drivers/scripts/__init__.py @@ -10,5 +10,5 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Sagemaker modules scripts directory.""" +"""Sagemaker modules container drivers - scripts directory.""" from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/scripts/environment.py b/src/sagemaker/modules/train/container_drivers/scripts/environment.py index ea6abac425..897b1f8af4 100644 --- a/src/sagemaker/modules/train/container_drivers/scripts/environment.py +++ b/src/sagemaker/modules/train/container_drivers/scripts/environment.py @@ -19,12 +19,17 @@ import json import os import sys +from pathlib import Path import logging -parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -sys.path.insert(0, parent_dir) +sys.path.insert(0, str(Path(__file__).parent.parent)) -from utils import safe_serialize, safe_deserialize # noqa: E402 # pylint: disable=C0413 +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + safe_serialize, + safe_deserialize, + read_distributed_json, + read_source_code_json, +) # Initialize logger SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) @@ -42,6 +47,8 @@ SM_OUTPUT_DIR = "/opt/ml/output" SM_OUTPUT_FAILURE = "/opt/ml/output/failure" SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" +SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code" +SM_DISTRIBUTED_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/distributed_drivers" SM_MASTER_ADDR = "algo-1" SM_MASTER_PORT = 7777 @@ -158,6 +165,17 @@ def set_env( "SM_MASTER_PORT": SM_MASTER_PORT, } + # SourceCode and DistributedConfig Environment Variables + source_code = read_source_code_json() + if source_code: + env_vars["SM_SOURCE_DIR"] = SM_SOURCE_DIR_PATH + env_vars["SM_ENTRY_SCRIPT"] = source_code.get("entry_script", "") + + distributed = read_distributed_json() + if distributed: + env_vars["SM_DISTRIBUTED_DRIVER_DIR"] = SM_DISTRIBUTED_DRIVER_DIR_PATH + env_vars["SM_DISTRIBUTED_CONFIG"] = distributed + # Data Channels channels = list(input_data_config.keys()) for channel in channels: diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 31decfaca9..aef6e3312b 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -18,8 +18,8 @@ import json import shutil from tempfile import TemporaryDirectory - from typing import Optional, List, Union, Dict, Any, ClassVar +import yaml from graphene.utils.str_converters import to_camel_case, to_snake_case @@ -70,7 +70,7 @@ ) from sagemaker.modules.local_core.local_container import _LocalContainer -from sagemaker.modules.distributed import Torchrun, MPI, DistributedConfig +from sagemaker.modules.distributed import Torchrun, DistributedConfig from sagemaker.modules.utils import ( _get_repo_name_from_image, _get_unique_name, @@ -94,8 +94,7 @@ from sagemaker.modules.templates import ( TRAIN_SCRIPT_TEMPLATE, EXECUTE_BASE_COMMANDS, - EXECUTE_MPI_DRIVER, - EXEUCTE_TORCHRUN_DRIVER, + EXEUCTE_DISTRIBUTED_DRIVER, EXECUTE_BASIC_SCRIPT_DRIVER, ) from sagemaker.telemetry.telemetry_logging import _telemetry_emitter @@ -153,7 +152,7 @@ class ModelTrainer(BaseModel): source_code (Optional[SourceCode]): The source code configuration. This is used to configure the source code for running the training job. - distributed (Optional[Union[MPI, Torchrun]]): + distributed (Optional[DistributedConfig]): The distributed runner for the training job. This is used to configure a distributed training job. If specifed, ``source_code`` must also be provided. @@ -195,8 +194,9 @@ class ModelTrainer(BaseModel): Defaults to "File". environment (Optional[Dict[str, str]]): The environment variables for the training job. - hyperparameters (Optional[Dict[str, Any]]): - The hyperparameters for the training job. + hyperparameters (Optional[Union[Dict[str, Any], str]): + The hyperparameters for the training job. Can be a dictionary of hyperparameters + or a path to hyperparameters json/yaml file. tags (Optional[List[Tag]]): An array of key-value pairs. You can use tags to categorize your AWS resources in different ways, for example, by purpose, owner, or environment. @@ -205,14 +205,16 @@ class ModelTrainer(BaseModel): "LOCAL_CONTAINER" mode. """ - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + model_config = ConfigDict( + arbitrary_types_allowed=True, validate_assignment=True, extra="forbid" + ) training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB sagemaker_session: Optional[Session] = None role: Optional[str] = None base_job_name: Optional[str] = None source_code: Optional[SourceCode] = None - distributed: Optional[Union[MPI, Torchrun]] = None + distributed: Optional[DistributedConfig] = None compute: Optional[Compute] = None networking: Optional[Networking] = None stopping_condition: Optional[StoppingCondition] = None @@ -224,7 +226,7 @@ class ModelTrainer(BaseModel): checkpoint_config: Optional[CheckpointConfig] = None training_input_mode: Optional[str] = "File" environment: Optional[Dict[str, str]] = {} - hyperparameters: Optional[Dict[str, Any]] = {} + hyperparameters: Optional[Union[Dict[str, Any], str]] = {} tags: Optional[List[Tag]] = None local_container_root: Optional[str] = os.getcwd() @@ -363,9 +365,10 @@ def _populate_intelligent_defaults_from_model_trainer_space(self): def __del__(self): """Destructor method to clean up the temporary directory.""" - # Clean up the temporary directory if it exists - if self._temp_recipe_train_dir is not None: - self._temp_recipe_train_dir.cleanup() + # Clean up the temporary directory if it exists and class was initialized + if hasattr(self, "__pydantic_fields_set__"): + if self._temp_recipe_train_dir is not None: + self._temp_recipe_train_dir.cleanup() def _validate_training_image_and_algorithm_name( self, training_image: Optional[str], algorithm_name: Optional[str] @@ -467,6 +470,29 @@ def model_post_init(self, __context: Any): f"StoppingCondition not provided. Using default:\n{self.stopping_condition}" ) + if self.hyperparameters and isinstance(self.hyperparameters, str): + if not os.path.exists(self.hyperparameters): + raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}") + logger.info(f"Loading hyperparameters from file: {self.hyperparameters}") + with open(self.hyperparameters, "r") as f: + contents = f.read() + try: + self.hyperparameters = json.loads(contents) + logger.debug("Hyperparameters loaded as JSON") + except json.JSONDecodeError: + try: + logger.info(f"contents: {contents}") + self.hyperparameters = yaml.safe_load(contents) + if not isinstance(self.hyperparameters, dict): + raise ValueError("YAML contents must be a valid mapping") + logger.info(f"hyperparameters: {self.hyperparameters}") + logger.debug("Hyperparameters loaded as YAML") + except (yaml.YAMLError, ValueError): + raise ValueError( + f"Invalid hyperparameters file: {self.hyperparameters}. " + "Must be a valid JSON or YAML file." + ) + if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None: session = self.sagemaker_session base_job_name = self.base_job_name @@ -534,12 +560,17 @@ def train( container_arguments = None if self.source_code: if self.training_mode == Mode.LOCAL_CONTAINER: - drivers_dir = TemporaryDirectory( - prefix=os.path.join(self.local_container_root + "/") - ) + tmp_dir = TemporaryDirectory(prefix=os.path.join(self.local_container_root + "/")) else: - drivers_dir = TemporaryDirectory() - shutil.copytree(SM_DRIVERS_LOCAL_PATH, drivers_dir.name, dirs_exist_ok=True) + tmp_dir = TemporaryDirectory() + # Copy everything under container_drivers/ to a temporary directory + shutil.copytree(SM_DRIVERS_LOCAL_PATH, tmp_dir.name, dirs_exist_ok=True) + + # If distributed is provided, overwrite code under /drivers + if self.distributed: + distributed_driver_dir = self.distributed.driver_dir + driver_dir = os.path.join(tmp_dir.name, "distributed_drivers") + shutil.copytree(distributed_driver_dir, driver_dir, dirs_exist_ok=True) # If source code is provided, create a channel for the source code # The source code will be mounted at /opt/ml/input/data/code in the container @@ -552,7 +583,7 @@ def train( input_data_config.append(source_code_channel) self._prepare_train_script( - tmp_dir=drivers_dir, + tmp_dir=tmp_dir, source_code=self.source_code, distributed=self.distributed, ) @@ -561,13 +592,13 @@ def train( mp_parameters = self.distributed.smp._to_mp_hyperparameters() string_hyper_parameters.update(mp_parameters) - self._write_source_code_json(tmp_dir=drivers_dir, source_code=self.source_code) - self._write_distributed_json(tmp_dir=drivers_dir, distributed=self.distributed) + self._write_source_code_json(tmp_dir=tmp_dir, source_code=self.source_code) + self._write_distributed_json(tmp_dir=tmp_dir, distributed=self.distributed) # Create an input channel for drivers packaged by the sdk sm_drivers_channel = self.create_input_data_channel( channel_name=SM_DRIVERS, - data_source=drivers_dir.name, + data_source=tmp_dir.name, key_prefix=input_data_key_prefix, ) input_data_config.append(sm_drivers_channel) @@ -769,7 +800,7 @@ def _write_source_code_json(self, tmp_dir: TemporaryDirectory, source_code: Sour """Write the source code configuration to a JSON file.""" file_path = os.path.join(tmp_dir.name, SOURCE_CODE_JSON) with open(file_path, "w") as f: - dump = source_code.model_dump(exclude_none=True) if source_code else {} + dump = source_code.model_dump() if source_code else {} f.write(json.dumps(dump)) def _write_distributed_json( @@ -780,7 +811,7 @@ def _write_distributed_json( """Write the distributed runner configuration to a JSON file.""" file_path = os.path.join(tmp_dir.name, DISTRIBUTED_JSON) with open(file_path, "w") as f: - dump = distributed.model_dump(exclude_none=True) if distributed else {} + dump = distributed.model_dump() if distributed else {} f.write(json.dumps(dump)) def _prepare_train_script( @@ -792,14 +823,14 @@ def _prepare_train_script( """Prepare the training script to be executed in the training job container. Args: - source_code (SourceCodeConfig): The source code configuration. + source_code (SourceCode): The source code configuration. """ base_command = "" if source_code.command: if source_code.entry_script: logger.warning( - "Both 'command' and 'entry_script' are provided in the SourceCodeConfig. " + "Both 'command' and 'entry_script' are provided in the SourceCode. " + "Defaulting to 'command'." ) base_command = source_code.command.split() @@ -817,13 +848,10 @@ def _prepare_train_script( if base_command: execute_driver = EXECUTE_BASE_COMMANDS.format(base_command=base_command) elif distributed: - distribution_type = distributed._type - if distribution_type == "mpi": - execute_driver = EXECUTE_MPI_DRIVER - elif distribution_type == "torchrun": - execute_driver = EXEUCTE_TORCHRUN_DRIVER - else: - raise ValueError(f"Unsupported distribution type: {distribution_type}.") + execute_driver = EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name=distributed.__class__.__name__, + driver_script=distributed.driver_script, + ) elif source_code.entry_script and not source_code.command and not distributed: if not source_code.entry_script.endswith((".py", ".sh")): raise ValueError( @@ -831,6 +859,13 @@ def _prepare_train_script( + "Only .py and .sh scripts are supported." ) execute_driver = EXECUTE_BASIC_SCRIPT_DRIVER + else: + # This should never be reached, as the source_code should have been validated. + raise ValueError( + f"Unsupported SourceCode or DistributedConfig: {source_code}, {distributed}." + + "Please provide a valid configuration with atleast one of 'command'" + + " or entry_script'." + ) train_script = TRAIN_SCRIPT_TEMPLATE.format( working_dir=working_dir, diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index ff38bcbde8..549645cbe2 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -125,6 +125,27 @@ def _register_custom_resolvers(): OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers)) +def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): + """Get the model base name and script for the training recipe.""" + + model_type_to_script = { + "llama_v3": ("llama", "llama_pretrain.py"), + "mistral": ("mistral", "mistral_pretrain.py"), + "mixtral": ("mixtral", "mixtral_pretrain.py"), + "deepseek": ("deepseek", "deepseek_pretrain.py"), + } + + for key in model_type_to_script: + if model_type.startswith(key): + model_type = key + break + + if model_type not in model_type_to_script: + raise ValueError(f"Model type {model_type} not supported") + + return model_type_to_script[model_type][0], model_type_to_script[model_type][1] + + def _configure_gpu_args( training_recipes_cfg: Dict[str, Any], region_name: str, @@ -140,24 +161,16 @@ def _configure_gpu_args( ) _run_clone_command_silent(adapter_repo, recipe_train_dir.name) - model_type_to_entry = { - "llama_v3": ("llama", "llama_pretrain.py"), - "mistral": ("mistral", "mistral_pretrain.py"), - "mixtral": ("mixtral", "mixtral_pretrain.py"), - } - if "model" not in recipe: raise ValueError("Supplied recipe does not contain required field model.") if "model_type" not in recipe["model"]: raise ValueError("Supplied recipe does not contain required field model_type.") model_type = recipe["model"]["model_type"] - if model_type not in model_type_to_entry: - raise ValueError(f"Model type {model_type} not supported") - source_code.source_dir = os.path.join( - recipe_train_dir.name, "examples", model_type_to_entry[model_type][0] - ) - source_code.entry_script = model_type_to_entry[model_type][1] + model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type) + + source_code.source_dir = os.path.join(recipe_train_dir.name, "examples", model_base_name) + source_code.entry_script = script gpu_image_cfg = training_recipes_cfg.get("gpu_image") if isinstance(gpu_image_cfg, str): diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index 9ed348c927..43a3588e6f 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -223,7 +223,7 @@ def deploy( Amazon SageMaker Model Monitoring. Default: None. Returns: - callable[string, sagemaker.session.Session] or None: Invocation of + Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index 104b93e00a..5126a37a85 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -84,8 +84,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 0dcd71741d..fa0c691d2d 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, List, Dict +from typing import Callable, Union, Optional, List, Dict import packaging.version @@ -68,9 +68,9 @@ def __init__( manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. - serializer (callable): Optional. Default serializes input data to + serializer (Callable): Optional. Default serializes input data to json. Handles dicts, lists, and numpy arrays. - deserializer (callable): Optional. Default parses the response using + deserializer (Callable): Optional. Default parses the response using ``json.load(...)``. component_name (str): Optional. Name of the Amazon SageMaker inference component corresponding to the predictor. @@ -98,7 +98,7 @@ def __init__( framework_version: str = _LOWEST_MMS_VERSION, py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = MXNetPredictor, + predictor_cls: Optional[Callable] = MXNetPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): @@ -127,7 +127,7 @@ def __init__( If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If ``image_uri`` is also ``None``, then a ``ValueError`` will be raised. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 04fbc1cc93..b36cd4e917 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -13,10 +13,12 @@ """Placeholder docstring""" from __future__ import absolute_import -from typing import Optional, Dict, List, Union +from typing import Callable, Optional, Dict, List, Union import sagemaker from sagemaker import ModelMetrics, Model +from sagemaker import local +from sagemaker import session from sagemaker.config import ( ENDPOINT_CONFIG_KMS_KEY_ID_PATH, MODEL_VPC_CONFIG_PATH, @@ -54,7 +56,7 @@ def __init__( self, models: List[Model], role: str = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, sagemaker_session: Optional[Session] = None, @@ -75,7 +77,7 @@ def __init__( endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. @@ -230,7 +232,7 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Returns: - callable[string, sagemaker.session.Session] or None: Invocation of + Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. """ @@ -560,3 +562,16 @@ def delete_model(self): raise ValueError("The SageMaker model must be created before attempting to delete.") self.sagemaker_session.delete_model(self.name) + + def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): + """Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already. + + The type of session object is determined by the instance type. + """ + if self.sagemaker_session: + return + + if instance_type in ("local", "local_gpu"): + self.sagemaker_session = local.LocalSession(sagemaker_config=self._sagemaker_config) + else: + self.sagemaker_session = session.Session(sagemaker_config=self._sagemaker_config) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 36cb920dde..103be47caf 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -17,52 +17,51 @@ and interpretation on Amazon SageMaker. """ from __future__ import absolute_import - +import logging import os import pathlib -import logging +import re +from copy import copy from textwrap import dedent from typing import Dict, List, Optional, Union -from copy import copy -import re import attr - from six.moves.urllib.parse import urlparse from six.moves.urllib.request import url2pathname + from sagemaker import s3 +from sagemaker.apiutils._base_types import ApiObject from sagemaker.config import ( + PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + PROCESSING_JOB_ENVIRONMENT_PATH, + PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, PROCESSING_JOB_KMS_KEY_ID_PATH, + PROCESSING_JOB_ROLE_ARN_PATH, PROCESSING_JOB_SECURITY_GROUP_IDS_PATH, PROCESSING_JOB_SUBNETS_PATH, - PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH, - PROCESSING_JOB_ROLE_ARN_PATH, - PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, - PROCESSING_JOB_ENVIRONMENT_PATH, ) +from sagemaker.dataset_definition.inputs import DatasetDefinition, S3Input from sagemaker.job import _Job from sagemaker.local import LocalSession from sagemaker.network import NetworkConfig +from sagemaker.s3 import S3Uploader +from sagemaker.session import Session from sagemaker.utils import ( + Tags, base_name_from_image, + check_and_get_run_experiment_config, + format_tags, get_config_value, name_from_base, - check_and_get_run_experiment_config, - resolve_value_from_config, resolve_class_attribute_from_config, - Tags, - format_tags, + resolve_value_from_config, ) -from sagemaker.session import Session from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.functions import Join from sagemaker.workflow.pipeline_context import runnable_by_pipeline -from sagemaker.workflow.execution_variables import ExecutionVariables -from sagemaker.workflow.entities import PipelineVariable -from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition -from sagemaker.apiutils._base_types import ApiObject -from sagemaker.s3 import S3Uploader logger = logging.getLogger(__name__) @@ -1416,7 +1415,7 @@ class RunArgs(object): class FeatureStoreOutput(ApiObject): """Configuration for processing job outputs in Amazon SageMaker Feature Store.""" - feature_group_name = None + feature_group_name: Optional[str] = None class FrameworkProcessor(ScriptProcessor): @@ -1465,7 +1464,7 @@ def __init__( instance_type (str or PipelineVariable): The type of EC2 instance to use for processing, for example, 'ml.c4.xlarge'. py_version (str): Python version you want to use for executing your - model training code. One of 'py2' or 'py3'. Defaults to 'py3'. Value + model training code. Ex `py38, py39, py310, py311`. Value is ignored when ``image_uri`` is provided. image_uri (str or PipelineVariable): The URI of the Docker image to use for the processing jobs (default: None). diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 46c57581d1..d56c100546 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -95,6 +95,7 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir): "llama_v3": ("llama", "llama_pretrain.py"), "mistral": ("mistral", "mistral_pretrain.py"), "mixtral": ("mixtral", "mixtral_pretrain.py"), + "deepseek": ("deepseek", "deepseek_pretrain.py"), } if "model" not in recipe: @@ -102,6 +103,12 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir): if "model_type" not in recipe["model"]: raise ValueError("Supplied recipe does not contain required field model_type.") model_type = recipe["model"]["model_type"] + + for key in model_type_to_script: + if model_type.startswith(key): + model_type = key + break + if model_type not in model_type_to_script: raise ValueError(f"Model type {model_type} not supported") @@ -175,8 +182,8 @@ def __init__( unless ``image_uri`` is provided. source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry - point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved + point file (default: None). If ``source_dir`` is an S3 URI, it must point to a + file with name ``sourcedir.tar.gz``. Structure within this directory are preserved when training on Amazon SageMaker. Must be a local path when using training_recipe. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 329f9b83b5..958327ba08 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import packaging.version @@ -99,7 +99,7 @@ def __init__( framework_version: str = "1.3", py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = PyTorchPredictor, + predictor_cls: Optional[Callable] = PyTorchPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): @@ -128,7 +128,7 @@ def __init__( If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If ``image_uri`` is also ``None``, then a ``ValueError`` will be raised. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 73a308ddf5..55b4654aa9 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -90,8 +90,10 @@ def remote( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, - use_torchrun=False, - nproc_per_node=1, + disable_output_compression: bool = False, + use_torchrun: bool = False, + use_mpirun: bool = False, + nproc_per_node: Optional[int] = None, ): """Decorator for running the annotated function as a SageMaker training job. @@ -207,7 +209,8 @@ def remote( files are accepted and uploaded to S3. instance_count (int): The number of instances to use. Defaults to 1. - NOTE: Remote function does not support instance_count > 1 for non Spark jobs. + NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and + mpirun utilities instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. @@ -281,11 +284,18 @@ def remote( After this amount of time Amazon SageMaker will stop waiting for managed spot training job to complete. Defaults to ``None``. + disable_output_compression (bool): Optional. When set to true, Model is uploaded to + Amazon S3 without compression after training finishes. + use_torchrun (bool): Specifies whether to use torchrun for distributed training. Defaults to ``False``. - nproc_per_node (int): Specifies the number of processes per node for distributed training. - Defaults to ``1``. + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + + nproc_per_node (int): Optional. Specifies the number of processes per node for + distributed training. Defaults to ``None``. + This is defined automatically configured on the instance type. """ def _remote(func): @@ -318,16 +328,23 @@ def _remote(func): spark_config=spark_config, use_spot_instances=use_spot_instances, max_wait_time_in_seconds=max_wait_time_in_seconds, + disable_output_compression=disable_output_compression, use_torchrun=use_torchrun, + use_mpirun=use_mpirun, nproc_per_node=nproc_per_node, ) @functools.wraps(func) def wrapper(*args, **kwargs): - if instance_count > 1 and not spark_config: + if instance_count > 1 and not ( + (spark_config is not None and not use_torchrun and not use_mpirun) + or (spark_config is None and use_torchrun and not use_mpirun) + or (spark_config is None and not use_torchrun and use_mpirun) + ): raise ValueError( - "Remote function do not support training on multi instances. " + "Remote function do not support training on multi instances " + + "without spark_config or use_torchrun or use_mpirun. " + "Please provide instance_count = 1" ) @@ -531,8 +548,10 @@ def __init__( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, - use_torchrun=False, - nproc_per_node=1, + disable_output_compression: bool = False, + use_torchrun: bool = False, + use_mpirun: bool = False, + nproc_per_node: Optional[int] = None, ): """Constructor for RemoteExecutor @@ -645,7 +664,8 @@ def __init__( files are accepted and uploaded to S3. instance_count (int): The number of instances to use. Defaults to 1. - NOTE: Remote function does not support instance_count > 1 for non Spark jobs. + NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and + mpirun utilities instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. @@ -722,20 +742,32 @@ def __init__( After this amount of time Amazon SageMaker will stop waiting for managed spot training job to complete. Defaults to ``None``. + disable_output_compression (bool): Optional. When set to true, Model is uploaded to + Amazon S3 without compression after training finishes. + use_torchrun (bool): Specifies whether to use torchrun for distributed training. Defaults to ``False``. - nproc_per_node (int): Specifies the number of processes per node. - Defaults to ``1``. + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + + nproc_per_node (int): Optional. Specifies the number of processes per node for + distributed training. Defaults to ``None``. + This is defined automatically configured on the instance type. """ self.max_parallel_jobs = max_parallel_jobs if self.max_parallel_jobs <= 0: raise ValueError("max_parallel_jobs must be greater than 0.") - if instance_count > 1 and not spark_config: + if instance_count > 1 and not ( + (spark_config is not None and not use_torchrun and not use_mpirun) + or (spark_config is None and use_torchrun and not use_mpirun) + or (spark_config is None and not use_torchrun and use_mpirun) + ): raise ValueError( - "Remote function do not support training on multi instances. " + "Remote function do not support training on multi instances " + + "without spark_config or use_torchrun or use_mpirun. " + "Please provide instance_count = 1" ) @@ -767,7 +799,9 @@ def __init__( spark_config=spark_config, use_spot_instances=use_spot_instances, max_wait_time_in_seconds=max_wait_time_in_seconds, + disable_output_compression=disable_output_compression, use_torchrun=use_torchrun, + use_mpirun=use_mpirun, nproc_per_node=nproc_per_node, ) diff --git a/src/sagemaker/remote_function/core/stored_function.py b/src/sagemaker/remote_function/core/stored_function.py index ade4a9e652..862c67d9ee 100644 --- a/src/sagemaker/remote_function/core/stored_function.py +++ b/src/sagemaker/remote_function/core/stored_function.py @@ -55,8 +55,6 @@ def __init__( hmac_key: str, s3_kms_key: str = None, context: Context = Context(), - use_torchrun: bool = False, - nproc_per_node: int = 1, ): """Construct a StoredFunction object. @@ -67,16 +65,12 @@ def __init__( s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. hmac_key: Key used to encrypt serialized and deserialized function and arguments. context: Build or run context of a pipeline step. - use_torchrun: Whether to use torchrun for distributed training. - nproc_per_node: Number of processes per node for distributed training. """ self.sagemaker_session = sagemaker_session self.s3_base_uri = s3_base_uri self.s3_kms_key = s3_kms_key self.hmac_key = hmac_key self.context = context - self.use_torchrun = use_torchrun - self.nproc_per_node = nproc_per_node self.func_upload_path = s3_path_join( s3_base_uri, context.step_name, context.func_step_s3_dir diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 8ab4d420e5..9000ccda08 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -81,6 +81,7 @@ # runtime script names BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py" +MPI_UTILS_SCRIPT_NAME = "mpi_utils_remote.py" ENTRYPOINT_SCRIPT_NAME = "job_driver.sh" PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py" @@ -130,9 +131,12 @@ export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" +printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" +cat /opt/ml/input/config/resourceconfig.json printf "INFO: Bootstraping runtime environment.\\n" python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" +source /opt/ml/input/sm_training.env if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] then @@ -155,13 +159,108 @@ fi printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n" + printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function \\n" $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@" else printf "INFO: No conda env provided. Invoking remote function\\n" + printf "INFO: python -m sagemaker.remote_function.invoke_function \\n" python -m sagemaker.remote_function.invoke_function "$@" fi """ +ENTRYPOINT_MPIRUN_SCRIPT = f""" +#!/bin/bash + +# Entry point for bootstrapping runtime environment and invoking remote function with mpirun + +set -eu + +PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} +export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs +printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" +export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip +printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" + +printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" +cat /opt/ml/input/config/resourceconfig.json + +printf "INFO: Bootstraping runtime environment.\\n" +python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" +source /opt/ml/input/sm_training.env + +if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] +then + if [ -f "remote_function_conda_env.txt" ] + then + cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt + fi + printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" + cd {JOB_REMOTE_FUNCTION_WORKSPACE} +fi + +if [ -f "remote_function_conda_env.txt" ] +then + conda_env=$(cat remote_function_conda_env.txt) + + if which mamba >/dev/null; then + conda_exe="mamba" + else + conda_exe="conda" + fi + + if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + + printf "INFO: Invoking remote function with mpirun inside conda environment: $conda_env.\\n" + printf "INFO: $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + + python -m mpi4py -m sagemaker.remote_function.invoke_function \\n" + $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.remote_function.invoke_function "$@" + + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 + else + printf "INFO: This is the instance $SM_CURRENT_HOST. mpirun command terminated\\n" + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + fi +else + if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + + printf "INFO: No conda env provided. Invoking remote function with mpirun\\n" + printf "INFO: mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.remote_function.invoke_function \\n" + + mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.remote_function.invoke_function "$@" + + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 + else + printf "INFO: This is the instance $SM_CURRENT_HOST.\\n" + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + fi +fi +""" + ENTRYPOINT_TORCHRUN_SCRIPT = f""" #!/bin/bash @@ -175,9 +274,12 @@ export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" +printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" +cat /opt/ml/input/config/resourceconfig.json printf "INFO: Bootstraping runtime environment.\\n" python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" +source /opt/ml/input/sm_training.env if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] then @@ -200,11 +302,20 @@ fi printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n" - $conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \ + printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ + --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ + -m sagemaker.remote_function.invoke_function \\n" + + $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ + --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ -m sagemaker.remote_function.invoke_function "$@" else printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" - torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@" + printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\n" + + torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@" fi """ @@ -262,8 +373,10 @@ def __init__( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, - use_torchrun=False, - nproc_per_node=1, + disable_output_compression: bool = False, + use_torchrun: bool = False, + use_mpirun: bool = False, + nproc_per_node: Optional[int] = None, ): """Initialize a _JobSettings instance which configures the remote job. @@ -445,6 +558,19 @@ def __init__( max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. After this amount of time Amazon SageMaker will stop waiting for managed spot training job to complete. Defaults to ``None``. + + disable_output_compression (bool): Optional. When set to true, Model is uploaded to + Amazon S3 without compression after training finishes. + + use_torchrun (bool): Specifies whether to use torchrun for distributed training. + Defaults to ``False``. + + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + + nproc_per_node (int): Optional. Specifies the number of processes per node for + distributed training. Defaults to ``None``. + This is defined automatically configured on the instance type. """ self.sagemaker_session = sagemaker_session or Session() self.environment_variables = resolve_value_from_config( @@ -603,7 +729,9 @@ def __init__( tags = format_tags(tags) self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS) + self.disable_output_compression = disable_output_compression self.use_torchrun = use_torchrun + self.use_mpirun = use_mpirun self.nproc_per_node = nproc_per_node @staticmethod @@ -732,6 +860,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non ) logger.info("Creating job: %s", job_name) + job_settings.sagemaker_session.sagemaker_client.create_training_job(**training_job_request) return _Job( @@ -746,7 +875,7 @@ def compile( job_settings: _JobSettings, job_name: str, s3_base_uri: str, - func: callable, + func: Callable, func_args: tuple, func_kwargs: dict, run_info=None, @@ -776,8 +905,6 @@ def compile( s3_base_uri=s3_base_uri, hmac_key=hmac_key, s3_kms_key=job_settings.s3_kms_key, - use_torchrun=job_settings.use_torchrun, - nproc_per_node=job_settings.nproc_per_node, ) stored_function.save(func, *func_args, **func_kwargs) else: @@ -790,8 +917,6 @@ def compile( step_name=step_compilation_context.step_name, func_step_s3_dir=step_compilation_context.pipeline_build_time, ), - use_torchrun=job_settings.use_torchrun, - nproc_per_node=job_settings.nproc_per_node, ) stored_function.save_pipeline_step_function(serialized_data) @@ -834,6 +959,8 @@ def compile( output_config = {"S3OutputPath": s3_base_uri} if job_settings.s3_kms_key is not None: output_config["KmsKeyId"] = job_settings.s3_kms_key + if job_settings.disable_output_compression: + output_config["CompressionType"] = "NONE" request_dict["OutputDataConfig"] = output_config container_args = ["--s3_base_uri", s3_base_uri] @@ -855,6 +982,12 @@ def compile( ).to_string(), ] ) + if job_settings.use_torchrun: + container_args.extend(["--distribution", "torchrun"]) + elif job_settings.use_mpirun: + container_args.extend(["--distribution", "mpirun"]) + if job_settings.nproc_per_node is not None and int(job_settings.nproc_per_node) > 0: + container_args.extend(["--user_nproc_per_node", str(job_settings.nproc_per_node)]) if job_settings.s3_kms_key: container_args.extend(["--s3_kms_key", job_settings.s3_kms_key]) @@ -931,6 +1064,8 @@ def compile( request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key}) extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri) + extended_request = _extend_mpirun_to_request(extended_request, job_settings) + extended_request = _extend_torchrun_to_request(extended_request, job_settings) return extended_request @@ -1011,7 +1146,7 @@ def _prepare_and_upload_runtime_scripts( s3_kms_key: str, sagemaker_session: Session, use_torchrun: bool = False, - nproc_per_node: int = 1, + use_mpirun: bool = False, ): """Copy runtime scripts to a folder and upload to S3. @@ -1030,7 +1165,9 @@ def _prepare_and_upload_runtime_scripts( use_torchrun (bool): Whether to use torchrun or not. - nproc_per_node (int): Number of processes per node. + use_mpirun (bool): Whether to use mpirun or not. + + nproc_per_node (Optional[int]): Number of processes per node """ from sagemaker.workflow.utilities import load_step_compilation_context @@ -1054,7 +1191,9 @@ def _prepare_and_upload_runtime_scripts( if use_torchrun: entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT - entry_point_script = entry_point_script.replace("$NPROC_PER_NODE", str(nproc_per_node)) + + if use_mpirun: + entry_point_script = ENTRYPOINT_MPIRUN_SCRIPT with open(entrypoint_script_path, "w", newline="\n") as file: file.writelines(entry_point_script) @@ -1062,12 +1201,16 @@ def _prepare_and_upload_runtime_scripts( bootstrap_script_path = os.path.join( os.path.dirname(__file__), "runtime_environment", BOOTSTRAP_SCRIPT_NAME ) + mpi_utils_path = os.path.join( + os.path.dirname(__file__), "runtime_environment", MPI_UTILS_SCRIPT_NAME + ) runtime_manager_script_path = os.path.join( os.path.dirname(__file__), "runtime_environment", RUNTIME_MANAGER_SCRIPT_NAME ) # copy runtime scripts to tmpdir shutil.copy2(bootstrap_script_path, bootstrap_scripts) + shutil.copy2(mpi_utils_path, bootstrap_scripts) shutil.copy2(runtime_manager_script_path, bootstrap_scripts) upload_path = S3Uploader.upload( @@ -1094,7 +1237,7 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): s3_kms_key=job_settings.s3_kms_key, sagemaker_session=job_settings.sagemaker_session, use_torchrun=job_settings.use_torchrun, - nproc_per_node=job_settings.nproc_per_node, + use_mpirun=job_settings.use_mpirun, ) input_data_config = [ @@ -1435,6 +1578,64 @@ def _upload_serialized_spark_configuration( return config_file_s3_uri +def _extend_mpirun_to_request( + request_dict: Dict, + job_settings: _JobSettings, +) -> Dict: + """Extend the create training job request with mpirun configuration. + + Args: + request_dict (Dict): create training job request dict. + job_settings (_JobSettings): the job settings. + """ + use_mpirun = job_settings.use_mpirun + instance_count = job_settings.instance_count + + if not use_mpirun: + return request_dict + + if instance_count == 1: + return request_dict + + extended_request = request_dict.copy() + + for input_channel in extended_request["InputDataConfig"]: + s3_data_source = input_channel["DataSource"].get("S3DataSource", None) + if s3_data_source: + s3_data_source["S3DataDistributionType"] = "FullyReplicated" + + return extended_request + + +def _extend_torchrun_to_request( + request_dict: Dict, + job_settings: _JobSettings, +) -> Dict: + """Extend the create training job request with torchrun configuration. + + Args: + request_dict (Dict): create training job request dict. + job_settings (_JobSettings): the job settings. + """ + use_torchrun = job_settings.use_torchrun + instance_count = job_settings.instance_count + + if not use_torchrun: + return request_dict + + if instance_count == 1: + return request_dict + + extended_request = request_dict.copy() + + for input_channel in extended_request["InputDataConfig"]: + s3_data_source = input_channel["DataSource"].get("S3DataSource", None) + if s3_data_source: + s3_data_source["S3DataDistributionType"] = "FullyReplicated" + + return extended_request + + def _extend_spark_config_to_request( request_dict: Dict, job_settings: _JobSettings, diff --git a/src/sagemaker/remote_function/runtime_environment/__init__.py b/src/sagemaker/remote_function/runtime_environment/__init__.py index e69de29bb2..18557a2eb5 100644 --- a/src/sagemaker/remote_function/runtime_environment/__init__.py +++ b/src/sagemaker/remote_function/runtime_environment/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container_drivers directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index 8fd83bfcfe..da7c493ae5 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -15,10 +15,14 @@ import argparse import getpass -import sys +import json +import multiprocessing import os -import shutil import pathlib +import shutil +import subprocess +import sys +from typing import Any, Dict if __package__ is None or __package__ == "": from runtime_environment_manager import ( @@ -39,64 +43,48 @@ REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" BASE_CHANNEL_PATH = "/opt/ml/input/data" FAILURE_REASON_PATH = "/opt/ml/output/failure" -JOB_OUTPUT_DIRS = ["/opt/ml/output", "/opt/ml/model", "/tmp"] +JOB_OUTPUT_DIRS = ["/opt/ml/input", "/opt/ml/output", "/opt/ml/model", "/tmp"] PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace" SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies" +SM_MODEL_DIR = "/opt/ml/model" -logger = get_logger() +SM_INPUT_DIR = "/opt/ml/input" +SM_INPUT_DATA_DIR = "/opt/ml/input/data" +SM_INPUT_CONFIG_DIR = "/opt/ml/input/config" +SM_OUTPUT_DIR = "/opt/ml/output" +SM_OUTPUT_FAILURE = "/opt/ml/output/failure" +SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" -def main(sys_args=None): - """Entry point for bootstrap script""" - - exit_code = DEFAULT_FAILURE_CODE +SM_MASTER_ADDR = "algo-1" +SM_MASTER_PORT = 7777 - try: - args = _parse_args(sys_args) - client_python_version = args.client_python_version - client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version - job_conda_env = args.job_conda_env - pipeline_execution_id = args.pipeline_execution_id - dependency_settings = _DependencySettings.from_string(args.dependency_settings) - func_step_workspace = args.func_step_s3_dir +RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json" +ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env" - conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") +SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"] +HIDDEN_VALUE = "******" - RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) +SM_EFA_NCCL_INSTANCES = [ + "ml.g4dn.8xlarge", + "ml.g4dn.12xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", + "ml.p4d.24xlarge", + "ml.p4de.24xlarge", + "ml.p5.48xlarge", + "ml.trn1.32xlarge", +] - user = getpass.getuser() - if user != "root": - log_message = ( - "The job is running on non-root user: %s. Adding write permissions to the " - "following job output directories: %s." - ) - logger.info(log_message, user, JOB_OUTPUT_DIRS) - RuntimeEnvironmentManager().change_dir_permission( - dirs=JOB_OUTPUT_DIRS, new_permission="777" - ) +SM_EFA_RDMA_INSTANCES = [ + "ml.p4d.24xlarge", + "ml.p4de.24xlarge", + "ml.trn1.32xlarge", +] - if pipeline_execution_id: - _bootstrap_runtime_env_for_pipeline_step( - client_python_version, func_step_workspace, conda_env, dependency_settings - ) - else: - _bootstrap_runtime_env_for_remote_function( - client_python_version, conda_env, dependency_settings - ) - - RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( - client_sagemaker_pysdk_version - ) - - exit_code = SUCCESS_EXIT_CODE - except Exception as e: # pylint: disable=broad-except - logger.exception("Error encountered while bootstrapping runtime environment: %s", e) - - _write_failure_reason_file(str(e)) - finally: - sys.exit(exit_code) +logger = get_logger() def _bootstrap_runtime_env_for_remote_function( @@ -283,9 +271,332 @@ def _parse_args(sys_args): parser.add_argument("--pipeline_execution_id", type=str) parser.add_argument("--dependency_settings", type=str) parser.add_argument("--func_step_s3_dir", type=str) + parser.add_argument("--distribution", type=str, default=None) + parser.add_argument("--user_nproc_per_node", type=str, default=None) args, _ = parser.parse_known_args(sys_args) return args +def log_key_value(key: str, value: str): + """Log a key-value pair, masking sensitive values if necessary.""" + if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS): + logger.info("%s=%s", key, HIDDEN_VALUE) + elif isinstance(value, dict): + masked_value = mask_sensitive_info(value) + logger.info("%s=%s", key, json.dumps(masked_value)) + else: + try: + decoded_value = json.loads(value) + if isinstance(decoded_value, dict): + masked_value = mask_sensitive_info(decoded_value) + logger.info("%s=%s", key, json.dumps(masked_value)) + else: + logger.info("%s=%s", key, decoded_value) + except (json.JSONDecodeError, TypeError): + logger.info("%s=%s", key, value) + + +def log_env_variables(env_vars_dict: Dict[str, Any]): + """Log Environment Variables from the environment and an env_vars_dict.""" + for key, value in os.environ.items(): + log_key_value(key, value) + + for key, value in env_vars_dict.items(): + log_key_value(key, value) + + +def mask_sensitive_info(data): + """Recursively mask sensitive information in a dictionary.""" + if isinstance(data, dict): + for k, v in data.items(): + if isinstance(v, dict): + data[k] = mask_sensitive_info(v) + elif isinstance(v, str) and any( + keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS + ): + data[k] = HIDDEN_VALUE + return data + + +def num_cpus() -> int: + """Return the number of CPUs available in the current container. + + Returns: + int: Number of CPUs available in the current container. + """ + return multiprocessing.cpu_count() + + +def num_gpus() -> int: + """Return the number of GPUs available in the current container. + + Returns: + int: Number of GPUs available in the current container. + """ + try: + cmd = ["nvidia-smi", "--list-gpus"] + output = subprocess.check_output(cmd).decode("utf-8") + return sum(1 for line in output.splitlines() if line.startswith("GPU ")) + except (OSError, subprocess.CalledProcessError): + logger.info("No GPUs detected (normal if no gpus installed)") + return 0 + + +def num_neurons() -> int: + """Return the number of neuron cores available in the current container. + + Returns: + int: Number of Neuron Cores available in the current container. + """ + try: + cmd = ["neuron-ls", "-j"] + output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") + j = json.loads(output) + neuron_cores = 0 + for item in j: + neuron_cores += item.get("nc_count", 0) + logger.info("Found %s neurons on this instance", neuron_cores) + return neuron_cores + except OSError: + logger.info("No Neurons detected (normal if no neurons installed)") + return 0 + except subprocess.CalledProcessError as e: + if e.output is not None: + try: + msg = e.output.decode("utf-8").partition("error=")[2] + logger.info( + "No Neurons detected (normal if no neurons installed). \ + If neuron installed then %s", + msg, + ) + except AttributeError: + logger.info("No Neurons detected (normal if no neurons installed)") + else: + logger.info("No Neurons detected (normal if no neurons installed)") + + return 0 + + +def safe_serialize(data): + """Serialize the data without wrapping strings in quotes. + + This function handles the following cases: + 1. If `data` is a string, it returns the string as-is without wrapping in quotes. + 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns + the JSON-encoded string using `json.dumps()`. + 3. If `data` cannot be serialized (e.g., a custom object), it returns the string + representation of the data using `str(data)`. + + Args: + data (Any): The data to serialize. + + Returns: + str: The serialized JSON-compatible string or the string representation of the input. + """ + if isinstance(data, str): + return data + try: + return json.dumps(data) + except TypeError: + return str(data) + + +def set_env( + resource_config: Dict[str, Any], + distribution: str = None, + user_nproc_per_node: bool = None, + output_file: str = ENV_OUTPUT_FILE, +): + """Set environment variables for the training job container. + + Args: + resource_config (Dict[str, Any]): Resource configuration for the training job. + output_file (str): Output file to write the environment variables. + """ + # Constants + env_vars = { + "SM_MODEL_DIR": SM_MODEL_DIR, + "SM_INPUT_DIR": SM_INPUT_DIR, + "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR, + "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR, + "SM_OUTPUT_DIR": SM_OUTPUT_DIR, + "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE, + "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR, + "SM_MASTER_ADDR": SM_MASTER_ADDR, + "SM_MASTER_PORT": SM_MASTER_PORT, + } + + # Host Variables + current_host = resource_config["current_host"] + current_instance_type = resource_config["current_instance_type"] + hosts = resource_config["hosts"] + sorted_hosts = sorted(hosts) + + env_vars["SM_CURRENT_HOST"] = current_host + env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type + env_vars["SM_HOSTS"] = sorted_hosts + env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"] + env_vars["SM_HOST_COUNT"] = len(sorted_hosts) + env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host) + + env_vars["SM_NUM_CPUS"] = num_cpus() + env_vars["SM_NUM_GPUS"] = num_gpus() + env_vars["SM_NUM_NEURONS"] = num_neurons() + + # Misc. + env_vars["SM_RESOURCE_CONFIG"] = resource_config + + if user_nproc_per_node is not None and int(user_nproc_per_node) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(user_nproc_per_node) + else: + if int(env_vars["SM_NUM_GPUS"]) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_GPUS"]) + elif int(env_vars["SM_NUM_NEURONS"]) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_NEURONS"]) + else: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_CPUS"]) + + # All Training Environment Variables + env_vars["SM_TRAINING_ENV"] = { + "current_host": env_vars["SM_CURRENT_HOST"], + "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"], + "hosts": env_vars["SM_HOSTS"], + "host_count": env_vars["SM_HOST_COUNT"], + "nproc_per_node": env_vars["SM_NPROC_PER_NODE"], + "master_addr": env_vars["SM_MASTER_ADDR"], + "master_port": env_vars["SM_MASTER_PORT"], + "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"], + "input_data_dir": env_vars["SM_INPUT_DATA_DIR"], + "input_dir": env_vars["SM_INPUT_DIR"], + "job_name": os.environ["TRAINING_JOB_NAME"], + "model_dir": env_vars["SM_MODEL_DIR"], + "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"], + "num_cpus": env_vars["SM_NUM_CPUS"], + "num_gpus": env_vars["SM_NUM_GPUS"], + "num_neurons": env_vars["SM_NUM_NEURONS"], + "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"], + "resource_config": env_vars["SM_RESOURCE_CONFIG"], + } + + if distribution and distribution == "torchrun": + logger.info("Distribution: torchrun") + + instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] + network_interface_name = env_vars.get("SM_NETWORK_INTERFACE_NAME", "eth0") + + if instance_type in SM_EFA_NCCL_INSTANCES: + # Enable EFA use + env_vars["FI_PROVIDER"] = "efa" + if instance_type in SM_EFA_RDMA_INSTANCES: + # Use EFA's RDMA functionality for one-sided and two-sided transfer + env_vars["FI_EFA_USE_DEVICE_RDMA"] = "1" + env_vars["RDMAV_FORK_SAFE"] = "1" + env_vars["NCCL_SOCKET_IFNAME"] = str(network_interface_name) + env_vars["NCCL_PROTO"] = "simple" + elif distribution and distribution == "mpirun": + logger.info("Distribution: mpirun") + + env_vars["MASTER_ADDR"] = env_vars["SM_MASTER_ADDR"] + env_vars["MASTER_PORT"] = str(env_vars["SM_MASTER_PORT"]) + + host_list = [ + "{}:{}".format(host, int(env_vars["SM_NPROC_PER_NODE"])) for host in sorted_hosts + ] + env_vars["SM_HOSTS_LIST"] = ",".join(host_list) + + instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] + + if instance_type in SM_EFA_NCCL_INSTANCES: + env_vars["SM_FI_PROVIDER"] = "-x FI_PROVIDER=efa" + env_vars["SM_NCCL_PROTO"] = "-x NCCL_PROTO=simple" + else: + env_vars["SM_FI_PROVIDER"] = "" + env_vars["SM_NCCL_PROTO"] = "" + + if instance_type in SM_EFA_RDMA_INSTANCES: + env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "-x FI_EFA_USE_DEVICE_RDMA=1" + else: + env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "" + + with open(output_file, "w") as f: + for key, value in env_vars.items(): + f.write(f"export {key}='{safe_serialize(value)}'\n") + + logger.info("Environment Variables:") + log_env_variables(env_vars_dict=env_vars) + + +def main(sys_args=None): + """Entry point for bootstrap script""" + + exit_code = DEFAULT_FAILURE_CODE + + try: + args = _parse_args(sys_args) + + logger.info("Arguments:") + for arg in vars(args): + logger.info("%s=%s", arg, getattr(args, arg)) + + client_python_version = args.client_python_version + client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version + job_conda_env = args.job_conda_env + pipeline_execution_id = args.pipeline_execution_id + dependency_settings = _DependencySettings.from_string(args.dependency_settings) + func_step_workspace = args.func_step_s3_dir + distribution = args.distribution + user_nproc_per_node = args.user_nproc_per_node + + conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") + + RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) + + user = getpass.getuser() + if user != "root": + log_message = ( + "The job is running on non-root user: %s. Adding write permissions to the " + "following job output directories: %s." + ) + logger.info(log_message, user, JOB_OUTPUT_DIRS) + RuntimeEnvironmentManager().change_dir_permission( + dirs=JOB_OUTPUT_DIRS, new_permission="777" + ) + + if pipeline_execution_id: + _bootstrap_runtime_env_for_pipeline_step( + client_python_version, func_step_workspace, conda_env, dependency_settings + ) + else: + _bootstrap_runtime_env_for_remote_function( + client_python_version, conda_env, dependency_settings + ) + + RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( + client_sagemaker_pysdk_version + ) + + if os.path.exists(RESOURCE_CONFIG): + try: + logger.info("Found %s", RESOURCE_CONFIG) + with open(RESOURCE_CONFIG, "r") as f: + resource_config = json.load(f) + set_env( + resource_config=resource_config, + distribution=distribution, + user_nproc_per_node=user_nproc_per_node, + ) + except (json.JSONDecodeError, FileNotFoundError) as e: + # Optionally, you might want to log this error + logger.info("ERROR: Error processing %s: %s", RESOURCE_CONFIG, str(e)) + + exit_code = SUCCESS_EXIT_CODE + except Exception as e: # pylint: disable=broad-except + logger.exception("Error encountered while bootstrapping runtime environment: %s", e) + + _write_failure_reason_file(str(e)) + finally: + sys.exit(exit_code) + + if __name__ == "__main__": main(sys.argv[1:]) diff --git a/src/sagemaker/remote_function/runtime_environment/mpi_utils_remote.py b/src/sagemaker/remote_function/runtime_environment/mpi_utils_remote.py new file mode 100644 index 0000000000..6f3897fb0b --- /dev/null +++ b/src/sagemaker/remote_function/runtime_environment/mpi_utils_remote.py @@ -0,0 +1,252 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""An utils function for runtime environment. This must be kept independent of SageMaker PySDK""" +from __future__ import absolute_import + +import argparse +import json +import os +import subprocess +import sys +import time +from typing import List + +import paramiko + +if __package__ is None or __package__ == "": + from runtime_environment_manager import ( + get_logger, + ) +else: + from sagemaker.remote_function.runtime_environment.runtime_environment_manager import ( + get_logger, + ) + +SUCCESS_EXIT_CODE = 0 +DEFAULT_FAILURE_CODE = 1 + +FINISHED_STATUS_FILE = "/tmp/done.algo-1" +READY_FILE = "/tmp/ready.%s" +DEFAULT_SSH_PORT = 22 + +FAILURE_REASON_PATH = "/opt/ml/output/failure" +FINISHED_STATUS_FILE = "/tmp/done.algo-1" + +logger = get_logger() + + +class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy): + """Class to handle host key policy for SageMaker distributed training SSH connections. + + Example: + >>> client = paramiko.SSHClient() + >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) + >>> # Will succeed for SageMaker algorithm containers + >>> client.connect('algo-1234.internal') + >>> # Will raise SSHException for other unknown hosts + >>> client.connect('unknown-host') # raises SSHException + """ + + def missing_host_key(self, client, hostname, key): + """Accept host keys for algo-* hostnames, reject others. + + Args: + client: The SSHClient instance + hostname: The hostname attempting to connect + key: The host key + Raises: + paramiko.SSHException: If hostname doesn't match algo-* pattern + """ + if hostname.startswith("algo-"): + client.get_host_keys().add(hostname, key.get_name(), key) + return + raise paramiko.SSHException(f"Unknown host key for {hostname}") + + +def _parse_args(sys_args): + """Parses CLI arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--job_ended", type=str, default="0") + args, _ = parser.parse_known_args(sys_args) + return args + + +def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: + """Check if the connection to the provided host and port is possible.""" + try: + with paramiko.SSHClient() as client: + client.load_system_host_keys() + client.set_missing_host_key_policy(CustomHostKeyPolicy()) + client.connect(host, port=port) + logger.info("Can connect to host %s", host) + return True + except Exception as e: # pylint: disable=W0703 + logger.info("Cannot connect to host %s", host) + logger.debug("Connection failed with exception: %s", e) + return False + + +def _write_file_to_host(host: str, status_file: str) -> bool: + """Write the a file to the provided host.""" + try: + logger.info("Writing %s to %s", status_file, host) + subprocess.run( + ["ssh", host, "touch", f"{status_file}"], + capture_output=True, + text=True, + check=True, + ) + logger.info("Finished writing status file") + return True + except subprocess.CalledProcessError: + logger.info("Cannot connect to %s", host) + return False + + +def _write_failure_reason_file(failure_msg): + """Create a file 'failure' with failure reason written if bootstrap runtime env failed. + + See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html + Args: + failure_msg: The content of file to be written. + """ + if not os.path.exists(FAILURE_REASON_PATH): + with open(FAILURE_REASON_PATH, "w") as f: + f.write("RuntimeEnvironmentError: " + failure_msg) + + +def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300): + """Worker nodes wait until they can connect to the master node.""" + start_time = time.time() + while True: + logger.info("Worker is attempting to connect to the master node %s...", master_host) + if _can_connect(master_host, port): + logger.info("Worker can connect to master node %s.", master_host) + break + if time.time() - start_time > timeout: + raise TimeoutError("Timed out waiting for master %s to be reachable." % master_host) + + time.sleep(5) # Wait for 5 seconds before trying again + + +def _wait_for_status_file(status_file: str): + """Wait for the status file to be created.""" + logger.info("Waiting for status file %s", status_file) + while not os.path.exists(status_file): + time.sleep(30) + logger.info("Found status file %s", status_file) + + +def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300): + """Master node waits until it can connect to all worker nodes.""" + start_time = time.time() + if not worker_hosts: + logger.info("No worker nodes to connect to.") + return + + while True: + logger.info("Master is attempting to connect to all workers...") + all_workers_connected = all( + _can_connect(worker, port) and os.path.exists(READY_FILE % worker) + for worker in worker_hosts + ) + + if all_workers_connected: + logger.info("Master can connect to all worker nodes.") + break + if time.time() - start_time > timeout: + raise TimeoutError("Timed out waiting for workers to be reachable.") + + time.sleep(5) # Wait for 5 seconds before trying again + + +def bootstrap_master_node(worker_hosts: List[str]): + """Bootstrap the master node.""" + logger.info("Bootstrapping master node...") + _wait_for_workers(worker_hosts) + + +def bootstrap_worker_node( + master_host: str, current_host: str, status_file: str = FINISHED_STATUS_FILE +): + """Bootstrap the worker nodes.""" + logger.info("Bootstrapping worker node...") + _wait_for_master(master_host) + _write_file_to_host(master_host, READY_FILE % current_host) + _wait_for_status_file(status_file) + + +def start_sshd_daemon(): + """Start the SSH daemon on the current node.""" + sshd_executable = "/usr/sbin/sshd" + + if not os.path.exists(sshd_executable): + raise RuntimeError("SSH daemon not found.") + + # Start the sshd in daemon mode (-D) + subprocess.Popen([sshd_executable, "-D"]) + logger.info("Started SSH daemon.") + + +def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE): + """Write the status file to all worker nodes.""" + for worker in worker_hosts: + retry = 0 + while not _write_file_to_host(worker, status_file): + time.sleep(5) + retry += 1 + if retry > 5: + raise TimeoutError("Timed out waiting for %s to be reachable." % worker) + logger.info("Retrying to write status file to %s", worker) + + +def main(sys_args=None): + """Entry point for bootstrap script""" + try: + args = _parse_args(sys_args) + + job_ended = args.job_ended + + main_host = os.environ["SM_MASTER_ADDR"] + current_host = os.environ["SM_CURRENT_HOST"] + + if job_ended == "0": + logger.info("Job is running, bootstrapping nodes") + + start_sshd_daemon() + + if current_host != main_host: + bootstrap_worker_node(main_host, current_host) + else: + sorted_hosts = json.loads(os.environ["SM_HOSTS"]) + worker_hosts = [host for host in sorted_hosts if host != main_host] + + bootstrap_master_node(worker_hosts) + else: + logger.info("Job ended, writing status file to workers") + + if current_host == main_host: + sorted_hosts = json.loads(os.environ["SM_HOSTS"]) + worker_hosts = [host for host in sorted_hosts if host != main_host] + + write_status_file_to_workers(worker_hosts) + except Exception as e: # pylint: disable=broad-except + logger.exception("Error encountered while bootstrapping runtime environment: %s", e) + + _write_failure_reason_file(str(e)) + + sys.exit(DEFAULT_FAILURE_CODE) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/src/sagemaker/rl/estimator.py b/src/sagemaker/rl/estimator.py index e262604ac3..f1e1407633 100644 --- a/src/sagemaker/rl/estimator.py +++ b/src/sagemaker/rl/estimator.py @@ -120,8 +120,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on diff --git a/src/sagemaker/s3_utils.py b/src/sagemaker/s3_utils.py index e53cdbe02a..f59c8a299f 100644 --- a/src/sagemaker/s3_utils.py +++ b/src/sagemaker/s3_utils.py @@ -45,6 +45,19 @@ def parse_s3_url(url): return parsed_url.netloc, parsed_url.path.lstrip("/") +def is_s3_url(url): + """Returns True if url is an s3 url, False if not + + Args: + url (str): + + Returns: + bool: + """ + parsed_url = urlparse(url) + return parsed_url.scheme == "s3" + + def s3_path_join(*args, with_end_slash: bool = False): """Returns the arguments joined by a slash ("/"), similar to ``os.path.join()`` (on Unix). diff --git a/src/sagemaker/serializer_utils.py b/src/sagemaker/serializer_utils.py new file mode 100644 index 0000000000..96a931084c --- /dev/null +++ b/src/sagemaker/serializer_utils.py @@ -0,0 +1,222 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Placeholder docstring""" +from __future__ import absolute_import + +import logging +import struct +import sys + +import numpy as np + +from sagemaker.amazon.record_pb2 import Record +from sagemaker.utils import DeferredError + + +def _write_feature_tensor(resolved_type, record, vector): + """Placeholder Docstring""" + if resolved_type == "Int32": + record.features["values"].int32_tensor.values.extend(vector) + elif resolved_type == "Float64": + record.features["values"].float64_tensor.values.extend(vector) + elif resolved_type == "Float32": + record.features["values"].float32_tensor.values.extend(vector) + + +def _write_label_tensor(resolved_type, record, scalar): + """Placeholder Docstring""" + if resolved_type == "Int32": + record.label["values"].int32_tensor.values.extend([scalar]) + elif resolved_type == "Float64": + record.label["values"].float64_tensor.values.extend([scalar]) + elif resolved_type == "Float32": + record.label["values"].float32_tensor.values.extend([scalar]) + + +def _write_keys_tensor(resolved_type, record, vector): + """Placeholder Docstring""" + if resolved_type == "Int32": + record.features["values"].int32_tensor.keys.extend(vector) + elif resolved_type == "Float64": + record.features["values"].float64_tensor.keys.extend(vector) + elif resolved_type == "Float32": + record.features["values"].float32_tensor.keys.extend(vector) + + +def _write_shape(resolved_type, record, scalar): + """Placeholder Docstring""" + if resolved_type == "Int32": + record.features["values"].int32_tensor.shape.extend([scalar]) + elif resolved_type == "Float64": + record.features["values"].float64_tensor.shape.extend([scalar]) + elif resolved_type == "Float32": + record.features["values"].float32_tensor.shape.extend([scalar]) + + +def write_numpy_to_dense_tensor(file, array, labels=None): + """Writes a numpy array to a dense tensor + + Args: + file: + array: + labels: + """ + + # Validate shape of array and labels, resolve array and label types + if not len(array.shape) == 2: + raise ValueError("Array must be a Matrix") + if labels is not None: + if not len(labels.shape) == 1: + raise ValueError("Labels must be a Vector") + if labels.shape[0] not in array.shape: + raise ValueError( + "Label shape {} not compatible with array shape {}".format( + labels.shape, array.shape + ) + ) + resolved_label_type = _resolve_type(labels.dtype) + resolved_type = _resolve_type(array.dtype) + + # Write each vector in array into a Record in the file object + record = Record() + for index, vector in enumerate(array): + record.Clear() + _write_feature_tensor(resolved_type, record, vector) + if labels is not None: + _write_label_tensor(resolved_label_type, record, labels[index]) + _write_recordio(file, record.SerializeToString()) + + +def write_spmatrix_to_sparse_tensor(file, array, labels=None): + """Writes a scipy sparse matrix to a sparse tensor + + Args: + file: + array: + labels: + """ + try: + import scipy + except ImportError as e: + logging.warning( + "scipy failed to import. Sparse matrix functions will be impaired or broken." + ) + # Any subsequent attempt to use scipy will raise the ImportError + scipy = DeferredError(e) + + if not scipy.sparse.issparse(array): + raise TypeError("Array must be sparse") + + # Validate shape of array and labels, resolve array and label types + if not len(array.shape) == 2: + raise ValueError("Array must be a Matrix") + if labels is not None: + if not len(labels.shape) == 1: + raise ValueError("Labels must be a Vector") + if labels.shape[0] not in array.shape: + raise ValueError( + "Label shape {} not compatible with array shape {}".format( + labels.shape, array.shape + ) + ) + resolved_label_type = _resolve_type(labels.dtype) + resolved_type = _resolve_type(array.dtype) + + csr_array = array.tocsr() + n_rows, n_cols = csr_array.shape + + record = Record() + for row_idx in range(n_rows): + record.Clear() + row = csr_array.getrow(row_idx) + # Write values + _write_feature_tensor(resolved_type, record, row.data) + # Write keys + _write_keys_tensor(resolved_type, record, row.indices.astype(np.uint64)) + + # Write labels + if labels is not None: + _write_label_tensor(resolved_label_type, record, labels[row_idx]) + + # Write shape + _write_shape(resolved_type, record, n_cols) + + _write_recordio(file, record.SerializeToString()) + + +def read_records(file): + """Eagerly read a collection of amazon Record protobuf objects from file. + + Args: + file: + """ + records = [] + for record_data in read_recordio(file): + record = Record() + record.ParseFromString(record_data) + records.append(record) + return records + + +# MXNet requires recordio records have length in bytes that's a multiple of 4 +# This sets up padding bytes to append to the end of the record, for diferent +# amounts of padding required. +padding = {} +for amount in range(4): + if sys.version_info >= (3,): + padding[amount] = bytes([0x00 for _ in range(amount)]) + else: + padding[amount] = bytearray([0x00 for _ in range(amount)]) + +_kmagic = 0xCED7230A + + +def _write_recordio(f, data): + """Writes a single data point as a RecordIO record to the given file. + + Args: + f: + data: + """ + length = len(data) + f.write(struct.pack("I", _kmagic)) + f.write(struct.pack("I", length)) + pad = (((length + 3) >> 2) << 2) - length + f.write(data) + f.write(padding[pad]) + + +def read_recordio(f): + """Placeholder Docstring""" + while True: + try: + (read_kmagic,) = struct.unpack("I", f.read(4)) + except struct.error: + return + assert read_kmagic == _kmagic + (len_record,) = struct.unpack("I", f.read(4)) + pad = (((len_record + 3) >> 2) << 2) - len_record + yield f.read(len_record) + if pad: + f.read(pad) + + +def _resolve_type(dtype): + """Placeholder Docstring""" + if dtype == np.dtype(int): + return "Int32" + if dtype == np.dtype(float): + return "Float64" + if dtype == np.dtype("float32"): + return "Float32" + raise ValueError("Unsupported dtype {} on array".format(dtype)) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index ef502dc6f3..be46be0856 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -30,8 +30,10 @@ SparseMatrixSerializer, TorchTensorSerializer, StringSerializer, + RecordSerializer, ) +from sagemaker.deprecations import deprecated_class from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartModelType @@ -152,3 +154,6 @@ def retrieve_default( model_type=model_type, config_name=config_name, ) + + +numpy_to_record_serializer = deprecated_class(RecordSerializer, "numpy_to_record_serializer") diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 37a77179cb..bf6fcaa376 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -17,7 +17,7 @@ import re from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Type, Any, List, Dict, Optional +from typing import Type, Any, List, Dict, Optional, Tuple import logging from botocore.exceptions import ClientError @@ -82,6 +82,7 @@ ModelServer.DJL_SERVING, ModelServer.TGI, } +_JS_MINIMUM_VERSION_IMAGE = "{}:0.31.0-lmi13.0.0-cu124" logger = logging.getLogger(__name__) @@ -156,6 +157,7 @@ def _create_pre_trained_js_model(self) -> Type[Model]: vpc_config=self.vpc_config, sagemaker_session=self.sagemaker_session, name=self.name, + instance_type=self.instance_type, ) self._original_deploy = pysdk_model.deploy @@ -829,7 +831,13 @@ def _optimize_for_jumpstart( self.pysdk_model._enable_network_isolation = False if quantization_config or sharding_config or is_compilation: - return create_optimization_job_args + # only apply default image for vLLM usecases. + # vLLM does not support compilation for now so skip on compilation + return ( + create_optimization_job_args + if is_compilation + else self._set_optimization_image_default(create_optimization_job_args) + ) return None def _is_gated_model(self, model=None) -> bool: @@ -986,3 +994,105 @@ def _get_neuron_model_env_vars( ) return job_model.env return None + + def _set_optimization_image_default( + self, create_optimization_job_args: Dict[str, Any] + ) -> Dict[str, Any]: + """Defaults the optimization image to the JumpStart deployment config default + + Args: + create_optimization_job_args (Dict[str, Any]): create optimization job request + + Returns: + Dict[str, Any]: create optimization job request with image uri default + """ + default_image = self._get_default_vllm_image(self.pysdk_model.init_kwargs["image_uri"]) + + # find the latest vLLM image version + for optimization_config in create_optimization_job_args.get("OptimizationConfigs"): + if optimization_config.get("ModelQuantizationConfig"): + model_quantization_config = optimization_config.get("ModelQuantizationConfig") + provided_image = model_quantization_config.get("Image") + if provided_image and self._get_latest_lmi_version_from_list( + default_image, provided_image + ): + default_image = provided_image + if optimization_config.get("ModelShardingConfig"): + model_sharding_config = optimization_config.get("ModelShardingConfig") + provided_image = model_sharding_config.get("Image") + if provided_image and self._get_latest_lmi_version_from_list( + default_image, provided_image + ): + default_image = provided_image + + # default to latest vLLM version + for optimization_config in create_optimization_job_args.get("OptimizationConfigs"): + if optimization_config.get("ModelQuantizationConfig") is not None: + optimization_config.get("ModelQuantizationConfig")["Image"] = default_image + if optimization_config.get("ModelShardingConfig") is not None: + optimization_config.get("ModelShardingConfig")["Image"] = default_image + + logger.info("Defaulting to %s image for optimization job", default_image) + + return create_optimization_job_args + + def _get_default_vllm_image(self, image: str) -> bool: + """Ensures the minimum working image version for vLLM enabled optimization techniques + + Args: + image (str): JumpStart provided default image + + Returns: + str: minimum working image version + """ + dlc_name, _ = image.split(":") + major_version_number, _, _ = self._parse_lmi_version(image) + + if major_version_number < self._parse_lmi_version(_JS_MINIMUM_VERSION_IMAGE)[0]: + minimum_version_default = _JS_MINIMUM_VERSION_IMAGE.format(dlc_name) + return minimum_version_default + return image + + def _get_latest_lmi_version_from_list(self, version: str, version_to_compare: str) -> bool: + """LMI version comparator + + Args: + version (str): current version + version_to_compare (str): version to compare to + + Returns: + bool: if version_to_compare larger or equal to version + """ + parse_lmi_version = self._parse_lmi_version(version) + parse_lmi_version_to_compare = self._parse_lmi_version(version_to_compare) + + # Check major version + if parse_lmi_version_to_compare[0] > parse_lmi_version[0]: + return True + # Check minor version + if parse_lmi_version_to_compare[0] == parse_lmi_version[0]: + if parse_lmi_version_to_compare[1] > parse_lmi_version[1]: + return True + if parse_lmi_version_to_compare[1] == parse_lmi_version[1]: + # Check patch version + if parse_lmi_version_to_compare[2] >= parse_lmi_version[2]: + return True + return False + return False + return False + + def _parse_lmi_version(self, image: str) -> Tuple[int, int, int]: + """Parse out LMI version + + Args: + image (str): image to parse version out of + + Returns: + Tuple[int, int, int]: LMI version split into major, minor, patch + """ + _, dlc_tag = image.split(":") + _, lmi_version, _ = dlc_tag.split("-") + major_version, minor_version, patch_version = lmi_version.split(".") + major_version_number = major_version[3:] + + return (int(major_version_number), int(minor_version), int(patch_version)) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index e5e850b885..9122f22e44 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -1433,15 +1433,15 @@ def _model_builder_optimize_wrapper( # HF Model ID format = "meta-llama/Meta-Llama-3.1-8B" # JS Model ID format = "meta-textgeneration-llama-3-1-8b" - llama_3_1_keywords = ["llama-3.1", "llama-3-1"] - is_llama_3_1 = self.model and any( - keyword in self.model.lower() for keyword in llama_3_1_keywords + is_llama_3_plus = self.model and bool( + re.search(r"llama-3[\.\-][1-9]\d*", self.model.lower()) ) if is_gpu_instance and self.model and self.is_compiled: - if is_llama_3_1: + if is_llama_3_plus: raise ValueError( - "Compilation is not supported for Llama-3.1 with a GPU instance." + "Compilation is not supported for models greater " + "than Llama-3.0 with a GPU instance." ) if speculative_decoding_config: raise ValueError( @@ -1602,6 +1602,7 @@ def deploy( ResourceRequirements, ] ] = None, + update_endpoint: Optional[bool] = False, ) -> Union[Predictor, Transformer]: """Deploys the built Model. @@ -1615,24 +1616,33 @@ def deploy( AsyncInferenceConfig, BatchTransformInferenceConfig, ResourceRequirements]]) : Additional Config for different deployment types such as serverless, async, batch and multi-model/container + update_endpoint (Optional[bool]): + Flag to update the model in an existing Amazon SageMaker endpoint. + If True, this will deploy a new EndpointConfig to an already existing endpoint + and delete resources corresponding to the previous EndpointConfig. Default: False + Note: Currently this is supported for single model endpoints Returns: Transformer for Batch Deployments Predictors for all others """ if not hasattr(self, "built_model"): raise ValueError("Model Needs to be built before deploying") - endpoint_name = unique_name_from_base(endpoint_name) + if not update_endpoint: + endpoint_name = unique_name_from_base(endpoint_name) + if not inference_config: # Real-time Deployment return self.built_model.deploy( instance_type=self.instance_type, initial_instance_count=initial_instance_count, endpoint_name=endpoint_name, + update_endpoint=update_endpoint, ) if isinstance(inference_config, ServerlessInferenceConfig): return self.built_model.deploy( serverless_inference_config=inference_config, endpoint_name=endpoint_name, + update_endpoint=update_endpoint, ) if isinstance(inference_config, AsyncInferenceConfig): @@ -1641,6 +1651,7 @@ def deploy( initial_instance_count=initial_instance_count, async_inference_config=inference_config, endpoint_name=endpoint_name, + update_endpoint=update_endpoint, ) if isinstance(inference_config, BatchTransformInferenceConfig): @@ -1652,6 +1663,10 @@ def deploy( return transformer if isinstance(inference_config, ResourceRequirements): + if update_endpoint: + raise ValueError( + "Currently update_endpoint is supported for single model endpoints" + ) # Multi Model and MultiContainer endpoints with Inference Component return self.built_model.deploy( instance_type=self.instance_type, @@ -1660,6 +1675,7 @@ def deploy( resources=inference_config, initial_instance_count=initial_instance_count, role=self.role_arn, + update_endpoint=update_endpoint, ) raise ValueError("Deployment Options not supported") diff --git a/src/sagemaker/serve/builder/schema_builder.py b/src/sagemaker/serve/builder/schema_builder.py index 3fd1816d0e..7f70e98747 100644 --- a/src/sagemaker/serve/builder/schema_builder.py +++ b/src/sagemaker/serve/builder/schema_builder.py @@ -4,6 +4,7 @@ import io import logging from pathlib import Path +from typing import Callable import numpy as np from pandas import DataFrame @@ -286,7 +287,7 @@ def _is_path_to_file(data: object) -> bool: def _validate_translations( - payload: object, serialize_callable: callable, deserialize_callable: callable + payload: object, serialize_callable: Callable, deserialize_callable: Callable ) -> None: """Placeholder docstring""" try: diff --git a/src/sagemaker/serve/detector/dependency_manager.py b/src/sagemaker/serve/detector/dependency_manager.py index e72a84da30..8ff37c9185 100644 --- a/src/sagemaker/serve/detector/dependency_manager.py +++ b/src/sagemaker/serve/detector/dependency_manager.py @@ -34,22 +34,34 @@ def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool = """Placeholder docstring""" path = work_dir.joinpath("requirements.txt") if "auto" in dependencies and dependencies["auto"]: + import site + + pkl_path = work_dir.joinpath(PKL_FILE_NAME) + dest_path = path + site_packages_dir = site.getsitepackages()[0] + pickle_command_dir = "/sagemaker/serve/detector" + command = [ sys.executable, - Path(__file__).parent.joinpath("pickle_dependencies.py"), - "--pkl_path", - work_dir.joinpath(PKL_FILE_NAME), - "--dest", - path, + "-c", ] if capture_all: - command.append("--capture_all") + command.append( + f"from pickle_dependencies import get_all_requirements;" + f'get_all_requirements("{dest_path}")' + ) + else: + command.append( + f"from pickle_dependencies import get_requirements_for_pkl_file;" + f'get_requirements_for_pkl_file("{pkl_path}", "{dest_path}")' + ) subprocess.run( command, env={"SETUPTOOLS_USE_DISTUTILS": "stdlib"}, check=True, + cwd=site_packages_dir + pickle_command_dir, ) with open(path, "r") as f: diff --git a/src/sagemaker/serve/detector/pickle_dependencies.py b/src/sagemaker/serve/detector/pickle_dependencies.py index 5a1cd43869..8f9da917fd 100644 --- a/src/sagemaker/serve/detector/pickle_dependencies.py +++ b/src/sagemaker/serve/detector/pickle_dependencies.py @@ -3,7 +3,6 @@ from __future__ import absolute_import from pathlib import Path from typing import List -import argparse import email.parser import email.policy import json @@ -129,32 +128,3 @@ def get_all_requirements(dest: Path): version = package_info.get("version") out.write(f"{name}=={version}\n") - - -def parse_args(): - """Placeholder docstring""" - parser = argparse.ArgumentParser( - prog="pkl_requirements", description="Generates a requirements.txt for a cloudpickle file" - ) - parser.add_argument("--pkl_path", required=True, help="path of the pkl file") - parser.add_argument("--dest", required=True, help="path of the destination requirements.txt") - parser.add_argument( - "--capture_all", - action="store_true", - help="capture all dependencies in current environment", - ) - args = parser.parse_args() - return (Path(args.pkl_path), Path(args.dest), args.capture_all) - - -def main(): - """Placeholder docstring""" - pkl_path, dest, capture_all = parse_args() - if capture_all: - get_all_requirements(dest) - else: - get_requirements_for_pkl_file(pkl_path, dest) - - -if __name__ == "__main__": - main() diff --git a/src/sagemaker/serve/model_format/mlflow/constants.py b/src/sagemaker/serve/model_format/mlflow/constants.py index d7ddcd9ef0..ff7553ea5f 100644 --- a/src/sagemaker/serve/model_format/mlflow/constants.py +++ b/src/sagemaker/serve/model_format/mlflow/constants.py @@ -18,6 +18,7 @@ "py38": "1.12.1", "py39": "1.13.1", "py310": "2.2.0", + "py311": "2.3.0", } MODEL_PACKAGE_ARN_REGEX = ( r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/(.*?)(?:/(\d+))?$" diff --git a/src/sagemaker/serve/model_server/multi_model_server/inference.py b/src/sagemaker/serve/model_server/multi_model_server/inference.py index 595b9d9c39..9361765da0 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/inference.py +++ b/src/sagemaker/serve/model_server/multi_model_server/inference.py @@ -21,7 +21,7 @@ METADATA_PATH = Path(__file__).parent.joinpath("metadata.json") -def model_fn(model_dir): +def model_fn(model_dir, context=None): """Overrides default method for loading a model""" shared_libs_path = Path(model_dir + "/shared_libs") @@ -40,16 +40,36 @@ def model_fn(model_dir): return partial(inference_spec.invoke, model=inference_spec.load(model_dir)) -def input_fn(input_data, content_type): +def input_fn(input_data, content_type, context=None): """Deserializes the bytes that were received from the model server""" try: if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data), content_type + ( + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) + ), + content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data), content_type[0] + ( + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) + ), + content_type[0], ) # Check if preprocess method is defined and call it @@ -62,12 +82,12 @@ def input_fn(input_data, content_type): raise Exception("Encountered error in deserialize_request.") from e -def predict_fn(input_data, predict_callable): +def predict_fn(input_data, predict_callable, context=None): """Invokes the model that is taken in by model server""" return predict_callable(input_data) -def output_fn(predictions, accept_type): +def output_fn(predictions, accept_type, context=None): """Prediction is serialized to bytes and sent back to the customer""" try: if hasattr(inference_spec, "postprocess"): diff --git a/src/sagemaker/serve/model_server/multi_model_server/prepare.py b/src/sagemaker/serve/model_server/multi_model_server/prepare.py index 48cf5c878a..e3abc70dd6 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/prepare.py +++ b/src/sagemaker/serve/model_server/multi_model_server/prepare.py @@ -84,7 +84,8 @@ def prepare_for_mms( image_uri: str, inference_spec: InferenceSpec = None, ) -> str: - """Prepares for InferenceSpec using model_path, writes inference.py, and captures dependencies to generate secret_key. + """Prepares for InferenceSpec using model_path, writes inference.py, \ + and captures dependencies to generate secret_key. Args:to model_path (str) : Argument diff --git a/src/sagemaker/serve/model_server/torchserve/inference.py b/src/sagemaker/serve/model_server/torchserve/inference.py index cad94cc817..058103a1fd 100644 --- a/src/sagemaker/serve/model_server/torchserve/inference.py +++ b/src/sagemaker/serve/model_server/torchserve/inference.py @@ -67,11 +67,31 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data), content_type + ( + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) + ), + content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data), content_type[0] + ( + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) + ), + content_type[0], ) # Check if preprocess method is defined and call it diff --git a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py index 4e82ec66b2..49cec5aab5 100644 --- a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py +++ b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py @@ -70,11 +70,31 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): return schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data), content_type + ( + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) + ), + content_type, ) else: return schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data), content_type[0] + ( + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) + ), + content_type[0], ) except Exception as e: raise Exception("Encountered error in deserialize_request.") from e diff --git a/src/sagemaker/serve/utils/conda_in_process.yml b/src/sagemaker/serve/utils/conda_in_process.yml index 61badaa52f..1f3fe322ef 100644 --- a/src/sagemaker/serve/utils/conda_in_process.yml +++ b/src/sagemaker/serve/utils/conda_in_process.yml @@ -12,15 +12,15 @@ dependencies: - boto3>=1.34.142,<2.0 - cloudpickle==2.2.1 - google-pasta - - numpy>=1.9.0,<2.0 + - numpy==1.26.4 - protobuf>=3.12,<5.0 - smdebug_rulesconfig==1.0.1 - importlib-metadata>=1.4.0,<7.0 - - packaging>=20.0 + - packaging>=23.0,<25 - pandas - pathos - schema - - PyYAML~=6.0 + - PyYAML>=6.0.1 - jsonschema - platformdirs - tblib>=1.7.0,<4 @@ -43,7 +43,7 @@ dependencies: - colorama>=0.4.4 - contextlib2>=21.6.0 - decorator>=5.1.1 - - dill>=0.3.6 + - dill>=0.3.9 - docutils>=0.16 - entrypoints>=0.4 - filelock>=3.11.0 @@ -82,7 +82,7 @@ dependencies: - python-dateutil>=2.8.2 - pytz>=2023.3 - pytz-deprecation-shim>=0.1.0.post0 - - pyyaml>=5.4.1 + - pyyaml>=6.0.1 - regex>=2023.3.23 - requests>=2.28.2 - rich>=13.3.4 diff --git a/src/sagemaker/serve/utils/in_process_requirements.txt b/src/sagemaker/serve/utils/in_process_requirements.txt index e356e1720d..da1fd8e617 100644 --- a/src/sagemaker/serve/utils/in_process_requirements.txt +++ b/src/sagemaker/serve/utils/in_process_requirements.txt @@ -11,7 +11,7 @@ cloudpickle==2.2.1 colorama>=0.4.4 contextlib2>=21.6.0 decorator>=5.1.1 -dill>=0.3.6 +dill>=0.3.9 docutils>=0.16 entrypoints>=0.4 filelock>=3.11.0 @@ -50,7 +50,7 @@ pyrsistent>=0.19.3 python-dateutil>=2.8.2 pytz>=2023.3 pytz-deprecation-shim>=0.1.0.post0 -pyyaml>=5.4.1 +pyyaml>=6.0.1 regex>=2023.3.23 requests>=2.28.2 rich>=13.3.4 diff --git a/src/sagemaker/serve/utils/telemetry_logger.py b/src/sagemaker/serve/utils/telemetry_logger.py index a1a0408718..c02fe9bf78 100644 --- a/src/sagemaker/serve/utils/telemetry_logger.py +++ b/src/sagemaker/serve/utils/telemetry_logger.py @@ -19,7 +19,7 @@ from sagemaker import Session, exceptions from sagemaker.serve.mode.function_pointers import Mode -from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH +from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH, MLFLOW_TRACKING_ARN from sagemaker.serve.utils.exceptions import ModelBuilderException from sagemaker.serve.utils.lineage_constants import ( MLFLOW_LOCAL_PATH, @@ -144,6 +144,9 @@ def wrapper(self, *args, **kwargs): mlflow_model_path = self.model_metadata[MLFLOW_MODEL_PATH] mlflow_model_path_type = _get_mlflow_model_path_type(mlflow_model_path) extra += f"&x-mlflowModelPathType={MLFLOW_MODEL_PATH_CODE[mlflow_model_path_type]}" + mlflow_model_tracking_server_arn = self.model_metadata.get(MLFLOW_TRACKING_ARN) + if mlflow_model_tracking_server_arn is not None: + extra += f"&x-mlflowTrackingServerArn={mlflow_model_tracking_server_arn}" if getattr(self, "model_hub", False): extra += f"&x-modelHub={MODEL_HUB_TO_CODE[str(self.model_hub)]}" diff --git a/src/sagemaker/serve/utils/tuning.py b/src/sagemaker/serve/utils/tuning.py index b93c01b522..5a63cfe508 100644 --- a/src/sagemaker/serve/utils/tuning.py +++ b/src/sagemaker/serve/utils/tuning.py @@ -7,6 +7,7 @@ import collections from multiprocessing.pool import ThreadPool from math import ceil +from typing import Callable import pandas as pd from numpy import percentile, std from sagemaker.serve.model_server.djl_serving.utils import _tokens_from_chars, _tokens_from_words @@ -152,7 +153,7 @@ def _tokens_per_second(generated_text: str, max_token_length: int, latency: floa return min(est_tokens, max_token_length) / latency -def _timed_invoke(predict: callable, sample_input: object) -> tuple: +def _timed_invoke(predict: Callable, sample_input: object) -> tuple: """Placeholder docstring""" start_timer = perf_counter() response = predict(sample_input) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 04a7326557..797d559348 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4347,11 +4347,59 @@ def submit(request): if model_package_group_name is not None and not model_package_group_name.startswith( "arn:" ): - _create_resource( - lambda: self.sagemaker_client.create_model_package_group( - ModelPackageGroupName=request["ModelPackageGroupName"] + is_model_package_group_present = False + try: + model_package_groups_response = self.search( + resource="ModelPackageGroup", + search_expression={ + "Filters": [ + { + "Name": "ModelPackageGroupName", + "Value": request["ModelPackageGroupName"], + "Operator": "Equals", + } + ], + }, + ) + if len(model_package_groups_response.get("Results")) > 0: + is_model_package_group_present = True + except Exception: # pylint: disable=W0703 + model_package_groups = [] + model_package_groups_response = self.sagemaker_client.list_model_package_groups( + NameContains=request["ModelPackageGroupName"], + ) + model_package_groups = ( + model_package_groups + + model_package_groups_response["ModelPackageGroupSummaryList"] + ) + next_token = model_package_groups_response.get("NextToken") + + while next_token is not None and next_token != "": + model_package_groups_response = ( + self.sagemaker_client.list_model_package_groups( + NameContains=request["ModelPackageGroupName"], NextToken=next_token + ) + ) + model_package_groups = ( + model_package_groups + + model_package_groups_response["ModelPackageGroupSummaryList"] + ) + next_token = model_package_groups_response.get("NextToken") + + filtered_model_package_group = list( + filter( + lambda mpg: mpg.get("ModelPackageGroupName") + == request["ModelPackageGroupName"], + model_package_groups, + ) + ) + is_model_package_group_present = len(filtered_model_package_group) > 0 + if not is_model_package_group_present: + _create_resource( + lambda: self.sagemaker_client.create_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"] + ) ) - ) if "SourceUri" in request and request["SourceUri"] is not None: # Remove inference spec from request if the # given source uri can lead to auto-population of it @@ -4415,6 +4463,49 @@ def wait_for_model_package(self, model_package_name, poll=5): ) return desc + def get_most_recently_created_approved_model_package(self, model_package_group_name): + """Returns the most recently created and Approved model package in a model package group + + Args: + model_package_group_name (str): Name or Arn of the model package group + + Returns: + dict: Returns a "sagemaker.model.ModelPackage" value. + """ + + approved_model_packages = self.sagemaker_client.list_model_packages( + ModelPackageGroupName=model_package_group_name, + ModelApprovalStatus="Approved", + SortBy="CreationTime", + SortOrder="Descending", + MaxResults=1, + ) + next_token = approved_model_packages.get("NextToken") + + while ( + len(approved_model_packages.get("ModelPackageSummaryList")) == 0 + and next_token is not None + and next_token != "" + ): + approved_model_packages = self.sagemaker_client.list_model_packages( + ModelPackageGroupName=model_package_group_name, + ModelApprovalStatus="Approved", + SortBy="CreationTime", + SortOrder="Descending", + MaxResults=1, + NextToken=next_token, + ) + next_token = approved_model_packages.get("NextToken") + + if len(approved_model_packages.get("ModelPackageSummaryList")) == 0: + return None + + return sagemaker.model.ModelPackage( + model_package_arn=approved_model_packages.get("ModelPackageSummaryList")[0].get( + "ModelPackageArn" + ) + ) + def describe_model(self, name): """Calls the DescribeModel API for the given model name. @@ -4440,6 +4531,10 @@ def create_endpoint_config( model_data_download_timeout=None, container_startup_health_check_timeout=None, explainer_config_dict=None, + async_inference_config_dict=None, + serverless_inference_config_dict=None, + routing_config: Optional[Dict[str, Any]] = None, + inference_ami_version: Optional[str] = None, ): """Create an Amazon SageMaker endpoint configuration. @@ -4477,6 +4572,30 @@ def create_endpoint_config( -inference-algo-ping-requests explainer_config_dict (dict): Specifies configuration to enable explainers. Default: None. + async_inference_config_dict (dict): Specifies + configuration related to async endpoint. Use this configuration when trying + to create async endpoint and make async inference. If empty config object + passed through, will use default config to deploy async endpoint. Deploy a + real-time endpoint if it's None. (default: None). + serverless_inference_config_dict (dict): + Specifies configuration related to serverless endpoint. Use this configuration + when trying to create serverless endpoint and make serverless inference. If + empty object passed through, will use pre-defined values in + ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an + instance based endpoint if it's None. (default: None). + routing_config (Optional[Dict[str, Any]): Settings the control how the endpoint routes + incoming traffic to the instances that the endpoint hosts. + Currently, support dictionary key ``RoutingStrategy``. + + .. code:: python + + { + "RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM + } + inference_ami_version (Optional [str]): + Specifies an option from a collection of preconfigured + Amazon Machine Image (AMI) images. For a full list of options, see: + https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] @@ -4496,9 +4615,12 @@ def create_endpoint_config( instance_type, initial_instance_count, accelerator_type=accelerator_type, + serverless_inference_config=serverless_inference_config_dict, volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, + routing_config=routing_config, + inference_ami_version=inference_ami_version, ) production_variants = [provided_production_variant] # Currently we just inject CoreDumpConfig.KmsKeyId from the config for production variant. @@ -4538,6 +4660,14 @@ def create_endpoint_config( ) request["DataCaptureConfig"] = inferred_data_capture_config_dict + if async_inference_config_dict is not None: + inferred_async_inference_config_dict = update_nested_dictionary_with_values_from_config( + async_inference_config_dict, + ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH, + sagemaker_session=self, + ) + request["AsyncInferenceConfig"] = inferred_async_inference_config_dict + if explainer_config_dict is not None: request["ExplainerConfig"] = explainer_config_dict @@ -5286,7 +5416,7 @@ def get_tagging_resources(self, tag_filters, resource_type_filters): resource_tag_response = self.resource_group_tagging_client.get_resources( TagFilters=tag_filters, ResourceTypeFilters=resource_type_filters, - NextToken=next_token, + PaginationToken=next_token, ) resource_list = resource_list + resource_tag_response["ResourceTagMappingList"] next_token = resource_tag_response.get("PaginationToken") diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index ae66bc8338..586e50da88 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -83,8 +83,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index c3727b2fb5..a9b0e2e8f0 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, List, Dict +from typing import Callable, Union, Optional, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -92,7 +92,7 @@ def __init__( framework_version: Optional[str] = None, py_version: str = "py3", image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = SKLearnPredictor, + predictor_cls: Optional[Callable] = SKLearnPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): @@ -122,7 +122,7 @@ def __init__( If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If ``image_uri`` is also ``None``, then a ``ValueError`` will be raised. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. diff --git a/src/sagemaker/telemetry/constants.py b/src/sagemaker/telemetry/constants.py index 2108ff9fd6..cb83a78279 100644 --- a/src/sagemaker/telemetry/constants.py +++ b/src/sagemaker/telemetry/constants.py @@ -42,3 +42,40 @@ class Status(Enum): def __str__(self): # pylint: disable=E0307 """Return the status name.""" return self.name + + +class Region(str, Enum): + """Telemetry: List of all supported AWS regions.""" + + # Classic + US_EAST_1 = "us-east-1" # IAD + US_EAST_2 = "us-east-2" # CMH + US_WEST_1 = "us-west-1" # SFO + US_WEST_2 = "us-west-2" # PDX + AP_NORTHEAST_1 = "ap-northeast-1" # NRT + AP_NORTHEAST_2 = "ap-northeast-2" # ICN + AP_NORTHEAST_3 = "ap-northeast-3" # KIX + AP_SOUTH_1 = "ap-south-1" # BOM + AP_SOUTHEAST_1 = "ap-southeast-1" # SIN + AP_SOUTHEAST_2 = "ap-southeast-2" # SYD + CA_CENTRAL_1 = "ca-central-1" # YUL + EU_CENTRAL_1 = "eu-central-1" # FRA + EU_NORTH_1 = "eu-north-1" # ARN + EU_WEST_1 = "eu-west-1" # DUB + EU_WEST_2 = "eu-west-2" # LHR + EU_WEST_3 = "eu-west-3" # CDG + SA_EAST_1 = "sa-east-1" # GRU + # Opt-in + AP_EAST_1 = "ap-east-1" # HKG + AP_SOUTHEAST_3 = "ap-southeast-3" # CGK + AF_SOUTH_1 = "af-south-1" # CPT + EU_SOUTH_1 = "eu-south-1" # MXP + ME_SOUTH_1 = "me-south-1" # BAH + MX_CENTRAL_1 = "mx-central-1" # QRO + AP_SOUTHEAST_7 = "ap-southeast-7" # BKK + AP_SOUTH_2 = "ap-south-2" # HYD + AP_SOUTHEAST_4 = "ap-southeast-4" # MEL + EU_CENTRAL_2 = "eu-central-2" # ZRH + EU_SOUTH_2 = "eu-south-2" # ZAZ + IL_CENTRAL_1 = "il-central-1" # TLV + ME_CENTRAL_1 = "me-central-1" # DXB diff --git a/src/sagemaker/telemetry/telemetry_logging.py b/src/sagemaker/telemetry/telemetry_logging.py index b45550b2c2..b0ecedee4c 100644 --- a/src/sagemaker/telemetry/telemetry_logging.py +++ b/src/sagemaker/telemetry/telemetry_logging.py @@ -27,6 +27,7 @@ from sagemaker.telemetry.constants import ( Feature, Status, + Region, DEFAULT_AWS_REGION, ) from sagemaker.user_agent import SDK_VERSION, process_studio_metadata_file @@ -189,8 +190,16 @@ def _send_telemetry_request( """Make GET request to an empty object in S3 bucket""" try: accountId = _get_accountId(session) if session else "NotAvailable" - # telemetry will be sent to us-west-2 if no session availale - region = _get_region_or_default(session) if session else DEFAULT_AWS_REGION + region = _get_region_or_default(session) + + try: + Region(region) # Validate the region + except ValueError: + logger.warning( + "Region not found in supported regions. Telemetry request will not be emitted." + ) + return + url = _construct_url( accountId, region, @@ -268,6 +277,7 @@ def _get_region_or_default(session): def _get_default_sagemaker_session(): """Return the default sagemaker session""" + boto_session = boto3.Session(region_name=DEFAULT_AWS_REGION) sagemaker_session = Session(boto_session=boto_session) diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index fe20994e20..b384cbbbb5 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, List, Dict +from typing import Callable, Union, Optional, List, Dict import sagemaker from sagemaker import image_uris, s3, ModelMetrics @@ -62,9 +62,9 @@ def __init__( manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. - serializer (callable): Optional. Default serializes input data to + serializer (Callable): Optional. Default serializes input data to json. Handles dicts, lists, and numpy arrays. - deserializer (callable): Optional. Default parses the response using + deserializer (Callable): Optional. Default parses the response using ``json.load(...)``. model_name (str): Optional. The name of the SavedModel model that should handle the request. If not specified, the endpoint's @@ -146,7 +146,7 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[str] = None, container_log_level: Optional[int] = None, - predictor_cls: callable = TensorFlowPredictor, + predictor_cls: Optional[Callable] = TensorFlowPredictor, **kwargs, ): """Initialize a Model. @@ -174,7 +174,7 @@ def __init__( container_log_level (int): Log level to use within the container (default: logging.ERROR). Valid values are defined in the Python logging module. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -358,6 +358,7 @@ def deploy( container_startup_health_check_timeout=None, inference_recommendation_id=None, explainer_config=None, + update_endpoint: Optional[bool] = False, **kwargs, ): """Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``.""" @@ -383,6 +384,7 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, inference_recommendation_id=inference_recommendation_id, explainer_config=explainer_config, + update_endpoint=update_endpoint, **kwargs, ) diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 4b0f38f36f..fa8f9b8555 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -18,21 +18,20 @@ import inspect import json import logging - from enum import Enum -from typing import Union, Dict, Optional, List, Set +from typing import Dict, List, Optional, Set, Union import sagemaker from sagemaker.amazon.amazon_estimator import ( - RecordSet, AmazonAlgorithmEstimatorBase, FileSystemRecordSet, + RecordSet, ) from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.analytics import HyperparameterTuningJobAnalytics from sagemaker.deprecations import removed_function -from sagemaker.estimator import Framework, EstimatorBase -from sagemaker.inputs import TrainingInput, FileSystemInput +from sagemaker.estimator import EstimatorBase, Framework +from sagemaker.inputs import FileSystemInput, TrainingInput from sagemaker.job import _Job from sagemaker.jumpstart.utils import ( add_jumpstart_uri_tags, @@ -44,18 +43,17 @@ IntegerParameter, ParameterRange, ) -from sagemaker.workflow.entities import PipelineVariable -from sagemaker.workflow.pipeline_context import runnable_by_pipeline - from sagemaker.session import Session from sagemaker.utils import ( + Tags, base_from_name, base_name_from_image, + format_tags, name_from_base, to_string, - format_tags, - Tags, ) +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.pipeline_context import runnable_by_pipeline AMAZON_ESTIMATOR_MODULE = "sagemaker" AMAZON_ESTIMATOR_CLS_NAMES = { @@ -133,15 +131,12 @@ def __init__( if warm_start_type not in list(WarmStartTypes): raise ValueError( - "Invalid type: {}, valid warm start types are: {}".format( - warm_start_type, list(WarmStartTypes) - ) + f"Invalid type: {warm_start_type}, " + f"valid warm start types are: {list(WarmStartTypes)}" ) if not parents: - raise ValueError( - "Invalid parents: {}, parents should not be None/empty".format(parents) - ) + raise ValueError(f"Invalid parents: {parents}, parents should not be None/empty") self.type = warm_start_type self.parents = set(parents) @@ -1455,9 +1450,7 @@ def _get_best_training_job(self): return tuning_job_describe_result["BestTrainingJob"] except KeyError: raise Exception( - "Best training job not available for tuning job: {}".format( - self.latest_tuning_job.name - ) + f"Best training job not available for tuning job: {self.latest_tuning_job.name}" ) def _ensure_last_tuning_job(self): @@ -1920,8 +1913,11 @@ def create( :meth:`~sagemaker.tuner.HyperparameterTuner.fit` method launches. If not specified, a default job name is generated, based on the training image name and current timestamp. - strategy (str): Strategy to be used for hyperparameter estimations - (default: 'Bayesian'). + strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations. + More information about different strategies: + https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html. + Available options are: 'Bayesian', 'Random', 'Hyperband', + 'Grid' (default: 'Bayesian') strategy_config (dict): The configuration for a training job launched by a hyperparameter tuning job. completion_criteria_config (dict): The configuration for tuning job completion criteria. @@ -2080,21 +2076,19 @@ def _validate_dict_argument(cls, name, value, allowed_keys, require_same_keys=Fa return if not isinstance(value, dict): - raise ValueError( - "Argument '{}' must be a dictionary using {} as keys".format(name, allowed_keys) - ) + raise ValueError(f"Argument '{name}' must be a dictionary using {allowed_keys} as keys") value_keys = sorted(value.keys()) if require_same_keys: if value_keys != allowed_keys: raise ValueError( - "The keys of argument '{}' must be the same as {}".format(name, allowed_keys) + f"The keys of argument '{name}' must be the same as {allowed_keys}" ) else: if not set(value_keys).issubset(set(allowed_keys)): raise ValueError( - "The keys of argument '{}' must be a subset of {}".format(name, allowed_keys) + f"The keys of argument '{name}' must be a subset of {allowed_keys}" ) def _add_estimator( diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index e8602de8d7..1a75a3a5cc 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -397,8 +397,7 @@ def download_folder(bucket_name, prefix, target, sagemaker_session): sagemaker_session (sagemaker.session.Session): a sagemaker session to interact with S3. """ - boto_session = sagemaker_session.boto_session - s3 = boto_session.resource("s3", region_name=boto_session.region_name) + s3 = sagemaker_session.s3_resource prefix = prefix.lstrip("/") @@ -726,7 +725,7 @@ def retry_with_backoff(callable_func, num_attempts=8, botocore_client_error_code """Retry with backoff until maximum attempts are reached Args: - callable_func (callable): The callable function to retry. + callable_func (Callable): The callable function to retry. num_attempts (int): The maximum number of attempts to retry.(Default: 8) botocore_client_error_code (str): The specific Botocore ClientError exception error code on which to retry on. diff --git a/src/sagemaker/workflow/notebook_job_step.py b/src/sagemaker/workflow/notebook_job_step.py index 8a1dd6bc53..ca0ecac15b 100644 --- a/src/sagemaker/workflow/notebook_job_step.py +++ b/src/sagemaker/workflow/notebook_job_step.py @@ -13,49 +13,33 @@ """The notebook job step definitions for workflow.""" from __future__ import absolute_import +import os import re import shutil -import os +from typing import Dict, List, Optional, Union -from typing import ( - List, - Optional, - Union, - Dict, +from sagemaker import vpc_utils +from sagemaker.config.config_schema import ( + NOTEBOOK_JOB_ROLE_ARN, + NOTEBOOK_JOB_S3_KMS_KEY_ID, + NOTEBOOK_JOB_S3_ROOT_URI, + NOTEBOOK_JOB_VOLUME_KMS_KEY_ID, + NOTEBOOK_JOB_VPC_CONFIG_SECURITY_GROUP_IDS, + NOTEBOOK_JOB_VPC_CONFIG_SUBNETS, ) - +from sagemaker.s3 import S3Uploader +from sagemaker.s3_utils import s3_path_join +from sagemaker.session import get_execution_role +from sagemaker.utils import Tags, _tmpdir, format_tags, name_from_base, resolve_value_from_config +from sagemaker.workflow.entities import PipelineVariable, RequestType from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.functions import Join from sagemaker.workflow.properties import Properties from sagemaker.workflow.retry import RetryPolicy -from sagemaker.workflow.steps import ( - Step, - ConfigurableRetryStep, - StepTypeEnum, -) from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.step_outputs import StepOutput - -from sagemaker.workflow.entities import ( - RequestType, - PipelineVariable, -) +from sagemaker.workflow.steps import ConfigurableRetryStep, Step, StepTypeEnum from sagemaker.workflow.utilities import _collect_parameters, load_step_compilation_context -from sagemaker.session import get_execution_role - -from sagemaker.s3_utils import s3_path_join -from sagemaker.s3 import S3Uploader -from sagemaker.utils import _tmpdir, name_from_base, resolve_value_from_config, format_tags, Tags -from sagemaker import vpc_utils - -from sagemaker.config.config_schema import ( - NOTEBOOK_JOB_ROLE_ARN, - NOTEBOOK_JOB_S3_ROOT_URI, - NOTEBOOK_JOB_S3_KMS_KEY_ID, - NOTEBOOK_JOB_VOLUME_KMS_KEY_ID, - NOTEBOOK_JOB_VPC_CONFIG_SUBNETS, - NOTEBOOK_JOB_VPC_CONFIG_SECURITY_GROUP_IDS, -) # disable E1101 as collect_parameters decorator sets the attributes @@ -374,7 +358,7 @@ def _prepare_env_variables(self): execution mechanism. """ - job_envs = self.environment_variables if self.environment_variables else {} + job_envs = dict(self.environment_variables or {}) system_envs = { "AWS_DEFAULT_REGION": self._region_from_session, "SM_JOB_DEF_VERSION": "1.0", diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 62167b96e7..9749014531 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -383,7 +383,11 @@ def start( ) def definition(self) -> str: - """Converts a request structure to string representation for workflow service calls.""" + """Converts a request structure to string representation for workflow service calls. + + Returns: + A JSON formatted string of pipeline definition. + """ compiled_steps = StepsCompiler( pipeline_name=self.name, sagemaker_session=self.sagemaker_session, diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index a80b5440c7..dbc37371db 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -18,7 +18,6 @@ from enum import Enum from typing import Dict, List, Set, Union, Optional, Any, TYPE_CHECKING -from urllib.parse import urlparse import attr @@ -465,6 +464,7 @@ def __init__( self.step_args = step_args self.estimator = estimator self.inputs = inputs + self.job_name = None self._properties = Properties( step_name=name, step=self, shape_name="DescribeTrainingJobResponse" @@ -493,19 +493,6 @@ def __init__( DeprecationWarning, ) - self.job_name = None - if estimator and (estimator.source_dir or estimator.entry_point): - # By default, `Estimator` will upload the local code to an S3 path - # containing a timestamp. This causes cache misses whenever a - # pipeline is updated, even if the underlying script hasn't changed. - # To avoid this, hash the contents of the training script and include it - # in the `job_name` passed to the `Estimator`, which will be used - # instead of the timestamped path. - if not is_pipeline_variable(estimator.source_dir) and not is_pipeline_variable( - estimator.entry_point - ): - self.job_name = self._generate_code_upload_path() - @property def arguments(self) -> RequestType: """The arguments dictionary that is used to call `create_training_job`. @@ -554,26 +541,6 @@ def to_request(self) -> RequestType: return request_dict - def _generate_code_upload_path(self) -> str or None: - """Generate an upload path for local training scripts based on their content.""" - from sagemaker.workflow.utilities import hash_files_or_dirs - - if self.estimator.source_dir: - source_dir_url = urlparse(self.estimator.source_dir) - if source_dir_url.scheme == "" or source_dir_url.scheme == "file": - code_hash = hash_files_or_dirs( - [self.estimator.source_dir] + self.estimator.dependencies - ) - return f"{self.name}-{code_hash}"[:1024] - elif self.estimator.entry_point: - entry_point_url = urlparse(self.estimator.entry_point) - if entry_point_url.scheme == "" or entry_point_url.scheme == "file": - code_hash = hash_files_or_dirs( - [self.estimator.entry_point] + self.estimator.dependencies - ) - return f"{self.name}-{code_hash}"[:1024] - return None - class CreateModelStep(ConfigurableRetryStep): """`CreateModelStep` for SageMaker Pipelines Workflows.""" @@ -645,6 +612,7 @@ def arguments(self) -> RequestType: request_dict = self.step_args else: if isinstance(self.model, PipelineModel): + self.model._init_sagemaker_session_if_does_not_exist() request_dict = self.model.sagemaker_session._create_model_request( name="", role=self.model.role, @@ -653,6 +621,7 @@ def arguments(self) -> RequestType: enable_network_isolation=self.model.enable_network_isolation, ) else: + self.model._init_sagemaker_session_if_does_not_exist() request_dict = self.model.sagemaker_session._create_model_request( name="", role=self.model.role, @@ -893,16 +862,6 @@ def __init__( "code argument has to be a valid S3 URI or local file path " + "rather than a pipeline variable" ) - code_url = urlparse(code) - if code_url.scheme == "" or code_url.scheme == "file": - # By default, `Processor` will upload the local code to an S3 path - # containing a timestamp. This causes cache misses whenever a - # pipeline is updated, even if the underlying script hasn't changed. - # To avoid this, hash the contents of the script and include it - # in the `job_name` passed to the `Processor`, which will be used - # instead of the timestamped path. - self.job_name = self._generate_code_upload_path() - warnings.warn( ( 'We are deprecating the instantiation of ProcessingStep using "processor".' diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index 4ef5ad5dd2..4fc98eb29a 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -268,29 +268,29 @@ def get_config_hash(step: Entity): def hash_object(obj) -> str: - """Get the MD5 hash of an object. + """Get the SHA256 hash of an object. Args: obj (dict): The object Returns: - str: The MD5 hash of the object + str: The SHA256 hash of the object """ - return hashlib.md5(str(obj).encode()).hexdigest() + return hashlib.sha256(str(obj).encode()).hexdigest() def hash_file(path: str) -> str: - """Get the MD5 hash of a file. + """Get the SHA256 hash of a file. Args: path (str): The local path for the file. Returns: - str: The MD5 hash of the file. + str: The SHA256 hash of the file. """ - return _hash_file(path, hashlib.md5()).hexdigest() + return _hash_file(path, hashlib.sha256()).hexdigest() def hash_files_or_dirs(paths: List[str]) -> str: - """Get the MD5 hash of the contents of a list of files or directories. + """Get the SHA256 hash of the contents of a list of files or directories. Hash is changed if: * input list is changed @@ -301,58 +301,58 @@ def hash_files_or_dirs(paths: List[str]) -> str: Args: paths: List of file or directory paths Returns: - str: The MD5 hash of the list of files or directories. + str: The SHA256 hash of the list of files or directories. """ - md5 = hashlib.md5() + sha256 = hashlib.sha256() for path in sorted(paths): - md5 = _hash_file_or_dir(path, md5) - return md5.hexdigest() + sha256 = _hash_file_or_dir(path, sha256) + return sha256.hexdigest() -def _hash_file_or_dir(path: str, md5: Hash) -> Hash: +def _hash_file_or_dir(path: str, sha256: Hash) -> Hash: """Updates the inputted Hash with the contents of the current path. Args: path: path of file or directory Returns: - str: The MD5 hash of the file or directory + str: The SHA256 hash of the file or directory """ if isinstance(path, str) and path.lower().startswith("file://"): path = unquote(urlparse(path).path) - md5.update(path.encode()) + sha256.update(path.encode()) if Path(path).is_dir(): - md5 = _hash_dir(path, md5) + sha256 = _hash_dir(path, sha256) elif Path(path).is_file(): - md5 = _hash_file(path, md5) - return md5 + sha256 = _hash_file(path, sha256) + return sha256 -def _hash_dir(directory: Union[str, Path], md5: Hash) -> Hash: +def _hash_dir(directory: Union[str, Path], sha256: Hash) -> Hash: """Updates the inputted Hash with the contents of the current path. Args: directory: path of the directory Returns: - str: The MD5 hash of the directory + str: The SHA256 hash of the directory """ if not Path(directory).is_dir(): raise ValueError(str(directory) + " is not a valid directory") for path in sorted(Path(directory).iterdir()): - md5.update(path.name.encode()) + sha256.update(path.name.encode()) if path.is_file(): - md5 = _hash_file(path, md5) + sha256 = _hash_file(path, sha256) elif path.is_dir(): - md5 = _hash_dir(path, md5) - return md5 + sha256 = _hash_dir(path, sha256) + return sha256 -def _hash_file(file: Union[str, Path], md5: Hash) -> Hash: +def _hash_file(file: Union[str, Path], sha256: Hash) -> Hash: """Updates the inputted Hash with the contents of the current path. Args: file: path of the file Returns: - str: The MD5 hash of the file + str: The SHA256 hash of the file """ if isinstance(file, str) and file.lower().startswith("file://"): file = unquote(urlparse(file).path) @@ -363,8 +363,8 @@ def _hash_file(file: Union[str, Path], md5: Hash) -> Hash: data = f.read(BUF_SIZE) if not data: break - md5.update(data) - return md5 + sha256.update(data) + return sha256 def validate_step_args_input( diff --git a/src/sagemaker/xgboost/estimator.py b/src/sagemaker/xgboost/estimator.py index 2921dbc2db..9385acf745 100644 --- a/src/sagemaker/xgboost/estimator.py +++ b/src/sagemaker/xgboost/estimator.py @@ -78,8 +78,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index ea532b4c39..f4797c79e7 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -91,7 +91,7 @@ def __init__( framework_version: str = None, image_uri: Optional[Union[str, PipelineVariable]] = None, py_version: str = "py3", - predictor_cls: callable = XGBoostPredictor, + predictor_cls: Optional[Callable] = XGBoostPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): @@ -113,8 +113,8 @@ def __init__( (default: 'py3'). framework_version (str): XGBoost version you want to use for executing your model training code. - predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create - a predictor with an endpoint name and SageMaker ``Session``. + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call + to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. model_server_workers (int or PipelineVariable): Optional. The number of worker processes diff --git a/tests/conftest.py b/tests/conftest.py index db890d1a14..7557c87fbe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -254,6 +254,8 @@ def mxnet_eia_latest_py_version(): @pytest.fixture(scope="module", params=["py2", "py3"]) def pytorch_training_py_version(pytorch_training_version, request): + if Version(pytorch_training_version) >= Version("2.6"): + return "py312" if Version(pytorch_training_version) >= Version("2.3"): return "py311" elif Version(pytorch_training_version) >= Version("2.0"): @@ -270,7 +272,9 @@ def pytorch_training_py_version(pytorch_training_version, request): @pytest.fixture(scope="module", params=["py2", "py3"]) def pytorch_inference_py_version(pytorch_inference_version, request): - if Version(pytorch_inference_version) >= Version("2.3"): + if Version(pytorch_inference_version) >= Version("2.6"): + return "py312" + elif Version(pytorch_inference_version) >= Version("2.3"): return "py311" elif Version(pytorch_inference_version) >= Version("2.0"): return "py310" @@ -293,6 +297,8 @@ def huggingface_pytorch_training_version(huggingface_training_version): @pytest.fixture(scope="module") def huggingface_pytorch_training_py_version(huggingface_pytorch_training_version): + if Version(huggingface_pytorch_training_version) >= Version("2.3"): + return "py311" if Version(huggingface_pytorch_training_version) >= Version("2.0"): return "py310" elif Version(huggingface_pytorch_training_version) >= Version("1.13"): @@ -355,6 +361,8 @@ def huggingface_training_compiler_pytorch_py_version( def huggingface_pytorch_latest_training_py_version( huggingface_training_pytorch_latest_version, ): + if Version(huggingface_training_pytorch_latest_version) >= Version("2.3"): + return "py311" if Version(huggingface_training_pytorch_latest_version) >= Version("2.0"): return "py310" elif Version(huggingface_training_pytorch_latest_version) >= Version("1.13"): diff --git a/tests/data/modules/custom_drivers/driver.py b/tests/data/modules/custom_drivers/driver.py new file mode 100644 index 0000000000..3395b80da9 --- /dev/null +++ b/tests/data/modules/custom_drivers/driver.py @@ -0,0 +1,34 @@ +import json +import os +import subprocess +import sys + + +def main(): + driver_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + process_count_per_node = driver_config["process_count_per_node"] + assert process_count_per_node != None + + hps = json.loads(os.environ["SM_HPS"]) + assert hps != None + assert isinstance(hps, dict) + + source_dir = os.environ["SM_SOURCE_DIR"] + assert source_dir == "/opt/ml/input/data/code" + sm_drivers_dir = os.environ["SM_DISTRIBUTED_DRIVER_DIR"] + assert sm_drivers_dir == "/opt/ml/input/data/sm_drivers/distributed_drivers" + + entry_script = os.environ["SM_ENTRY_SCRIPT"] + assert entry_script != None + + python = sys.executable + + command = [python, entry_script] + print(f"Running command: {command}") + subprocess.run(command, check=True) + + +if __name__ == "__main__": + print("Running custom driver script") + main() + print("Finished running custom driver script") diff --git a/tests/data/modules/params_script/hyperparameters.json b/tests/data/modules/params_script/hyperparameters.json new file mode 100644 index 0000000000..f637288dbe --- /dev/null +++ b/tests/data/modules/params_script/hyperparameters.json @@ -0,0 +1,15 @@ +{ + "integer": 1, + "boolean": true, + "float": 3.14, + "string": "Hello World", + "list": [1, 2, 3], + "dict": { + "string": "value", + "integer": 3, + "float": 3.14, + "list": [1, 2, 3], + "dict": {"key": "value"}, + "boolean": true + } +} \ No newline at end of file diff --git a/tests/data/modules/params_script/hyperparameters.yaml b/tests/data/modules/params_script/hyperparameters.yaml new file mode 100644 index 0000000000..9e3011daf2 --- /dev/null +++ b/tests/data/modules/params_script/hyperparameters.yaml @@ -0,0 +1,19 @@ +integer: 1 +boolean: true +float: 3.14 +string: "Hello World" +list: + - 1 + - 2 + - 3 +dict: + string: value + integer: 3 + float: 3.14 + list: + - 1 + - 2 + - 3 + dict: + key: value + boolean: true \ No newline at end of file diff --git a/tests/data/modules/params_script/requirements.txt b/tests/data/modules/params_script/requirements.txt new file mode 100644 index 0000000000..3d2e72e354 --- /dev/null +++ b/tests/data/modules/params_script/requirements.txt @@ -0,0 +1 @@ +omegaconf diff --git a/tests/data/modules/params_script/train.py b/tests/data/modules/params_script/train.py index 8d3924a325..9b8cb2c82f 100644 --- a/tests/data/modules/params_script/train.py +++ b/tests/data/modules/params_script/train.py @@ -16,6 +16,9 @@ import argparse import json import os +from typing import List, Dict, Any +from dataclasses import dataclass +from omegaconf import OmegaConf EXPECTED_HYPERPARAMETERS = { "integer": 1, @@ -26,6 +29,7 @@ "dict": { "string": "value", "integer": 3, + "float": 3.14, "list": [1, 2, 3], "dict": {"key": "value"}, "boolean": True, @@ -117,7 +121,7 @@ def main(): assert isinstance(params["dict"], dict) params = json.loads(os.environ["SM_TRAINING_ENV"])["hyperparameters"] - print(params) + print(f"SM_TRAINING_ENV -> hyperparameters: {params}") assert params["string"] == EXPECTED_HYPERPARAMETERS["string"] assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"] assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"] @@ -132,9 +136,96 @@ def main(): assert isinstance(params["float"], float) assert isinstance(params["list"], list) assert isinstance(params["dict"], dict) - print(f"SM_TRAINING_ENV -> hyperparameters: {params}") - print("Test passed.") + # Local JSON - DictConfig OmegaConf + params = OmegaConf.load("hyperparameters.json") + + print(f"Local hyperparameters.json: {params}") + assert params.string == EXPECTED_HYPERPARAMETERS["string"] + assert params.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert params.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert params.float == EXPECTED_HYPERPARAMETERS["float"] + assert params.list == EXPECTED_HYPERPARAMETERS["list"] + assert params.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert params.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert params.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert params.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert params.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert params.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert params.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + + @dataclass + class DictConfig: + string: str + integer: int + boolean: bool + float: float + list: List[int] + dict: Dict[str, Any] + + @dataclass + class HPConfig: + string: str + integer: int + boolean: bool + float: float + list: List[int] + dict: DictConfig + + # Local JSON - Structured OmegaConf + hp_config: HPConfig = OmegaConf.merge( + OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.json") + ) + print(f"Local hyperparameters.json - Structured: {hp_config}") + assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"] + assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"] + assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"] + assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + + # Local YAML - Structured OmegaConf + hp_config: HPConfig = OmegaConf.merge( + OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.yaml") + ) + print(f"Local hyperparameters.yaml - Structured: {hp_config}") + assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"] + assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"] + assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"] + assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + print(f"hyperparameters.yaml -> hyperparameters: {hp_config}") + + # HP Dict - Structured OmegaConf + hp_dict = json.loads(os.environ["SM_HPS"]) + hp_config: HPConfig = OmegaConf.merge(OmegaConf.structured(HPConfig), OmegaConf.create(hp_dict)) + print(f"SM_HPS - Structured: {hp_config}") + assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"] + assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"] + assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"] + assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + print(f"SM_HPS -> hyperparameters: {hp_config}") if __name__ == "__main__": diff --git a/tests/data/modules/scripts/entry_script.py b/tests/data/modules/scripts/entry_script.py new file mode 100644 index 0000000000..3c972bd956 --- /dev/null +++ b/tests/data/modules/scripts/entry_script.py @@ -0,0 +1,19 @@ +import json +import os +import time + + +def main(): + hps = json.loads(os.environ["SM_HPS"]) + assert hps != None + print(f"Hyperparameters: {hps}") + + print("Running pseudo training script") + for epochs in range(hps["epochs"]): + print(f"Epoch: {epochs}") + time.sleep(1) + print("Finished running pseudo training script") + + +if __name__ == "__main__": + main() diff --git a/tests/data/pipeline/model_step/pytorch_mnist/requirements.txt b/tests/data/pipeline/model_step/pytorch_mnist/requirements.txt index 56d09228be..c25fca7e9f 100644 --- a/tests/data/pipeline/model_step/pytorch_mnist/requirements.txt +++ b/tests/data/pipeline/model_step/pytorch_mnist/requirements.txt @@ -1 +1 @@ -scipy>=1.8.1 +scipy>=1.11.3 diff --git a/tests/data/remote_function/requirements.txt b/tests/data/remote_function/requirements.txt index 0e99587e6e..44ce1d9331 100644 --- a/tests/data/remote_function/requirements.txt +++ b/tests/data/remote_function/requirements.txt @@ -1 +1 @@ -scipy==1.10.1 +scipy==1.11.3 diff --git a/tests/data/serve_resources/mlflow/pytorch/conda.yaml b/tests/data/serve_resources/mlflow/pytorch/conda.yaml index be61456197..b740d25b70 100644 --- a/tests/data/serve_resources/mlflow/pytorch/conda.yaml +++ b/tests/data/serve_resources/mlflow/pytorch/conda.yaml @@ -9,7 +9,7 @@ dependencies: - cffi==1.16.0 - cloudpickle==2.2.1 - defusedxml==0.7.1 - - dill==0.3.8 + - dill==0.3.9 - gmpy2==2.1.2 - numpy==1.26.4 - opt-einsum==3.3.0 @@ -17,8 +17,8 @@ dependencies: - pandas==2.2.1 - pyyaml==6.0.1 - requests==2.31.0 - - torch==2.0.1 - - torchvision==0.15.2 + - torch>=2.6.0 + - torchvision>=0.17.0 - tqdm==4.66.2 - scikit-learn==1.3.2 name: mlflow-env diff --git a/tests/data/serve_resources/mlflow/pytorch/requirements.txt b/tests/data/serve_resources/mlflow/pytorch/requirements.txt index 0446ed5053..aacc85cb91 100644 --- a/tests/data/serve_resources/mlflow/pytorch/requirements.txt +++ b/tests/data/serve_resources/mlflow/pytorch/requirements.txt @@ -3,14 +3,14 @@ astunparse==1.6.3 cffi==1.16.0 cloudpickle==2.2.1 defusedxml==0.7.1 -dill==0.3.8 +dill==0.3.9 gmpy2==2.1.2 -numpy==1.24.4 +numpy==1.26.4 opt-einsum==3.3.0 -packaging==21.3 +packaging>=23.0,<25 pandas==2.2.1 pyyaml==6.0.1 requests==2.32.2 -torch==2.2.0 -torchvision==0.17.0 +torch>=2.6.0 +torchvision>=0.17.0 tqdm==4.66.3 diff --git a/tests/data/serve_resources/mlflow/xgboost/requirements.txt b/tests/data/serve_resources/mlflow/xgboost/requirements.txt index 1130dcaec5..6f879340a7 100644 --- a/tests/data/serve_resources/mlflow/xgboost/requirements.txt +++ b/tests/data/serve_resources/mlflow/xgboost/requirements.txt @@ -1,8 +1,8 @@ mlflow==2.13.2 lz4==4.3.2 -numpy==1.24.4 +numpy==1.26.4 pandas==2.0.3 psutil==5.9.8 scikit-learn==1.3.2 -scipy==1.10.1 +scipy==1.11.3 xgboost==1.7.1 diff --git a/tests/data/workflow/requirements.txt b/tests/data/workflow/requirements.txt index 0e99587e6e..44ce1d9331 100644 --- a/tests/data/workflow/requirements.txt +++ b/tests/data/workflow/requirements.txt @@ -1 +1 @@ -scipy==1.10.1 +scipy==1.11.3 diff --git a/tests/integ/sagemaker/experiments/helpers.py b/tests/integ/sagemaker/experiments/helpers.py index 9a22c3a30c..c8f35471b1 100644 --- a/tests/integ/sagemaker/experiments/helpers.py +++ b/tests/integ/sagemaker/experiments/helpers.py @@ -13,9 +13,12 @@ from __future__ import absolute_import from contextlib import contextmanager +import pytest +import logging from sagemaker import utils from sagemaker.experiments.experiment import Experiment +from sagemaker.experiments._run_context import _RunContext EXP_INTEG_TEST_NAME_PREFIX = "experiments-integ" @@ -40,3 +43,17 @@ def cleanup_exp_resources(exp_names, sagemaker_session): for exp_name in exp_names: exp = Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session) exp._delete_all(action="--force") + + +@pytest.fixture +def clear_run_context(): + current_run = _RunContext.get_current_run() + if current_run is None: + return + + logging.info( + f"RunContext already populated by run {current_run.run_name}" + f" in experiment {current_run.experiment_name}." + " Clearing context manually" + ) + _RunContext.drop_current_run() diff --git a/tests/integ/sagemaker/experiments/test_run.py b/tests/integ/sagemaker/experiments/test_run.py index 4f59d11c54..f00f53a5ad 100644 --- a/tests/integ/sagemaker/experiments/test_run.py +++ b/tests/integ/sagemaker/experiments/test_run.py @@ -720,8 +720,8 @@ def _generate_processor( ) return FrameworkProcessor( estimator_cls=PyTorch, - framework_version="1.10", - py_version="py38", + framework_version="1.13.1", + py_version="py39", instance_count=1, instance_type="ml.m5.xlarge", role=execution_role, diff --git a/tests/integ/sagemaker/jumpstart/constants.py b/tests/integ/sagemaker/jumpstart/constants.py index 1ffb1d8dc0..740d88e9c0 100644 --- a/tests/integ/sagemaker/jumpstart/constants.py +++ b/tests/integ/sagemaker/jumpstart/constants.py @@ -47,7 +47,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str: ("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"), ("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"), ("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"), - ("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"), + ("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"), ("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"), ("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"), ("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"), diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index b938f489df..c9a39ac3dc 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -170,7 +170,7 @@ def test_jumpstart_gated_model(setup): model = JumpStartModel( model_id=model_id, - model_version="3.*", # version >=3.0.0 stores artifacts in jumpstart-private-cache-* buckets + model_version="*", # version >=3.0.0 stores artifacts in jumpstart-private-cache-* buckets role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), ) @@ -197,7 +197,7 @@ def test_jumpstart_gated_model_inference_component_enabled(setup): model = JumpStartModel( model_id=model_id, - model_version="3.*", # version >=3.0.0 stores artifacts in jumpstart-private-cache-* buckets + model_version="*", # version >=3.0.0 stores artifacts in jumpstart-private-cache-* buckets role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), ) @@ -477,6 +477,7 @@ def _teardown_test_hub_with_reference(public_hub_model_id: str): raise e +@pytest.mark.skip # Currently JumpStartModel does not pull from HubService for the Public Hub. def test_model_reference_marketplace_model(setup): session = get_sm_session() @@ -535,6 +536,7 @@ def test_model_reference_marketplace_model(setup): # _teardown_test_hub_with_reference(public_hub_marketplace_model_id) +@pytest.mark.skip def test_bedrock_store_model_tags_from_hub_service(setup): session = get_sm_session() diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/__init__.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py new file mode 100644 index 0000000000..a6e33f1bdf --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -0,0 +1,204 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import time + +import pytest +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.hub.hub import Hub + +from sagemaker.jumpstart.estimator import JumpStartEstimator +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket + +from tests.integ.sagemaker.jumpstart.constants import ( + ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, + ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, + JUMPSTART_TAG, +) +from tests.integ.sagemaker.jumpstart.utils import ( + get_public_hub_model_arn, + get_sm_session, + with_exponential_backoff, + get_training_dataset_for_model_and_version, +) + +MAX_INIT_TIME_SECONDS = 5 + +TEST_MODEL_IDS = { + "huggingface-spc-bert-base-cased", + "meta-textgeneration-llama-2-7b", + "catboost-regression-model", +} + + +@with_exponential_backoff() +def create_model_reference(hub_instance, model_arn): + try: + hub_instance.create_model_reference(model_arn=model_arn) + except Exception: + pass + + +@pytest.fixture(scope="session") +def add_model_references(): + # Create Model References to test in Hub + hub_instance = Hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + ) + for model in TEST_MODEL_IDS: + model_arn = get_public_hub_model_arn(hub_instance, model) + create_model_reference(hub_instance, model_arn) + + +def test_jumpstart_hub_estimator(setup, add_model_references): + model_id, model_version = "huggingface-spc-bert-base-cased", "*" + + estimator = JumpStartEstimator( + model_id=model_id, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + estimator.fit( + inputs={ + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + } + ) + + # test that we can create a JumpStartEstimator from existing job with `attach` + estimator = JumpStartEstimator.attach( + training_job_name=estimator.latest_training_job.name, + model_id=model_id, + model_version=model_version, + ) + + # uses ml.p3.2xlarge instance + predictor = estimator.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + response = predictor.predict(["hello", "world"]) + + assert response is not None + + +def test_jumpstart_hub_estimator_with_session(setup, add_model_references): + + model_id, model_version = "huggingface-spc-bert-base-cased", "*" + + sagemaker_session = get_sm_session() + + estimator = JumpStartEstimator( + model_id=model_id, + role=sagemaker_session.get_caller_identity_arn(), + sagemaker_session=sagemaker_session, + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + estimator.fit( + inputs={ + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + } + ) + + # test that we can create a JumpStartEstimator from existing job with `attach` + estimator = JumpStartEstimator.attach( + training_job_name=estimator.latest_training_job.name, + model_id=model_id, + model_version=model_version, + sagemaker_session=get_sm_session(), + ) + + # uses ml.p3.2xlarge instance + predictor = estimator.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + response = predictor.predict(["hello", "world"]) + + assert response is not None + + +def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references): + + model_id, model_version = "meta-textgeneration-llama-2-7b", "*" + + estimator = JumpStartEstimator( + model_id=model_id, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + estimator.fit( + accept_eula=True, + inputs={ + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + }, + ) + + predictor = estimator.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + payload = { + "inputs": "some-payload", + "parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6}, + } + + response = predictor.predict(payload, custom_attributes="accept_eula=true") + + assert response is not None + + +def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references): + + model_id, model_version = "meta-textgeneration-llama-2-7b", "*" + + estimator = JumpStartEstimator( + model_id=model_id, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + with pytest.raises(Exception): + estimator.fit( + inputs={ + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + } + ) + + +def test_instantiating_estimator(setup, add_model_references): + + model_id = "catboost-regression-model" + + start_time = time.perf_counter() + + JumpStartEstimator( + model_id=model_id, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + elapsed_time = time.perf_counter() - start_time + + assert elapsed_time <= MAX_INIT_TIME_SECONDS diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py index 751162d2e6..c7e039693b 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py @@ -48,7 +48,10 @@ @with_exponential_backoff() def create_model_reference(hub_instance, model_arn): - hub_instance.create_model_reference(model_arn=model_arn) + try: + hub_instance.create_model_reference(model_arn=model_arn) + except Exception: + pass @pytest.fixture(scope="session") @@ -82,6 +85,23 @@ def test_jumpstart_hub_model(setup, add_model_references): assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name) +def test_jumpstart_hub_model_with_default_session(setup, add_model_references): + model_version = "*" + hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] + + model_id = "catboost-classification-model" + + sagemaker_session = get_sm_session() + + model = JumpStartModel(model_id=model_id, model_version=model_version, hub_name=hub_name) + + predictor = model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name) + + def test_jumpstart_hub_gated_model(setup, add_model_references): model_id = "meta-textgeneration-llama-3-2-1b" @@ -105,9 +125,10 @@ def test_jumpstart_hub_gated_model(setup, add_model_references): assert response is not None +@pytest.mark.skip(reason="blocking PR checks and release pipeline.") def test_jumpstart_gated_model_inference_component_enabled(setup, add_model_references): - model_id = "meta-textgeneration-llama-2-7b" + model_id = "meta-textgeneration-llama-3-2-1b" hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] diff --git a/tests/integ/sagemaker/jumpstart/private_hub/test_hub.py b/tests/integ/sagemaker/jumpstart/private_hub/test_hub.py index 2bccb96524..db5d868c06 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/test_hub.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/test_hub.py @@ -32,6 +32,7 @@ def hub_instance(): yield hub +@pytest.mark.skip def test_private_hub(setup, hub_instance): # Createhub create_hub_response = hub_instance.create( diff --git a/tests/integ/sagemaker/jumpstart/private_hub/test_hub_content.py b/tests/integ/sagemaker/jumpstart/private_hub/test_hub_content.py index b25cff2d62..04b945a457 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/test_hub_content.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/test_hub_content.py @@ -38,7 +38,7 @@ def test_hub_model_reference(setup): describe_model_response = hub_instance.describe_model(model_name=model_id) assert describe_model_response is not None - assert type(describe_model_response) == DescribeHubContentResponse + assert isinstance(describe_model_response, DescribeHubContentResponse) assert describe_model_response.hub_content_name == model_id assert describe_model_response.hub_content_type == "ModelReference" diff --git a/tests/integ/sagemaker/modules/train/test_local_model_trainer.py b/tests/integ/sagemaker/modules/train/test_local_model_trainer.py index adb5f85f3e..7947b2fc87 100644 --- a/tests/integ/sagemaker/modules/train/test_local_model_trainer.py +++ b/tests/integ/sagemaker/modules/train/test_local_model_trainer.py @@ -92,10 +92,7 @@ def test_single_container_local_mode_local_data(modules_sagemaker_session): "compressed_artifacts", "artifacts", "model", - "shared", - "input", "output", - "algo-1", ] for directory in directories: @@ -149,14 +146,16 @@ def test_single_container_local_mode_s3_data(modules_sagemaker_session): assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz")) finally: subprocess.run(["docker", "compose", "down", "-v"]) + + assert not os.path.exists(os.path.join(CWD, "shared")) + assert not os.path.exists(os.path.join(CWD, "input")) + assert not os.path.exists(os.path.join(CWD, "algo-1")) + directories = [ "compressed_artifacts", "artifacts", "model", - "shared", - "input", "output", - "algo-1", ] for directory in directories: @@ -204,20 +203,20 @@ def test_multi_container_local_mode(modules_sagemaker_session): model_trainer.train() assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz")) - assert os.path.exists(os.path.join(CWD, "algo-1")) - assert os.path.exists(os.path.join(CWD, "algo-2")) finally: subprocess.run(["docker", "compose", "down", "-v"]) + + assert not os.path.exists(os.path.join(CWD, "shared")) + assert not os.path.exists(os.path.join(CWD, "input")) + assert not os.path.exists(os.path.join(CWD, "algo-1")) + assert not os.path.exists(os.path.join(CWD, "algo-2")) + directories = [ "compressed_artifacts", "artifacts", "model", - "shared", - "input", "output", - "algo-1", - "algo-2", ] for directory in directories: diff --git a/tests/integ/sagemaker/modules/train/test_model_trainer.py b/tests/integ/sagemaker/modules/train/test_model_trainer.py index cd298402b2..a1e3106553 100644 --- a/tests/integ/sagemaker/modules/train/test_model_trainer.py +++ b/tests/integ/sagemaker/modules/train/test_model_trainer.py @@ -17,7 +17,7 @@ from sagemaker.modules.train import ModelTrainer from sagemaker.modules.configs import SourceCode, Compute -from sagemaker.modules.distributed import MPI, Torchrun +from sagemaker.modules.distributed import MPI, Torchrun, DistributedConfig EXPECTED_HYPERPARAMETERS = { "integer": 1, @@ -28,26 +28,29 @@ "dict": { "string": "value", "integer": 3, + "float": 3.14, "list": [1, 2, 3], "dict": {"key": "value"}, "boolean": True, }, } +PARAM_SCRIPT_SOURCE_DIR = f"{DATA_DIR}/modules/params_script" +PARAM_SCRIPT_SOURCE_CODE = SourceCode( + source_dir=PARAM_SCRIPT_SOURCE_DIR, + requirements="requirements.txt", + entry_script="train.py", +) + DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310" def test_hp_contract_basic_py_script(modules_sagemaker_session): - source_code = SourceCode( - source_dir=f"{DATA_DIR}/modules/params_script", - entry_script="train.py", - ) - model_trainer = ModelTrainer( sagemaker_session=modules_sagemaker_session, training_image=DEFAULT_CPU_IMAGE, hyperparameters=EXPECTED_HYPERPARAMETERS, - source_code=source_code, + source_code=PARAM_SCRIPT_SOURCE_CODE, base_job_name="hp-contract-basic-py-script", ) @@ -57,6 +60,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session): def test_hp_contract_basic_sh_script(modules_sagemaker_session): source_code = SourceCode( source_dir=f"{DATA_DIR}/modules/params_script", + requirements="requirements.txt", entry_script="train.sh", ) model_trainer = ModelTrainer( @@ -71,17 +75,13 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session): def test_hp_contract_mpi_script(modules_sagemaker_session): - source_code = SourceCode( - source_dir=f"{DATA_DIR}/modules/params_script", - entry_script="train.py", - ) compute = Compute(instance_type="ml.m5.xlarge", instance_count=2) model_trainer = ModelTrainer( sagemaker_session=modules_sagemaker_session, training_image=DEFAULT_CPU_IMAGE, compute=compute, hyperparameters=EXPECTED_HYPERPARAMETERS, - source_code=source_code, + source_code=PARAM_SCRIPT_SOURCE_CODE, distributed=MPI(), base_job_name="hp-contract-mpi-script", ) @@ -90,19 +90,71 @@ def test_hp_contract_mpi_script(modules_sagemaker_session): def test_hp_contract_torchrun_script(modules_sagemaker_session): - source_code = SourceCode( - source_dir=f"{DATA_DIR}/modules/params_script", - entry_script="train.py", - ) compute = Compute(instance_type="ml.m5.xlarge", instance_count=2) model_trainer = ModelTrainer( sagemaker_session=modules_sagemaker_session, training_image=DEFAULT_CPU_IMAGE, compute=compute, hyperparameters=EXPECTED_HYPERPARAMETERS, - source_code=source_code, + source_code=PARAM_SCRIPT_SOURCE_CODE, distributed=Torchrun(), base_job_name="hp-contract-torchrun-script", ) model_trainer.train() + + +def test_hp_contract_hyperparameter_json(modules_sagemaker_session): + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.json", + source_code=PARAM_SCRIPT_SOURCE_CODE, + base_job_name="hp-contract-hyperparameter-json", + ) + assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS + model_trainer.train() + + +def test_hp_contract_hyperparameter_yaml(modules_sagemaker_session): + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.yaml", + source_code=PARAM_SCRIPT_SOURCE_CODE, + base_job_name="hp-contract-hyperparameter-yaml", + ) + assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS + model_trainer.train() + + +def test_custom_distributed_driver(modules_sagemaker_session): + class CustomDriver(DistributedConfig): + process_count_per_node: int = None + + @property + def driver_dir(self) -> str: + return f"{DATA_DIR}/modules/custom_drivers" + + @property + def driver_script(self) -> str: + return "driver.py" + + source_code = SourceCode( + source_dir=f"{DATA_DIR}/modules/scripts", + entry_script="entry_script.py", + ) + + hyperparameters = {"epochs": 10} + + custom_driver = CustomDriver(process_count_per_node=2) + + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=hyperparameters, + source_code=source_code, + distributed=custom_driver, + base_job_name="custom-distributed-driver", + ) + model_trainer.train() diff --git a/tests/integ/sagemaker/remote_function/test_decorator.py b/tests/integ/sagemaker/remote_function/test_decorator.py index 2717bb9afe..fa55d7dfa7 100644 --- a/tests/integ/sagemaker/remote_function/test_decorator.py +++ b/tests/integ/sagemaker/remote_function/test_decorator.py @@ -825,7 +825,7 @@ def test_decorator_torchrun( dummy_container_without_error, gpu_instance_type, use_torchrun=False, - nproc_per_node=1, + use_mpirun=False, ): @remote( role=ROLE, @@ -834,7 +834,7 @@ def test_decorator_torchrun( sagemaker_session=sagemaker_session, keep_alive_period_in_seconds=60, use_torchrun=use_torchrun, - nproc_per_node=nproc_per_node, + use_mpirun=use_mpirun, ) def divide(x, y): return x / y diff --git a/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py b/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py index 10f338c4b5..80f9c50e4b 100644 --- a/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py +++ b/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py @@ -12,38 +12,72 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import pytest - -from sagemaker import get_execution_role -from sklearn.datasets import load_iris -from sklearn.model_selection import train_test_split - import os +import uuid +from typing import Generator +import numpy as np +import pandas as pd +import pytest +from sagemaker_core.main.resources import TrainingJob from sagemaker_core.main.shapes import ( AlgorithmSpecification, Channel, DataSource, - S3DataSource, OutputDataConfig, ResourceConfig, + S3DataSource, StoppingCondition, ) -import uuid -from sagemaker.serve.builder.model_builder import ModelBuilder -import pandas as pd -import numpy as np -from sagemaker.serve import InferenceSpec, SchemaBuilder -from sagemaker_core.main.resources import TrainingJob +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split from xgboost import XGBClassifier -from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig - -from sagemaker.s3_utils import s3_path_join +from sagemaker import get_execution_role from sagemaker.async_inference import AsyncInferenceConfig +from sagemaker.s3_utils import s3_path_join +from sagemaker.serve import InferenceSpec, SchemaBuilder +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig from tests.integ.utils import cleanup_model_resources +@pytest.fixture(autouse=True) +def cleanup_endpoints(mb_sagemaker_session) -> Generator[None, None, None]: + """Clean up any existing endpoints before and after tests.""" + sagemaker_client = mb_sagemaker_session.sagemaker_client + + # Pre-test cleanup + try: + endpoints = sagemaker_client.list_endpoints() + for endpoint in endpoints["Endpoints"]: + try: + sagemaker_client.delete_endpoint(EndpointName=endpoint["EndpointName"]) + sagemaker_client.delete_endpoint_config( + EndpointConfigName=endpoint["EndpointConfigName"] + ) + except Exception as e: + print(f"Error cleaning up endpoint {endpoint['EndpointName']}: {e}") + except Exception as e: + print(f"Error listing endpoints: {e}") + + yield + + # Post-test cleanup + try: + endpoints = sagemaker_client.list_endpoints() + for endpoint in endpoints["Endpoints"]: + try: + sagemaker_client.delete_endpoint(EndpointName=endpoint["EndpointName"]) + sagemaker_client.delete_endpoint_config( + EndpointConfigName=endpoint["EndpointConfigName"] + ) + except Exception as e: + print(f"Error cleaning up endpoint {endpoint['EndpointName']}: {e}") + except Exception as e: + print(f"Error listing endpoints: {e}") + + @pytest.fixture(scope="module") def xgboost_model_builder(mb_sagemaker_session): sagemaker_session = mb_sagemaker_session diff --git a/tests/integ/sagemaker/serve/test_schema_builder.py b/tests/integ/sagemaker/serve/test_schema_builder.py index 1a2bbe2355..6d3e8281d5 100644 --- a/tests/integ/sagemaker/serve/test_schema_builder.py +++ b/tests/integ/sagemaker/serve/test_schema_builder.py @@ -34,7 +34,9 @@ def test_model_builder_happy_path_with_only_model_id_text_generation(sagemaker_session): model_builder = ModelBuilder( - model="HuggingFaceH4/zephyr-7b-beta", sagemaker_session=sagemaker_session + model="HuggingFaceH4/zephyr-7b-beta", + sagemaker_session=sagemaker_session, + instance_type=None, ) model = model_builder.build(sagemaker_session=sagemaker_session) diff --git a/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py b/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py index 348c57745f..ea65f998c8 100644 --- a/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py +++ b/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py @@ -24,14 +24,17 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_expected( sagemaker_session, ): - with patch.object( - Session, "create_model", return_value="mock_model" - ) as mock_create_model, patch.object( - Session, "endpoint_from_production_variants" - ) as mock_endpoint_from_production_variants: + with ( + patch.object(Session, "create_model", return_value="mock_model") as mock_create_model, + patch.object( + Session, "endpoint_from_production_variants" + ) as mock_endpoint_from_production_variants, + ): iam_client = sagemaker_session.boto_session.client("iam") role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] + sagemaker_session.sagemaker_client.create_optimization_job = MagicMock() + schema_builder = SchemaBuilder("test", "test") model_builder = ModelBuilder( model="meta-textgeneration-llama-3-1-8b-instruct", @@ -50,6 +53,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e accept_eula=True, ) + assert not sagemaker_session.sagemaker_client.create_optimization_job.called + optimized_model.deploy() mock_create_model.assert_called_once_with( @@ -96,17 +101,18 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_expected( sagemaker_session, ): - with patch.object( - Session, - "wait_for_optimization_job", - return_value={"OptimizationJobName": "mock_optimization_job"}, - ), patch.object( - Session, "create_model", return_value="mock_model" - ) as mock_create_model, patch.object( - Session, "endpoint_from_production_variants", return_value="mock_endpoint_name" - ) as mock_endpoint_from_production_variants, patch.object( - Session, "create_inference_component" - ) as mock_create_inference_component: + with ( + patch.object( + Session, + "wait_for_optimization_job", + return_value={"OptimizationJobName": "mock_optimization_job"}, + ), + patch.object(Session, "create_model", return_value="mock_model") as mock_create_model, + patch.object( + Session, "endpoint_from_production_variants", return_value="mock_endpoint_name" + ) as mock_endpoint_from_production_variants, + patch.object(Session, "create_inference_component") as mock_create_inference_component, + ): iam_client = sagemaker_session.boto_session.client("iam") role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] @@ -126,6 +132,13 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_ accept_eula=True, ) + assert ( + sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelShardingConfig"]["Image"] + is not None + ) + optimized_model.deploy( resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8}) ) @@ -174,15 +187,17 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are_expected( sagemaker_session, ): - with patch.object( - Session, - "wait_for_optimization_job", - return_value={"OptimizationJobName": "mock_optimization_job"}, - ), patch.object( - Session, "create_model", return_value="mock_model" - ) as mock_create_model, patch.object( - Session, "endpoint_from_production_variants", return_value="mock_endpoint_name" - ) as mock_endpoint_from_production_variants: + with ( + patch.object( + Session, + "wait_for_optimization_job", + return_value={"OptimizationJobName": "mock_optimization_job"}, + ), + patch.object(Session, "create_model", return_value="mock_model") as mock_create_model, + patch.object( + Session, "endpoint_from_production_variants", return_value="mock_endpoint_name" + ) as mock_endpoint_from_production_variants, + ): iam_client = sagemaker_session.boto_session.client("iam") role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] @@ -206,6 +221,13 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are accept_eula=True, ) + assert ( + sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelQuantizationConfig"]["Image"] + is not None + ) + optimized_model.deploy() mock_create_model.assert_called_once_with( diff --git a/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py b/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py index e6beb76d6e..345d5e5af9 100644 --- a/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py @@ -31,7 +31,7 @@ PYTORCH_SQUEEZENET_MLFLOW_RESOURCE_DIR, SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, # SERVE_LOCAL_CONTAINER_TIMEOUT, - PYTHON_VERSION_IS_NOT_310, + # PYTHON_VERSION_IS_NOT_310, ) from tests.integ.timeout import timeout from tests.integ.utils import cleanup_model_resources @@ -166,9 +166,9 @@ def model_builder(request): # ), f"{caught_ex} was thrown when running pytorch squeezenet local container test" -@pytest.mark.skipif( - PYTHON_VERSION_IS_NOT_310, # or NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE, - reason="The goal of these test are to test the serving components of our feature", +@pytest.mark.skip( + reason="Testing against Python version 310 which is not supported anymore" + " https://github.com/aws/deep-learning-containers/blob/master/available_images.md", ) def test_happy_pytorch_sagemaker_endpoint_with_torch_serve( sagemaker_session, diff --git a/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py b/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py index 8724fc5116..cf1eb65325 100644 --- a/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py +++ b/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py @@ -96,6 +96,8 @@ def model_builder(request): def test_non_text_generation_model_single_GPU( sagemaker_session, model_builder, model_input, **kwargs ): + if kwargs["instance_type"] == "ml.p2.xlarge": + pytest.skip("Instance type ml.p2.xlarge has been deprecated") iam_client = sagemaker_session.boto_session.client("iam") role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session) @@ -147,6 +149,8 @@ def test_non_text_generation_model_single_GPU( def test_non_text_generation_model_multi_GPU( sagemaker_session, model_builder, model_input, **kwargs ): + if kwargs["instance_type"] == "ml.p2.xlarge": + pytest.skip("Instance type ml.p2.xlarge has been deprecated") iam_client = sagemaker_session.boto_session.client("iam") role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] caught_ex = None diff --git a/tests/integ/sagemaker/serve/test_serve_transformers.py b/tests/integ/sagemaker/serve/test_serve_transformers.py index 5f172f3edb..9405934474 100644 --- a/tests/integ/sagemaker/serve/test_serve_transformers.py +++ b/tests/integ/sagemaker/serve/test_serve_transformers.py @@ -97,6 +97,9 @@ def model_builder(request): def test_pytorch_transformers_sagemaker_endpoint( sagemaker_session, model_builder, model_input, **kwargs ): + if kwargs["instance_type"] == "ml.p2.xlarge": + pytest.skip("Instance type ml.p2.xlarge has been deprecated") + logger.info("Running in SAGEMAKER_ENDPOINT mode...") caught_ex = None diff --git a/tests/integ/sagemaker/serve/utils/test_hardware_detector.py b/tests/integ/sagemaker/serve/utils/test_hardware_detector.py index 9102927c55..bab26a25d1 100644 --- a/tests/integ/sagemaker/serve/utils/test_hardware_detector.py +++ b/tests/integ/sagemaker/serve/utils/test_hardware_detector.py @@ -19,7 +19,7 @@ REGION = "us-west-2" VALID_INSTANCE_TYPE = "ml.g5.48xlarge" INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge" -EXPECTED_INSTANCE_GPU_INFO = (8, 196608) +EXPECTED_INSTANCE_GPU_INFO = (8, 183104) def test_get_gpu_info_success(sagemaker_session): diff --git a/tests/integ/sagemaker/workflow/helpers.py b/tests/integ/sagemaker/workflow/helpers.py index 20365ef169..9f0176c5c2 100644 --- a/tests/integ/sagemaker/workflow/helpers.py +++ b/tests/integ/sagemaker/workflow/helpers.py @@ -70,8 +70,8 @@ def create_and_execute_pipeline( assert execution_steps[0]["StepStatus"] == step_status if step_result_type: result = execution.result(execution_steps[0]["StepName"]) - assert ( - type(result) == step_result_type + assert isinstance( + result, step_result_type ), f"Expected {step_result_type}, instead found {type(result)}" if step_result_value: diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index 7f85c0066c..8f98cd076d 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -26,7 +26,6 @@ import pytest from packaging.version import Version -from packaging.specifiers import SpecifierSet from sagemaker.model_card.model_card import ModelCard, ModelOverview, ModelPackageModelCard from sagemaker.model_card.schema_constraints import ModelCardStatusEnum @@ -1422,7 +1421,7 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model( pipeline_name, region_name, ): - if Version(tf_full_version) in SpecifierSet("==2.16.*"): + if Version(tf_full_version) >= Version("2.16"): pytest.skip( "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " "https://github.com/tensorflow/io/issues/2039" diff --git a/tests/integ/sagemaker/workflow/test_model_steps.py b/tests/integ/sagemaker/workflow/test_model_steps.py index 089cdaf08f..02f7613f85 100644 --- a/tests/integ/sagemaker/workflow/test_model_steps.py +++ b/tests/integ/sagemaker/workflow/test_model_steps.py @@ -18,7 +18,6 @@ import pytest from packaging.version import Version -from packaging.specifiers import SpecifierSet from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution from sagemaker.workflow.fail_step import FailStep @@ -592,7 +591,7 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics( def test_model_registration_with_tensorflow_model_with_pipeline_model( pipeline_session, role, tf_full_version, tf_full_py_version, pipeline_name ): - if Version(tf_full_version) in SpecifierSet("==2.16.*"): + if Version(tf_full_version) >= Version("2.16"): pytest.skip( "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " "https://github.com/tensorflow/io/issues/2039" diff --git a/tests/integ/sagemaker/workflow/test_training_steps.py b/tests/integ/sagemaker/workflow/test_training_steps.py index bcff221afe..4b442c6d93 100644 --- a/tests/integ/sagemaker/workflow/test_training_steps.py +++ b/tests/integ/sagemaker/workflow/test_training_steps.py @@ -19,7 +19,6 @@ import pytest from packaging.version import Version -from packaging.specifiers import SpecifierSet from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution from sagemaker import TrainingInput, get_execution_role, utils, image_uris @@ -238,7 +237,7 @@ def test_training_step_with_output_path_as_join( def test_tensorflow_training_step_with_parameterized_code_input( pipeline_session, role, tf_full_version, tf_full_py_version, pipeline_name ): - if Version(tf_full_version) in SpecifierSet("==2.16.*"): + if Version(tf_full_version) >= Version("2.16"): pytest.skip( "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " "https://github.com/tensorflow/io/issues/2039" diff --git a/tests/integ/sagemaker/workflow/test_workflow.py b/tests/integ/sagemaker/workflow/test_workflow.py index 2643a3b88e..9ef0b14a04 100644 --- a/tests/integ/sagemaker/workflow/test_workflow.py +++ b/tests/integ/sagemaker/workflow/test_workflow.py @@ -1122,8 +1122,8 @@ def test_model_registration_with_tuning_model( entry_point=entry_point, source_dir=base_dir, role=role, - framework_version="1.10", - py_version="py38", + framework_version="1.13.1", + py_version="py39", instance_count=instance_count, instance_type=instance_type, sagemaker_session=pipeline_session, @@ -1159,8 +1159,8 @@ def test_model_registration_with_tuning_model( ), entry_point=entry_point, source_dir=base_dir, - framework_version="1.10", - py_version="py38", + framework_version="1.13.1", + py_version="py39", sagemaker_session=pipeline_session, ) step_model_regis_args = model.register( diff --git a/tests/integ/test_collection.py b/tests/integ/test_collection.py index 2ee1d90e34..9a6db645cf 100644 --- a/tests/integ/test_collection.py +++ b/tests/integ/test_collection.py @@ -19,20 +19,22 @@ def test_create_collection_root_success(sagemaker_session): collection = Collection(sagemaker_session) collection_name = unique_name_from_base("test-collection") - collection.create(collection_name) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - assert collection_details["ResponseMetadata"]["HTTPStatusCode"] == 200 - delete_response = collection.delete([collection_name]) - assert len(delete_response["deleted_collections"]) == 1 - assert len(delete_response["delete_collection_failures"]) == 0 + try: + collection.create(collection_name) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + assert collection_details["ResponseMetadata"]["HTTPStatusCode"] == 200 + finally: + delete_response = collection.delete([collection_name]) + assert len(delete_response["deleted_collections"]) == 1 + assert len(delete_response["delete_collection_failures"]) == 0 def test_create_collection_nested_success(sagemaker_session): @@ -41,25 +43,27 @@ def test_create_collection_nested_success(sagemaker_session): child_collection_name = unique_name_from_base("test-collection-2") collection.create(collection_name) collection.create(collection_name=child_collection_name, parent_collection_name=collection_name) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - # has one child i.e child collection - assert len(collection_details["Resources"]) == 1 - - collection_details = sagemaker_session.list_group_resources( - group=child_collection_name, filters=collection_filter - ) - collection_details["ResponseMetadata"]["HTTPStatusCode"] - delete_response = collection.delete([child_collection_name, collection_name]) - assert len(delete_response["deleted_collections"]) == 2 - assert len(delete_response["delete_collection_failures"]) == 0 + try: + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + # has one child i.e child collection + assert len(collection_details["Resources"]) == 1 + + collection_details = sagemaker_session.list_group_resources( + group=child_collection_name, filters=collection_filter + ) + collection_details["ResponseMetadata"]["HTTPStatusCode"] + finally: + delete_response = collection.delete([child_collection_name, collection_name]) + assert len(delete_response["deleted_collections"]) == 2 + assert len(delete_response["delete_collection_failures"]) == 0 def test_add_remove_model_groups_in_collection_success(sagemaker_session): @@ -70,40 +74,42 @@ def test_add_remove_model_groups_in_collection_success(sagemaker_session): collection = Collection(sagemaker_session) collection_name = unique_name_from_base("test-collection") collection.create(collection_name) - model_groups = [] - model_groups.append(model_group_name) - add_response = collection.add_model_groups( - collection_name=collection_name, model_groups=model_groups - ) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - - assert len(add_response["failure"]) == 0 - assert len(add_response["added_groups"]) == 1 - assert len(collection_details["Resources"]) == 1 - - remove_response = collection.remove_model_groups( - collection_name=collection_name, model_groups=model_groups - ) - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - assert len(remove_response["failure"]) == 0 - assert len(remove_response["removed_groups"]) == 1 - assert len(collection_details["Resources"]) == 0 - - delete_response = collection.delete([collection_name]) - assert len(delete_response["deleted_collections"]) == 1 - sagemaker_session.sagemaker_client.delete_model_package_group( - ModelPackageGroupName=model_group_name - ) + try: + model_groups = [] + model_groups.append(model_group_name) + add_response = collection.add_model_groups( + collection_name=collection_name, model_groups=model_groups + ) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + + assert len(add_response["failure"]) == 0 + assert len(add_response["added_groups"]) == 1 + assert len(collection_details["Resources"]) == 1 + + remove_response = collection.remove_model_groups( + collection_name=collection_name, model_groups=model_groups + ) + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + assert len(remove_response["failure"]) == 0 + assert len(remove_response["removed_groups"]) == 1 + assert len(collection_details["Resources"]) == 0 + + finally: + delete_response = collection.delete([collection_name]) + assert len(delete_response["deleted_collections"]) == 1 + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) def test_move_model_groups_in_collection_success(sagemaker_session): @@ -116,56 +122,58 @@ def test_move_model_groups_in_collection_success(sagemaker_session): destination_collection_name = unique_name_from_base("test-collection-destination") collection.create(source_collection_name) collection.create(destination_collection_name) - model_groups = [] - model_groups.append(model_group_name) - add_response = collection.add_model_groups( - collection_name=source_collection_name, model_groups=model_groups - ) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=source_collection_name, filters=collection_filter - ) - - assert len(add_response["failure"]) == 0 - assert len(add_response["added_groups"]) == 1 - assert len(collection_details["Resources"]) == 1 - - move_response = collection.move_model_group( - source_collection_name=source_collection_name, - model_group=model_group_name, - destination_collection_name=destination_collection_name, - ) - - assert move_response["moved_success"] == model_group_name - - collection_details = sagemaker_session.list_group_resources( - group=destination_collection_name, filters=collection_filter - ) - - assert len(collection_details["Resources"]) == 1 - - collection_details = sagemaker_session.list_group_resources( - group=source_collection_name, filters=collection_filter - ) - assert len(collection_details["Resources"]) == 0 - - remove_response = collection.remove_model_groups( - collection_name=destination_collection_name, model_groups=model_groups - ) - - assert len(remove_response["failure"]) == 0 - assert len(remove_response["removed_groups"]) == 1 - - delete_response = collection.delete([source_collection_name, destination_collection_name]) - assert len(delete_response["deleted_collections"]) == 2 - sagemaker_session.sagemaker_client.delete_model_package_group( - ModelPackageGroupName=model_group_name - ) + try: + model_groups = [] + model_groups.append(model_group_name) + add_response = collection.add_model_groups( + collection_name=source_collection_name, model_groups=model_groups + ) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=source_collection_name, filters=collection_filter + ) + + assert len(add_response["failure"]) == 0 + assert len(add_response["added_groups"]) == 1 + assert len(collection_details["Resources"]) == 1 + + move_response = collection.move_model_group( + source_collection_name=source_collection_name, + model_group=model_group_name, + destination_collection_name=destination_collection_name, + ) + + assert move_response["moved_success"] == model_group_name + + collection_details = sagemaker_session.list_group_resources( + group=destination_collection_name, filters=collection_filter + ) + + assert len(collection_details["Resources"]) == 1 + + collection_details = sagemaker_session.list_group_resources( + group=source_collection_name, filters=collection_filter + ) + assert len(collection_details["Resources"]) == 0 + + remove_response = collection.remove_model_groups( + collection_name=destination_collection_name, model_groups=model_groups + ) + + assert len(remove_response["failure"]) == 0 + assert len(remove_response["removed_groups"]) == 1 + + finally: + delete_response = collection.delete([source_collection_name, destination_collection_name]) + assert len(delete_response["deleted_collections"]) == 2 + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) def test_list_collection_success(sagemaker_session): @@ -176,23 +184,27 @@ def test_list_collection_success(sagemaker_session): collection = Collection(sagemaker_session) collection_name = unique_name_from_base("test-collection") collection.create(collection_name) - model_groups = [] - model_groups.append(model_group_name) - collection.add_model_groups(collection_name=collection_name, model_groups=model_groups) - child_collection_name = unique_name_from_base("test-collection") - collection.create(parent_collection_name=collection_name, collection_name=child_collection_name) - root_collections = collection.list_collection() - is_collection_found = False - for root_collection in root_collections: - if root_collection["Name"] == collection_name: - is_collection_found = True - assert is_collection_found - - collection_content = collection.list_collection(collection_name) - assert len(collection_content) == 2 - - collection.remove_model_groups(collection_name=collection_name, model_groups=model_groups) - collection.delete([child_collection_name, collection_name]) - sagemaker_session.sagemaker_client.delete_model_package_group( - ModelPackageGroupName=model_group_name - ) + try: + model_groups = [] + model_groups.append(model_group_name) + collection.add_model_groups(collection_name=collection_name, model_groups=model_groups) + child_collection_name = unique_name_from_base("test-collection") + collection.create( + parent_collection_name=collection_name, collection_name=child_collection_name + ) + root_collections = collection.list_collection() + is_collection_found = False + for root_collection in root_collections: + if root_collection["Name"] == collection_name: + is_collection_found = True + assert is_collection_found + + collection_content = collection.list_collection(collection_name) + assert len(collection_content) == 2 + + collection.remove_model_groups(collection_name=collection_name, model_groups=model_groups) + finally: + collection.delete([child_collection_name, collection_name]) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index 43db78527a..75f1807148 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -1645,9 +1645,11 @@ def test_create_dataset_with_feature_group_base( feature_store_session, feature_group, offline_store_s3_uri ) - with timeout(minutes=10) and cleanup_offline_store( - base, feature_store_session - ) and cleanup_offline_store(feature_group, feature_store_session): + with ( + timeout(minutes=10) + and cleanup_offline_store(base, feature_store_session) + and cleanup_offline_store(feature_group, feature_store_session) + ): feature_store = FeatureStore(sagemaker_session=feature_store_session) df, query_string = ( feature_store.create_dataset(base=base, output_path=offline_store_s3_uri) @@ -1832,9 +1834,11 @@ def test_create_dataset_with_feature_group_base_with_additional_params( feature_store_session, feature_group, offline_store_s3_uri ) - with timeout(minutes=10) and cleanup_offline_store( - base, feature_store_session - ) and cleanup_offline_store(feature_group, feature_store_session): + with ( + timeout(minutes=10) + and cleanup_offline_store(base, feature_store_session) + and cleanup_offline_store(feature_group, feature_store_session) + ): feature_store = FeatureStore(sagemaker_session=feature_store_session) df, query_string = ( feature_store.create_dataset(base=base, output_path=offline_store_s3_uri) diff --git a/tests/integ/test_horovod.py b/tests/integ/test_horovod.py index 2ddcdc92e0..78314c2ade 100644 --- a/tests/integ/test_horovod.py +++ b/tests/integ/test_horovod.py @@ -62,11 +62,8 @@ def test_hvd_gpu( tmpdir, **kwargs, ): - if ( - Version(tensorflow_training_latest_version) >= Version("2.12") - and kwargs["instance_type"] == "ml.p2.xlarge" - ): - pytest.skip("P2 instances have been deprecated for sagemaker jobs starting TensorFlow 2.12") + if kwargs["instance_type"] == "ml.p2.xlarge": + pytest.skip("Instance type ml.p2.xlarge has been deprecated") if Version(tensorflow_training_latest_version) >= Version("2.13"): pytest.skip("Horovod is deprecated in TensorFlow 2.13 and above") diff --git a/tests/integ/test_horovod_mx.py b/tests/integ/test_horovod_mx.py index 7bd6a641e0..a238966dd3 100644 --- a/tests/integ/test_horovod_mx.py +++ b/tests/integ/test_horovod_mx.py @@ -58,6 +58,9 @@ def test_hvd_gpu( tmpdir, **kwargs, ): + if kwargs["instance_type"] == "ml.p2.xlarge": + pytest.skip("Instance type ml.p2.xlarge has been deprecated") + _create_and_fit_estimator( mxnet_training_latest_version, mxnet_training_latest_py_version, diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 0015efe3fd..0b2900bef7 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -15,7 +15,8 @@ import boto3 from botocore.config import Config -from sagemaker import Session +from sagemaker import Session, ModelPackage +from sagemaker.utils import unique_name_from_base CUSTOM_BUCKET_NAME = "this-bucket-should-not-exist" @@ -44,3 +45,62 @@ def test_sagemaker_session_does_not_create_bucket_on_init( s3 = boto3.resource("s3", region_name=boto_session.region_name) assert s3.Bucket(CUSTOM_BUCKET_NAME).creation_date is None + + +def test_sagemaker_session_to_return_most_recent_approved_model_package(sagemaker_session): + model_package_group_name = unique_name_from_base("test-model-package-group") + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is None + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_package_group_name + ) + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is None + source_uri = "dummy source uri" + model_package = sagemaker_session.sagemaker_client.create_model_package( + ModelPackageGroupName=model_package_group_name, SourceUri=source_uri + ) + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is None + ModelPackage( + sagemaker_session=sagemaker_session, + model_package_arn=model_package["ModelPackageArn"], + ).update_approval_status(approval_status="Approved") + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is not None + assert approved_model_package.model_package_arn == model_package.get("ModelPackageArn") + model_package_2 = sagemaker_session.sagemaker_client.create_model_package( + ModelPackageGroupName=model_package_group_name, SourceUri=source_uri + ) + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is not None + assert approved_model_package.model_package_arn == model_package.get("ModelPackageArn") + ModelPackage( + sagemaker_session=sagemaker_session, + model_package_arn=model_package_2["ModelPackageArn"], + ).update_approval_status(approval_status="Approved") + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is not None + assert approved_model_package.model_package_arn == model_package_2.get("ModelPackageArn") + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package_2["ModelPackageArn"] + ) + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package["ModelPackageArn"] + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_package_group_name + ) diff --git a/tests/integ/test_spark_processing.py b/tests/integ/test_spark_processing.py index 25a4942d70..ac956be94e 100644 --- a/tests/integ/test_spark_processing.py +++ b/tests/integ/test_spark_processing.py @@ -35,7 +35,7 @@ SPARK_PATH = os.path.join(DATA_DIR, "spark") -@pytest.fixture(scope="module") +@pytest.fixture(scope="module", autouse=True) def build_jar(): jar_file_path = os.path.join(SPARK_PATH, "code", "java", "hello-java-spark") # compile java file @@ -69,9 +69,6 @@ def build_jar(): ".", ] ) - yield - subprocess.run(["rm", os.path.join(jar_file_path, "hello-spark-java.jar")]) - subprocess.run(["rm", os.path.join(jar_file_path, JAVA_FILE_PATH, "HelloJavaSparkApp.class")]) @pytest.fixture(scope="module") @@ -207,12 +204,10 @@ def configuration() -> list: def test_sagemaker_pyspark_v3( - spark_v3_py_processor, spark_v3_jar_processor, sagemaker_session, configuration, build_jar + spark_v3_py_processor, spark_v3_jar_processor, sagemaker_session, configuration ): test_sagemaker_pyspark_multinode(spark_v3_py_processor, sagemaker_session, configuration) - test_sagemaker_java_jar_multinode( - spark_v3_jar_processor, sagemaker_session, configuration, build_jar - ) + test_sagemaker_java_jar_multinode(spark_v3_jar_processor, sagemaker_session, configuration) def test_sagemaker_pyspark_multinode(spark_py_processor, sagemaker_session, configuration): @@ -280,9 +275,7 @@ def test_sagemaker_pyspark_multinode(spark_py_processor, sagemaker_session, conf assert len(output_contents) != 0 -def test_sagemaker_java_jar_multinode( - spark_jar_processor, sagemaker_session, configuration, build_jar -): +def test_sagemaker_java_jar_multinode(spark_jar_processor, sagemaker_session, configuration): """Test SparkJarProcessor using Java application jar""" bucket = spark_jar_processor.sagemaker_session.default_bucket() with open(os.path.join(SPARK_PATH, "files", "data.jsonl")) as data: diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 8c99854d14..0d03aee8ea 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -19,7 +19,6 @@ import pytest from packaging.version import Version -from packaging.specifiers import SpecifierSet from sagemaker import KMeans, s3, get_execution_role from sagemaker.mxnet import MXNet @@ -556,7 +555,7 @@ def test_transform_mxnet_logs( def test_transform_tf_kms_network_isolation( sagemaker_session, cpu_instance_type, tmpdir, tf_full_version, tf_full_py_version ): - if Version(tf_full_version) in SpecifierSet("==2.16.*"): + if Version(tf_full_version) >= Version("2.16"): pytest.skip( "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " "https://github.com/tensorflow/io/issues/2039" diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py index 4c93e18939..5d32030580 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py @@ -75,12 +75,12 @@ def test_constructor_node_should_be_modified(src, expected): ("sagemaker.predictor._NumpyDeserializer()", "deserializers.NumpyDeserializer()"), ("sagemaker.predictor._JsonDeserializer()", "deserializers.JSONDeserializer()"), ( - "sagemaker.amazon.common.numpy_to_record_serializer()", - "sagemaker.amazon.common.RecordSerializer()", + "sagemaker.serializers.numpy_to_record_serializer()", + "sagemaker.serializers.RecordSerializer()", ), ( - "sagemaker.amazon.common.record_deserializer()", - "sagemaker.amazon.common.RecordDeserializer()", + "sagemaker.deserializers.record_deserializer()", + "sagemaker.deserializers.RecordDeserializer()", ), ("_CsvSerializer()", "serializers.CSVSerializer()"), ("_JsonSerializer()", "serializers.JSONSerializer()"), @@ -265,20 +265,12 @@ def test_import_from_amazon_common_node_should_be_modified(import_statement, exp "import_statement, expected", [ ( - "from sagemaker.amazon.common import numpy_to_record_serializer", - "from sagemaker.amazon.common import RecordSerializer", + "from sagemaker.serializers import numpy_to_record_serializer", + "from sagemaker.serializers import RecordSerializer", ), ( - "from sagemaker.amazon.common import record_deserializer", - "from sagemaker.amazon.common import RecordDeserializer", - ), - ( - "from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer", - "from sagemaker.amazon.common import RecordSerializer, RecordDeserializer", - ), - ( - "from sagemaker.amazon.common import write_spmatrix_to_sparse_tensor, numpy_to_record_serializer", - "from sagemaker.amazon.common import write_spmatrix_to_sparse_tensor, RecordSerializer", + "from sagemaker.deserializers import record_deserializer", + "from sagemaker.deserializers import RecordDeserializer", ), ], ) diff --git a/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_feature_processor_lineage.py b/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_feature_processor_lineage.py index 118800dd0f..f149823b2f 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_feature_processor_lineage.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_feature_processor_lineage.py @@ -113,69 +113,85 @@ def test_create_lineage_when_no_lineage_exists_with_fg_only(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - side_effect=RESOURCE_NOT_FOUND_EXCEPTION, - ) as load_pipeline_context_method, patch.object( - PipelineLineageEntityHandler, - "create_pipeline_context", - return_value=PIPELINE_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - [], - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + ): lineage_handler.create_lineage() retrieve_feature_group_context_arns_method.assert_has_calls( @@ -259,75 +275,92 @@ def test_create_lineage_when_no_lineage_exists_with_raw_data_only(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - side_effect=RESOURCE_NOT_FOUND_EXCEPTION, - ) as load_pipeline_context_method, patch.object( - PipelineLineageEntityHandler, - "create_pipeline_context", - return_value=PIPELINE_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - [], - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_called_once_with( @@ -408,75 +441,92 @@ def test_create_lineage_when_no_lineage_exists_with_fg_and_raw_data_with_tags(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - side_effect=RESOURCE_NOT_FOUND_EXCEPTION, - ) as load_pipeline_context_method, patch.object( - PipelineLineageEntityHandler, - "create_pipeline_context", - return_value=PIPELINE_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - [], - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -569,75 +619,92 @@ def test_create_lineage_when_no_lineage_exists_with_no_transformation_code(): output=FEATURE_GROUP_DATA_SOURCE[0].name, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=None, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - side_effect=RESOURCE_NOT_FOUND_EXCEPTION, - ) as load_pipeline_context_method, patch.object( - PipelineLineageEntityHandler, - "create_pipeline_context", - return_value=PIPELINE_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - [], - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=None, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -728,78 +795,96 @@ def test_create_lineage_when_already_exist_with_no_version_change(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=PIPELINE_CONTEXT, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as create_pipeline_version_context_method, patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as create_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -925,73 +1010,91 @@ def test_create_lineage_when_already_exist_with_changed_raw_data(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1]], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1]], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -1140,74 +1243,92 @@ def test_create_lineage_when_already_exist_with_changed_input_fg(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[FEATURE_GROUP_INPUT[0], FEATURE_GROUP_INPUT[0]], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[FEATURE_GROUP_INPUT[0], FEATURE_GROUP_INPUT[0]], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -1354,78 +1475,96 @@ def test_create_lineage_when_already_exist_with_changed_output_fg(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[1], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[1], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -1576,78 +1715,96 @@ def test_create_lineage_when_already_exist_with_changed_transformation_code(): transformation_code=TRANSFORMATION_CODE_INPUT_2, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_2, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -1778,78 +1935,96 @@ def test_create_lineage_when_already_exist_with_last_transformation_code_as_none transformation_code=TRANSFORMATION_CODE_INPUT_2, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_2, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -1968,77 +2143,95 @@ def test_create_lineage_when_already_exist_with_all_previous_transformation_code transformation_code=TRANSFORMATION_CODE_INPUT_2, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_2, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - iter([]), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + iter([]), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -2154,78 +2347,96 @@ def test_create_lineage_when_already_exist_with_removed_transformation_code(): output=FEATURE_GROUP_DATA_SOURCE[0].name, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=None, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=None, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -2370,15 +2581,18 @@ def test_get_pipeline_lineage_names_when_lineage_exists(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=PIPELINE_CONTEXT, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method: + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + ): return_value = lineage_handler.get_pipeline_lineage_names() assert return_value == dict( @@ -2416,28 +2630,34 @@ def test_create_schedule_lineage(): pipeline=PIPELINE, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=PIPELINE_CONTEXT, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - S3LineageEntityHandler, - "retrieve_pipeline_schedule_artifact", - return_value=SCHEDULE_ARTIFACT_RESULT, - ) as retrieve_pipeline_schedule_artifact_method, patch.object( - LineageAssociationHandler, - "add_upstream_schedule_associations", - ) as add_upstream_schedule_associations_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + S3LineageEntityHandler, + "retrieve_pipeline_schedule_artifact", + return_value=SCHEDULE_ARTIFACT_RESULT, + ) as retrieve_pipeline_schedule_artifact_method, + patch.object( + LineageAssociationHandler, + "add_upstream_schedule_associations", + ) as add_upstream_schedule_associations_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_schedule_lineage( pipeline_name=PIPELINE_NAME, schedule_arn=SCHEDULE_ARN, @@ -2487,28 +2707,34 @@ def test_create_trigger_lineage(): pipeline=PIPELINE, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=PIPELINE_CONTEXT, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - S3LineageEntityHandler, - "retrieve_pipeline_trigger_artifact", - return_value=PIPELINE_TRIGGER_ARTIFACT, - ) as retrieve_pipeline_trigger_artifact_method, patch.object( - LineageAssociationHandler, - "_add_association", - ) as add_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + S3LineageEntityHandler, + "retrieve_pipeline_trigger_artifact", + return_value=PIPELINE_TRIGGER_ARTIFACT, + ) as retrieve_pipeline_trigger_artifact_method, + patch.object( + LineageAssociationHandler, + "_add_association", + ) as add_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_trigger_lineage( pipeline_name=PIPELINE_NAME, trigger_arn=TRIGGER_ARN, @@ -2564,56 +2790,68 @@ def test_upsert_tags_for_lineage_resources(): ) lineage_handler.sagemaker_session.boto_session = Mock() lineage_handler.sagemaker_session.sagemaker_client = Mock() - with patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - iter([]), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, "load_artifact_from_arn", return_value=ARTIFACT_RESULT - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, "_load_artifact_from_s3_uri", return_value=ARTIFACT_SUMMARY - ) as load_artifact_from_s3_uri_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags, patch.object( - Context, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as context_set_tags, patch.object( - EventBridgeSchedulerHelper, "describe_schedule", return_value=dict(Arn="schedule_arn") - ) as get_event_bridge_schedule, patch.object( - EventBridgeRuleHelper, "describe_rule", return_value=dict(Arn="rule_arn") - ) as get_event_bridge_rule: + with ( + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + iter([]), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, "load_artifact_from_arn", return_value=ARTIFACT_RESULT + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, "_load_artifact_from_s3_uri", return_value=ARTIFACT_SUMMARY + ) as load_artifact_from_s3_uri_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + patch.object( + Context, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as context_set_tags, + patch.object( + EventBridgeSchedulerHelper, "describe_schedule", return_value=dict(Arn="schedule_arn") + ) as get_event_bridge_schedule, + patch.object( + EventBridgeRuleHelper, "describe_rule", return_value=dict(Arn="rule_arn") + ) as get_event_bridge_rule, + ): lineage_handler.upsert_tags_for_lineage_resources(TAGS) retrieve_raw_data_artifact_method.assert_has_calls( diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py index 57f4a54f78..7b35174940 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py @@ -907,7 +907,9 @@ def test_remote_decorator_fields_consistency(get_execution_role, session): "use_spot_instances", "max_wait_time_in_seconds", "custom_file_filter", + "disable_output_compression", "use_torchrun", + "use_mpirun", "nproc_per_node", } diff --git a/tests/unit/sagemaker/huggingface/test_llm_utils.py b/tests/unit/sagemaker/huggingface/test_llm_utils.py index 675a6fd885..9bb1b451a1 100644 --- a/tests/unit/sagemaker/huggingface/test_llm_utils.py +++ b/tests/unit/sagemaker/huggingface/test_llm_utils.py @@ -65,7 +65,7 @@ def test_huggingface_model_metadata_unauthorized_exception(self, mock_urllib): "Trying to access a gated/private HuggingFace model without valid credentials. " "Please provide a HUGGING_FACE_HUB_TOKEN in env_vars" ) - self.assertEquals(expected_error_msg, str(context.exception)) + self.assertEqual(expected_error_msg, str(context.exception)) @patch("sagemaker.huggingface.llm_utils.urllib") def test_huggingface_model_metadata_general_exception(self, mock_urllib): @@ -76,7 +76,7 @@ def test_huggingface_model_metadata_general_exception(self, mock_urllib): expected_error_msg = ( f"Did not find model metadata for the following HuggingFace Model ID {MOCK_HF_ID}" ) - self.assertEquals(expected_error_msg, str(context.exception)) + self.assertEqual(expected_error_msg, str(context.exception)) @patch("huggingface_hub.snapshot_download") def test_download_huggingface_model_metadata(self, mock_snapshot_download): diff --git a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py index 28525a390c..084c2d1438 100644 --- a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py +++ b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import pytest +from packaging.version import parse from sagemaker.huggingface import get_huggingface_llm_image_uri from tests.unit.sagemaker.image_uris import expected_uris, conftest @@ -46,6 +47,8 @@ "2.0.2": "2.3.0-tgi2.0.2-gpu-py310-cu121-ubuntu22.04", "2.2.0": "2.3.0-tgi2.2.0-gpu-py310-cu121-ubuntu22.04-v2.0", "2.3.1": "2.4.0-tgi2.3.1-gpu-py311-cu124-ubuntu22.04", + "2.4.0": "2.4.0-tgi2.4.0-gpu-py311-cu124-ubuntu22.04-v2.2", + "3.0.1": "2.4.0-tgi3.0.1-gpu-py311-cu124-ubuntu22.04-v2.1", }, "inf2": { "0.0.16": "1.13.1-optimum0.0.16-neuronx-py310-ubuntu22.04", @@ -58,6 +61,7 @@ "0.0.23": "2.1.2-optimum0.0.23-neuronx-py310-ubuntu22.04", "0.0.24": "2.1.2-optimum0.0.24-neuronx-py310-ubuntu22.04", "0.0.25": "2.1.2-optimum0.0.25-neuronx-py310-ubuntu22.04", + "0.0.27": "2.1.2-optimum0.0.27-neuronx-py310-ubuntu22.04", }, } @@ -69,10 +73,31 @@ def test_huggingface_uris(load_config): VERSIONS = load_config["inference"]["versions"] device = load_config["inference"]["processors"][0] backend = "huggingface-neuronx" if device == "inf2" else "huggingface" + + # Fail if device is not in mapping + if device not in HF_VERSIONS_MAPPING: + raise ValueError(f"Device {device} not found in HF_VERSIONS_MAPPING") + + # Get highest version for the device + highest_version = max(HF_VERSIONS_MAPPING[device].keys(), key=lambda x: parse(x)) + for version in VERSIONS: ACCOUNTS = load_config["inference"]["versions"][version]["registries"] for region in ACCOUNTS.keys(): uri = get_huggingface_llm_image_uri(backend, region=region, version=version) + + # Skip only if test version is higher than highest known version. + # There's now automation to add new TGI releases to image_uri_config directory + # that doesn't involve a human raising a PR. + if parse(version) > parse(highest_version): + print( + f"Skipping version check for {version} as there is " + "automation that now updates the image_uri_config " + "without a human raising a PR. Tests will pass for " + f"versions higher than {highest_version} that are not in HF_VERSIONS_MAPPING." + ) + continue + expected = expected_uris.huggingface_llm_framework_uri( "huggingface-pytorch-tgi-inference", ACCOUNTS[region], diff --git a/tests/unit/sagemaker/image_uris/test_smp_v2.py b/tests/unit/sagemaker/image_uris/test_smp_v2.py index b1297822f7..3177384e7e 100644 --- a/tests/unit/sagemaker/image_uris/test_smp_v2.py +++ b/tests/unit/sagemaker/image_uris/test_smp_v2.py @@ -36,15 +36,18 @@ def test_smp_v2(load_config): for region in ACCOUNTS.keys(): for instance_type in CONTAINER_VERSIONS.keys(): cuda_vers = CONTAINER_VERSIONS[instance_type] - if ( - "2.1" in version - or "2.2" in version - or "2.3" in version - or "2.4" in version + supported_smp_pt_versions_cu124 = ("2.5",) + supported_smp_pt_versions_cu121 = ("2.1", "2.2", "2.3", "2.4") + if any( + pt_version in version for pt_version in supported_smp_pt_versions_cu124 + ): + cuda_vers = "cu124" + elif any( + pt_version in version for pt_version in supported_smp_pt_versions_cu121 ): cuda_vers = "cu121" - if "2.3.1" == version or "2.4.1" == version: + if version in ("2.3.1", "2.4.1", "2.5.1"): py_version = "py311" uri = image_uris.get_training_image_uri( diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 59f38bd189..ae02c597da 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -3059,7 +3059,7 @@ "g4": { "regional_properties": {"image_uri": "$gpu_image_uri"}, "properties": { - "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + "training_artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" }, }, "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, @@ -3135,7 +3135,7 @@ }, "p9": { "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": {"artifact_key": "do/re/mi"}, + "properties": {"training_artifact_key": "do/re/mi"}, }, "m2": { "regional_properties": {"image_uri": "$cpu_image_uri"}, @@ -3214,13 +3214,13 @@ "ml.p9.12xlarge": { "properties": { "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "artifact_key": "you/not/entertained", + "training_artifact_key": "you/not/entertained", } }, "g6": { "properties": { "environment_variables": {"BLAH": "4"}, - "artifact_key": "path/to/training/artifact.tar.gz", + "training_artifact_key": "path/to/training/artifact.tar.gz", "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/", } }, @@ -5046,7 +5046,7 @@ "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, "m5": { "regional_properties": {"image_uri": "$cpu_ecr_uri_1"}, - "properties": {"artifact_key": "hello-world-1"}, + "properties": {"training_artifact_key": "hello-world-1"}, }, "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, @@ -14360,7 +14360,7 @@ "jmespath==1.0.1", "jsonschema==4.17.3", "multiprocess==0.70.14", - "numpy==1.24.3", + "numpy==1.26.4", "oscrypto==1.3.0", "packaging==23.1", "pandas==2.0.2", @@ -15553,6 +15553,8 @@ }, "inference_enable_network_isolation": True, "training_enable_network_isolation": True, + "default_training_dataset_uri": None, + "default_training_dataset_key": "training-datasets/tf_flowers/", "resource_name_base": "pt-ic-mobilenet-v2", "hosting_eula_key": None, "hosting_model_package_arns": {}, @@ -15988,6 +15990,18 @@ "spec_key": "community_models_specs/tensorflow-ic-" "imagenet-inception-v3-classification-4/specs_v3.0.0.json", }, + { + "model_id": "meta-textgeneration-llama-2-7b", + "version": "4.9.0", + "min_version": "2.49.0", + "spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.9.0.json", + }, + { + "model_id": "meta-textgeneration-llama-2-7b", + "version": "4.13.0", + "min_version": "2.49.0", + "spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.13.0.json", + }, ] BASE_PROPRIETARY_HEADER = { @@ -17234,13 +17248,13 @@ "g4dn": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g4dn/v1.0.0/", # noqa: E501 }, }, "g5": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g5/v1.0.0/", # noqa: E501 }, }, "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, @@ -17249,13 +17263,13 @@ "p3dn": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p3dn/v1.0.0/", # noqa: E501 }, }, "p4d": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p4d/v1.0.0/", # noqa: E501 }, }, "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, @@ -17391,7 +17405,7 @@ "texttable==1.6.7", "tokenize-rt==5.1.0", "tokenizers==0.13.3", - "torch==2.2.0", + "torch>=2.6.0", "transformers==4.33.3", "triton==2.2.0", "typing-extensions==4.8.0", diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 1fd2a47aca..4a64b413f4 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -392,23 +392,6 @@ def test_gated_model_s3_uri( mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session - with pytest.raises(ValueError) as e: - JumpStartEstimator( - model_id=model_id, - environment={ - "accept_eula": "false", - "what am i": "doing", - "SageMakerGatedModelS3Uri": "none of your business", - }, - ) - assert str(e.value) == ( - "Need to define ‘accept_eula'='true' within Environment. " - "Model 'meta-textgeneration-llama-2-7b-f' requires accepting end-user " - "license agreement (EULA). See " - "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com/fmhMetadata/eula/llamaEula.txt" - " for terms of use." - ) - mock_estimator_init.reset_mock() estimator = JumpStartEstimator(model_id=model_id, environment={"accept_eula": "true"}) @@ -510,6 +493,151 @@ def test_gated_model_s3_uri( ], ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_gated_model_s3_uri_with_eula_in_fit( + self, + mock_estimator_deploy: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session_estimator: mock.Mock, + mock_session_model: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_timestamp: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + + mock_timestamp.return_value = "8675309" + + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + + model_id, _ = "js-gated-artifact-trainable-model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session_estimator.return_value = sagemaker_session + mock_session_model.return_value = sagemaker_session + + mock_estimator_init.reset_mock() + + estimator = JumpStartEstimator(model_id=model_id) + + mock_estimator_init.assert_called_once_with( + instance_type="ml.g5.12xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-" + "pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "meta/transfer_learning/textgeneration/v1.0.6/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={ + "int8_quantization": "False", + "enable_fsdp": "True", + "epoch": "1", + "learning_rate": "0.0001", + "lora_r": "8", + "lora_alpha": "32", + "lora_dropout": "0.05", + "instruction_tuned": "False", + "chat_dataset": "True", + "add_input_output_demarcation_key": "True", + "per_device_train_batch_size": "1", + "per_device_eval_batch_size": "1", + "max_train_samples": "-1", + "max_val_samples": "-1", + "seed": "10", + "max_input_length": "-1", + "validation_split_ratio": "0.2", + "train_data_split_seed": "0", + "preprocessing_num_workers": "None", + }, + metric_definitions=[ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", + }, + ], + role=execution_role, + sagemaker_session=sagemaker_session, + max_run=360000, + enable_network_isolation=True, + encrypt_inter_container_traffic=True, + environment={ + "SageMakerGatedModelS3Uri": "s3://sagemaker-repository-pdx/" + "model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + }, + tags=[ + { + "Key": "sagemaker-sdk:jumpstart-model-id", + "Value": "js-gated-artifact-trainable-model", + }, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.4"}, + ], + ) + + channels = { + "training": f"s3://{get_jumpstart_content_bucket(region)}/" + f"some-training-dataset-doesn't-matter", + } + + estimator.fit(channels, accept_eula=True) + + mock_estimator_fit.assert_called_once_with( + inputs=channels, + wait=True, + job_name="meta-textgeneration-llama-2-7b-f-8675309", + ) + + assert hasattr(estimator, "model_access_config") + assert hasattr(estimator, "hub_access_config") + + assert estimator.model_access_config == {"AcceptEula": True} + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + predictor_cls=Predictor, + endpoint_name="meta-textgeneration-llama-2-7b-f-8675309", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118", + wait=True, + model_data_download_timeout=3600, + container_startup_health_check_timeout=3600, + role=execution_role, + enable_network_isolation=True, + model_name="meta-textgeneration-llama-2-7b-f-8675309", + use_compiled_model=False, + tags=[ + { + "Key": "sagemaker-sdk:jumpstart-model-id", + "Value": "js-gated-artifact-trainable-model", + }, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.4"}, + ], + ) + @mock.patch( "sagemaker.jumpstart.artifacts.environment_variables.get_jumpstart_gated_content_bucket" ) @@ -1218,7 +1346,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): and reach out to JumpStart team.""" init_args_to_skip: Set[str] = set(["kwargs"]) - fit_args_to_skip: Set[str] = set() + fit_args_to_skip: Set[str] = set(["accept_eula"]) deploy_args_to_skip: Set[str] = set(["kwargs"]) parent_class_init = Estimator.__init__ @@ -1243,8 +1371,8 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): js_class_fit = JumpStartEstimator.fit js_class_fit_args = set(signature(js_class_fit).parameters.keys()) - assert js_class_fit_args - parent_class_fit_args == set() - assert parent_class_fit_args - js_class_fit_args == fit_args_to_skip + assert js_class_fit_args - parent_class_fit_args == fit_args_to_skip + assert parent_class_fit_args - js_class_fit_args == set() model_class_init = Model.__init__ model_class_init_args = set(signature(model_class_init).parameters.keys()) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py index 073921d5ba..39eca166ee 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py @@ -123,16 +123,16 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_model_init_kwargs.return_value = {} - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), config_role) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), config_role) assert "enable_network_isolation" not in mock_estimator_init.call_args[1] assert "encrypt_inter_container_traffic" not in mock_estimator_init.call_args[1] estimator.deploy() - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals(mock_estimator_deploy.call_args[1].get("role"), config_inference_role) + self.assertEqual(mock_estimator_deploy.call_args[1].get("role"), config_inference_role) assert "enable_network_isolation" not in mock_estimator_deploy.call_args[1] @@ -181,13 +181,13 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), config_role) - self.assertEquals( + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), config_role) + self.assertEqual( mock_estimator_init.call_args[1].get("enable_network_isolation"), config_enable_network_isolation, ) - self.assertEquals( + self.assertEqual( mock_estimator_init.call_args[1].get("encrypt_inter_container_traffic"), config_intercontainer_encryption, ) @@ -200,11 +200,11 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( estimator.deploy() - self.assertEquals(mock_get_sagemaker_config_value.call_count, 6) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 6) - self.assertEquals(mock_estimator_deploy.call_args[1].get("role"), config_inference_role) + self.assertEqual(mock_estimator_deploy.call_args[1].get("role"), config_inference_role) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("enable_network_isolation"), config_inference_enable_network_isolation, ) @@ -257,13 +257,13 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( encrypt_inter_container_traffic=override_encrypt_inter_container_traffic, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_estimator_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) - self.assertEquals( + self.assertEqual( mock_estimator_init.call_args[1].get("encrypt_inter_container_traffic"), override_encrypt_inter_container_traffic, ) @@ -280,13 +280,13 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( enable_network_isolation=override_inference_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("role"), mock_inference_override_role ) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("enable_network_isolation"), override_inference_enable_network_isolation, ) @@ -336,13 +336,13 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( encrypt_inter_container_traffic=override_encrypt_inter_container_traffic, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_estimator_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) - self.assertEquals( + self.assertEqual( mock_estimator_init.call_args[1].get("encrypt_inter_container_traffic"), override_encrypt_inter_container_traffic, ) @@ -355,13 +355,13 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( enable_network_isolation=override_inference_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("role"), mock_inference_override_role ) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("enable_network_isolation"), override_inference_enable_network_isolation, ) @@ -412,8 +412,8 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), execution_role) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_estimator_init.call_args[1] assert "encrypt_inter_container_traffic" not in mock_estimator_init.call_args[1] @@ -421,9 +421,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_model_init_kwargs.return_value = {} - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals(mock_estimator_deploy.call_args[1].get("role"), execution_role) + self.assertEqual(mock_estimator_deploy.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_estimator_deploy.call_args[1] @@ -475,13 +475,13 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), execution_role) - self.assertEquals( + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), execution_role) + self.assertEqual( mock_estimator_init.call_args[1].get("enable_network_isolation"), metadata_enable_network_isolation, ) - self.assertEquals( + self.assertEqual( mock_estimator_init.call_args[1].get("encrypt_inter_container_traffic"), metadata_intercontainer_encryption, ) @@ -492,11 +492,11 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( estimator.deploy() - self.assertEquals(mock_get_sagemaker_config_value.call_count, 6) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 6) - self.assertEquals(mock_estimator_deploy.call_args[1].get("role"), execution_role) + self.assertEqual(mock_estimator_deploy.call_args[1].get("role"), execution_role) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("enable_network_isolation"), metadata_inference_enable_network_isolation, ) @@ -548,13 +548,13 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( encrypt_inter_container_traffic=override_encrypt_inter_container_traffic, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_estimator_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) - self.assertEquals( + self.assertEqual( mock_estimator_init.call_args[1].get("encrypt_inter_container_traffic"), override_encrypt_inter_container_traffic, ) @@ -568,11 +568,11 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( enable_network_isolation=override_inference_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals(mock_estimator_deploy.call_args[1].get("role"), override_inference_role) + self.assertEqual(mock_estimator_deploy.call_args[1].get("role"), override_inference_role) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("enable_network_isolation"), override_inference_enable_network_isolation, ) @@ -618,13 +618,13 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( enable_network_isolation=override_enable_network_isolation, encrypt_inter_container_traffic=override_encrypt_inter_container_traffic, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_estimator_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) - self.assertEquals( + self.assertEqual( mock_estimator_init.call_args[1].get("encrypt_inter_container_traffic"), override_encrypt_inter_container_traffic, ) @@ -634,11 +634,11 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( enable_network_isolation=override_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals(mock_estimator_deploy.call_args[1].get("role"), override_inference_role) + self.assertEqual(mock_estimator_deploy.call_args[1].get("role"), override_inference_role) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py index 8522b33bc3..29efb6b31f 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_hub.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -16,7 +16,6 @@ import pytest from mock import Mock from sagemaker.jumpstart.hub.hub import Hub -from sagemaker.jumpstart.hub.types import S3ObjectLocation REGION = "us-east-1" @@ -60,48 +59,34 @@ def test_instantiates(sagemaker_session): @pytest.mark.parametrize( - ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + ("hub_name,hub_description,,hub_display_name,hub_search_keywords,tags"), [ - pytest.param("MockHub1", "this is my sagemaker hub", None, None, None, None), + pytest.param("MockHub1", "this is my sagemaker hub", None, None, None), pytest.param( "MockHub2", "this is my sagemaker hub two", - None, "DisplayMockHub2", ["mock", "hub", "123"], [{"Key": "tag-key-1", "Value": "tag-value-1"}], ), ], ) -@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location") def test_create_with_no_bucket_name( - mock_generate_hub_storage_location, sagemaker_session, hub_name, hub_description, - hub_bucket_name, hub_display_name, hub_search_keywords, tags, ): - storage_location = S3ObjectLocation( - "sagemaker-hubs-us-east-1-123456789123", f"{hub_name}-{FAKE_TIME.timestamp()}" - ) - mock_generate_hub_storage_location.return_value = storage_location create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) - sagemaker_session.describe_hub.return_value = { - "S3StorageConfig": {"S3OutputPath": f"s3://{hub_bucket_name}/{storage_location.key}"} - } hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session) request = { "hub_name": hub_name, "hub_description": hub_description, "hub_display_name": hub_display_name, "hub_search_keywords": hub_search_keywords, - "s3_storage_config": { - "S3OutputPath": f"s3://sagemaker-hubs-us-east-1-123456789123/{storage_location.key}" - }, "tags": tags, } response = hub.create( @@ -128,9 +113,9 @@ def test_create_with_no_bucket_name( ), ], ) -@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location") +@patch("sagemaker.jumpstart.hub.hub.datetime") def test_create_with_bucket_name( - mock_generate_hub_storage_location, + mock_datetime, sagemaker_session, hub_name, hub_description, @@ -139,8 +124,8 @@ def test_create_with_bucket_name( hub_search_keywords, tags, ): - storage_location = S3ObjectLocation(hub_bucket_name, f"{hub_name}-{FAKE_TIME.timestamp()}") - mock_generate_hub_storage_location.return_value = storage_location + mock_datetime.now.return_value = FAKE_TIME + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name) @@ -149,7 +134,9 @@ def test_create_with_bucket_name( "hub_description": hub_description, "hub_display_name": hub_display_name, "hub_search_keywords": hub_search_keywords, - "s3_storage_config": {"S3OutputPath": f"s3://mock-bucket-123/{storage_location.key}"}, + "s3_storage_config": { + "S3OutputPath": f"s3://mock-bucket-123/{hub_name}-{FAKE_TIME.timestamp()}" + }, "tags": tags, } response = hub.create( @@ -192,6 +179,39 @@ def test_describe_model_success(mock_describe_hub_content_response, sagemaker_se ) +@patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json") +def test_describe_model_one_thrown_error(mock_describe_hub_content_response, sagemaker_session): + mock_describe_hub_content_response.return_value = Mock() + mock_list_hub_content_versions = sagemaker_session.list_hub_content_versions + mock_list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0"}, + {"HubContentVersion": "2.0"}, + {"HubContentVersion": "3.0"}, + ] + } + mock_describe_hub_content = sagemaker_session.describe_hub_content + mock_describe_hub_content.side_effect = [ + Exception("Some exception"), + {"HubContentName": "test-model", "HubContentVersion": "3.0"}, + ] + + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + + with patch("sagemaker.jumpstart.hub.utils.get_hub_model_version") as mock_get_hub_model_version: + mock_get_hub_model_version.return_value = "3.0" + + hub.describe_model("test-model") + + mock_describe_hub_content.asssert_called_times(2) + mock_describe_hub_content.assert_called_with( + hub_name=HUB_NAME, + hub_content_name="test-model", + hub_content_version="3.0", + hub_content_type="Model", + ) + + def test_create_hub_content_reference(sagemaker_session): hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) model_name = "mock-model-one-huggingface" diff --git a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py index 11798bc854..ebd90d98d2 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py @@ -923,15 +923,13 @@ def test_hub_content_document_from_json_obj(): "g4dn": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-" - "huggingface-llm-gemma-2b-instruct.tar.gz", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g4dn/v1.0.0/", # noqa: E501 }, }, "g5": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-" - "huggingface-llm-gemma-2b-instruct.tar.gz", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g5/v1.0.0/", # noqa: E501 }, }, "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, @@ -940,15 +938,13 @@ def test_hub_content_document_from_json_obj(): "p3dn": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-" - "huggingface-llm-gemma-2b-instruct.tar.gz", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p3dn/v1.0.0/", # noqa: E501 }, }, "p4d": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-" - "huggingface-llm-gemma-2b-instruct.tar.gz", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p4d/v1.0.0/", # noqa: E501 }, }, "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index 22bc527b18..5745a7f79c 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -14,7 +14,10 @@ from unittest.mock import patch, Mock from sagemaker.jumpstart.types import HubArnExtractedInfo -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.constants import ( + JUMPSTART_DEFAULT_REGION_NAME, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) from sagemaker.jumpstart.hub import parser_utils, utils @@ -80,6 +83,17 @@ def test_construct_hub_arn_from_name(): ) +def test_construct_hub_arn_from_name_with_session_none(): + hub_name = "my-cool-hub" + account_id = DEFAULT_JUMPSTART_SAGEMAKER_SESSION.account_id() + boto_region_name = DEFAULT_JUMPSTART_SAGEMAKER_SESSION.boto_region_name + + assert ( + utils.construct_hub_arn_from_name(hub_name=hub_name, session=None) + == f"arn:aws:sagemaker:{boto_region_name}:{account_id}:hub/{hub_name}" + ) + + def test_construct_hub_model_arn_from_inputs(): model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" @@ -96,6 +110,23 @@ def test_construct_hub_model_arn_from_inputs(): ) +def test_construct_hub_model_reference_arn_from_inputs(): + model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + hub_content_arn_prefix = "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub" + + assert ( + utils.construct_hub_model_reference_arn_from_inputs(hub_arn, model_name, version) + == hub_content_arn_prefix + "/ModelReference/pytorch-ic-imagenet-v2/1.0.2" + ) + + version = "*" + assert ( + utils.construct_hub_model_reference_arn_from_inputs(hub_arn, model_name, version) + == hub_content_arn_prefix + "/ModelReference/pytorch-ic-imagenet-v2/*" + ) + + def test_generate_hub_arn_for_init_kwargs(): hub_name = "my-hub-name" hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" @@ -142,30 +173,6 @@ def test_generate_hub_arn_for_init_kwargs(): assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn -def test_create_hub_bucket_if_it_does_not_exist_hub_arn(): - mock_sagemaker_session = Mock() - mock_sagemaker_session.account_id.return_value = "123456789123" - mock_sagemaker_session.client("sts").get_caller_identity.return_value = { - "Account": "123456789123" - } - hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" - # Mock custom session with custom values - mock_custom_session = Mock() - mock_custom_session.account_id.return_value = "000000000000" - mock_custom_session.boto_region_name = "us-east-2" - mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None - mock_sagemaker_session.boto_region_name = "us-east-1" - - bucket_name = "sagemaker-hubs-us-east-1-123456789123" - created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( - sagemaker_session=mock_sagemaker_session - ) - - mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() - assert created_hub_bucket_name == bucket_name - assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn - - def test_is_gated_bucket(): assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True @@ -176,23 +183,6 @@ def test_is_gated_bucket(): assert utils.is_gated_bucket("") is False -def test_create_hub_bucket_if_it_does_not_exist(): - mock_sagemaker_session = Mock() - mock_sagemaker_session.account_id.return_value = "123456789123" - mock_sagemaker_session.client("sts").get_caller_identity.return_value = { - "Account": "123456789123" - } - mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None - mock_sagemaker_session.boto_region_name = "us-east-1" - bucket_name = "sagemaker-hubs-us-east-1-123456789123" - created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( - sagemaker_session=mock_sagemaker_session - ) - - mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() - assert created_hub_bucket_name == bucket_name - - @patch("sagemaker.session.Session") def test_get_hub_model_version_success(mock_session): hub_name = "test_hub" diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index be961828f4..d9b126f651 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -794,7 +794,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): and reach out to JumpStart team.""" init_args_to_skip: Set[str] = set(["model_reference_arn"]) - deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn"]) + deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn", "update_endpoint"]) deploy_args_removed_at_deploy_time: Set[str] = set(["model_access_configs"]) parent_class_init = Model.__init__ diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index 2be4bde7e4..a0299ebb1a 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -99,9 +99,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_model_init.call_args[1].get("role"), config_role) + self.assertEqual(mock_model_init.call_args[1].get("role"), config_role) assert "enable_network_isolation" not in mock_model_init.call_args[1] @@ -147,10 +147,10 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( role=override_role, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_model_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_model_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_model_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) @@ -197,10 +197,10 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 2) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 2) - self.assertEquals(mock_model_init.call_args[1].get("role"), config_role) - self.assertEquals( + self.assertEqual(mock_model_init.call_args[1].get("role"), config_role) + self.assertEqual( mock_model_init.call_args[1].get("enable_network_isolation"), config_enable_network_isolation, ) @@ -249,10 +249,10 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( enable_network_isolation=override_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_model_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_model_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_model_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) @@ -299,10 +299,10 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 2) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 2) - self.assertEquals(mock_model_init.call_args[1].get("role"), execution_role) - self.assertEquals( + self.assertEqual(mock_model_init.call_args[1].get("role"), execution_role) + self.assertEqual( mock_model_init.call_args[1].get("enable_network_isolation"), metadata_enable_network_isolation, ) @@ -350,10 +350,10 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( enable_network_isolation=override_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_model_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_model_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_model_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) @@ -398,9 +398,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_model_init.call_args[1].get("role"), execution_role) + self.assertEqual(mock_model_init.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_model_init.call_args[1] @mock.patch( @@ -445,10 +445,10 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( enable_network_isolation=override_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_model_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_model_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_model_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index e687a1c4ac..75aa93a920 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -176,7 +176,7 @@ def test_retrieve_training_artifact_key(self): "image_uri": "$alias_ecr_uri_1", }, "properties": { - "artifact_key": "in/the/way", + "training_artifact_key": "in/the/way", }, } }, diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index da20debc6a..17996f4f15 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -22,10 +22,14 @@ from mock.mock import MagicMock import pytest from mock import patch +from packaging.version import Version + +from sagemaker.jumpstart import utils from sagemaker.jumpstart.cache import ( JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JumpStartModelsCache, ) from sagemaker.jumpstart.constants import ( @@ -33,6 +37,7 @@ ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, ) from sagemaker.jumpstart.types import ( + JumpStartCachedContentValue, JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, @@ -53,6 +58,25 @@ from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region") +@patch( + "sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket" +) +@patch("boto3.client") +def test_jumpstart_cache_init(mock_boto3_client): + cache = JumpStartModelsCache() + assert cache._region == "dummy-region" + assert cache.s3_bucket_name == "dummy-bucket" + assert cache._manifest_file_s3_key == JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY + assert cache._proprietary_manifest_s3_key == JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY + assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION + mock_boto3_client.assert_called_once_with("s3", region_name="dummy-region") + + # Some callers override the session to None, should still be set to default + cache = JumpStartModelsCache(sagemaker_session=None) + assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION + + @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_header(): @@ -160,6 +184,30 @@ def test_jumpstart_cache_get_header(): semantic_version_str="1.0.*", ) + assert JumpStartModelHeader( + { + "model_id": "meta-textgeneration-llama-2-7b", + "version": "4.13.0", + "min_version": "2.49.0", + "spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.13.0.json", + } + ) == cache.get_header( + model_id="meta-textgeneration-llama-2-7b", + semantic_version_str="*", + ) + + assert JumpStartModelHeader( + { + "model_id": "meta-textgeneration-llama-2-7b", + "version": "4.13.0", + "min_version": "2.49.0", + "spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.13.0.json", + } + ) == cache.get_header( + model_id="meta-textgeneration-llama-2-7b", + semantic_version_str="4.*", + ) + assert JumpStartModelHeader( { "model_id": "ai21-summarization", @@ -1119,3 +1167,124 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( ), ] ) + + +@patch.object(JumpStartModelsCache, "_retrieval_function") +def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights( + retrieval_function: Mock, +): + sm_version = Version(utils.get_sagemaker_version()) + new_sm_version = Version(str(sm_version.major + 1) + ".0.0") + print(str(new_sm_version)) + versions = ["1.0.0", "2.9.1", "2.16.0"] + manifest = [ + { + "model_id": "test-model", + "version": version, + "min_version": "2.49.0", + "spec_key": "spec_key", + } + for version in versions + ] + + manifest.append( + { + "model_id": "test-model", + "version": "3.0.0", + "min_version": str(new_sm_version), + "spec_key": "spec_key", + } + ) + + manifest_dict = {} + for header in manifest: + header_obj = JumpStartModelHeader(header) + manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = ( + header_obj + ) + retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict) + key = JumpStartVersionedModelId("test-model", "*") + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None) + + assert_key = JumpStartVersionedModelId("test-model", "2.16.0") + + assert result == assert_key + + +@patch.object(JumpStartModelsCache, "_retrieval_function") +def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights( + retrieval_function: Mock, +): + sm_version = Version(utils.get_sagemaker_version()) + new_sm_version = Version(str(sm_version.major + 1) + ".0.0") + print(str(new_sm_version)) + versions = ["1.0.0", "2.9.1", "2.16.0"] + manifest = [ + { + "model_id": "test-model", + "version": version, + "min_version": "2.49.0", + "spec_key": "spec_key", + } + for version in versions + ] + + manifest.append( + { + "model_id": "test-model", + "version": "3.0.0", + "min_version": str(new_sm_version), + "spec_key": "spec_key", + } + ) + + manifest_dict = {} + for header in manifest: + header_obj = JumpStartModelHeader(header) + manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = ( + header_obj + ) + retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict) + key = JumpStartVersionedModelId("test-model", "*") + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + result = cache._get_proprietary_manifest_key_from_model_id(key=key, value=None) + + assert_key = JumpStartVersionedModelId("test-model", "2.16.0") + + assert result == assert_key + + +@patch.object(JumpStartModelsCache, "_retrieval_function") +def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(retrieval_function: Mock): + sm_version = Version(utils.get_sagemaker_version()) + new_sm_version = Version(str(sm_version.major + 1) + ".0.0") + print(str(new_sm_version)) + versions = ["abc", "2.9.1", "2.16.0"] + manifest = [ + { + "model_id": "test-model", + "version": version, + "min_version": "2.49.0", + "spec_key": "spec_key", + } + for version in versions + ] + + manifest_dict = {} + for header in manifest: + header_obj = JumpStartModelHeader(header) + manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = ( + header_obj + ) + retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict) + key = JumpStartVersionedModelId("test-model", "*") + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None) + + assert_key = JumpStartVersionedModelId("test-model", "abc") + + assert result == assert_key diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 3efa8c8c81..0b5ef63947 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -117,7 +117,7 @@ "g4": { "regional_properties": {"image_uri": "$gpu_image_uri"}, "properties": { - "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + "training_artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" }, }, "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, @@ -193,7 +193,7 @@ }, "p9": { "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": {"artifact_key": "do/re/mi"}, + "properties": {"training_artifact_key": "do/re/mi"}, }, "m2": { "regional_properties": {"image_uri": "$cpu_image_uri"}, @@ -272,13 +272,13 @@ "ml.p9.12xlarge": { "properties": { "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "artifact_key": "you/not/entertained", + "training_artifact_key": "you/not/entertained", } }, "g6": { "properties": { "environment_variables": {"BLAH": "4"}, - "artifact_key": "path/to/training/artifact.tar.gz", + "training_artifact_key": "path/to/training/artifact.tar.gz", "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/", } }, @@ -378,6 +378,7 @@ def test_jumpstart_model_specs(): specs1.training_script_key == "source-directory-tarballs/pytorch/transfer_learning/ic/v2.3.0/sourcedir.tar.gz" ) + assert specs1.default_training_dataset_key == "training-datasets/tf_flowers/" assert specs1.hyperparameters == [ JumpStartHyperparameter( { @@ -952,27 +953,35 @@ def test_jumpstart_hosting_prepacked_artifact_key_instance_variants(): def test_jumpstart_training_artifact_key_instance_variants(): assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.g6.xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.g6.xlarge" + ) == "path/to/training/artifact.tar.gz" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.g4.9xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.g4.9xlarge" + ) == "path/to/prepacked/training/artifact/prefix/number2/" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.p9.9xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.p9.9xlarge" + ) == "do/re/mi" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.p9.12xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.p9.12xlarge" + ) == "you/not/entertained" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key( + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( instance_type="ml.g9dsfsdfs.12xlarge" ) is None diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 7cf8fdc9b6..de9be1d51d 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -13,10 +13,9 @@ from __future__ import absolute_import import os from unittest import TestCase -from unittest.mock import call - +from unittest.mock import call, mock_open, Mock, patch +import json from botocore.exceptions import ClientError -from mock.mock import Mock, patch import pytest import boto3 import random @@ -24,6 +23,7 @@ from sagemaker import session from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( + _load_region_config, DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE, @@ -38,6 +38,7 @@ JUMPSTART_RESOURCE_BASE_NAME, NEO_DEFAULT_REGION_NAME, JumpStartScriptScope, + JUMPSTART_LAUNCHED_REGIONS, ) from functools import partial from sagemaker.jumpstart.enums import JumpStartTag, MIMEType, JumpStartModelType @@ -49,6 +50,7 @@ JumpStartBenchmarkStat, JumpStartModelHeader, JumpStartVersionedModelId, + JumpStartLaunchedRegionInfo, ) from tests.unit.sagemaker.jumpstart.utils import ( get_base_spec_with_prototype_configs, @@ -1386,7 +1388,7 @@ def test_no_model_id_no_version_found(self): mock_sagemaker_session.list_tags = mock_list_tags mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] - self.assertEquals( + self.assertEqual( utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), (None, None, None, None), ) @@ -1401,7 +1403,7 @@ def test_model_id_no_version_found(self): {"Key": JumpStartTag.MODEL_ID, "Value": "model_id"}, ] - self.assertEquals( + self.assertEqual( utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), ("model_id", None, None, None), ) @@ -1416,7 +1418,7 @@ def test_no_model_id_version_found(self): {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version"}, ] - self.assertEquals( + self.assertEqual( utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), (None, "model_version", None, None), ) @@ -1428,7 +1430,7 @@ def test_no_config_name_found(self): mock_sagemaker_session.list_tags = mock_list_tags mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] - self.assertEquals( + self.assertEqual( utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), (None, None, None, None), ) @@ -1443,7 +1445,7 @@ def test_inference_config_name_found(self): {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name"}, ] - self.assertEquals( + self.assertEqual( utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), (None, None, "config_name", None), ) @@ -1458,7 +1460,7 @@ def test_training_config_name_found(self): {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "config_name"}, ] - self.assertEquals( + self.assertEqual( utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), (None, None, None, "config_name"), ) @@ -1474,7 +1476,7 @@ def test_both_config_name_found(self): {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "training_config_name"}, ] - self.assertEquals( + self.assertEqual( utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), (None, None, "inference_config_name", "training_config_name"), ) @@ -1490,7 +1492,7 @@ def test_model_id_version_found(self): {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version"}, ] - self.assertEquals( + self.assertEqual( utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), ("model_id", "model_version", None, None), ) @@ -1508,7 +1510,7 @@ def test_multiple_model_id_versions_found(self): {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version_2"}, ] - self.assertEquals( + self.assertEqual( utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), (None, None, None, None), ) @@ -1526,7 +1528,7 @@ def test_multiple_model_id_versions_found_aliases_consistent(self): {"Key": random.choice(EXTRA_MODEL_VERSION_TAGS), "Value": "model_version_1"}, ] - self.assertEquals( + self.assertEqual( utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), ("model_id_1", "model_version_1", None, None), ) @@ -1544,7 +1546,7 @@ def test_multiple_model_id_versions_found_aliases_inconsistent(self): {"Key": random.choice(EXTRA_MODEL_VERSION_TAGS), "Value": "model_version_2"}, ] - self.assertEquals( + self.assertEqual( utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), (None, None, None, None), ) @@ -1562,13 +1564,116 @@ def test_multiple_config_names_found_aliases_inconsistent(self): {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name_2"}, ] - self.assertEquals( + self.assertEqual( utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), ("model_id_1", "model_version_1", None, None), ) mock_list_tags.assert_called_once_with("some-arn") +class TestJumpStartLaunchedRegions(TestCase): + def test_regions_not_empty(self): + self.assertTrue(len(JUMPSTART_LAUNCHED_REGIONS) > 0) + + +class TestLoadRegionConfig(TestCase): + def setUp(self): + # Sample valid config that matches the expected structure + self.valid_config = { + "us-east-1": { + "content_bucket": "jumpstart-cache-prod-us-east-1", + "gated_content_bucket": "jumpstart-private-cache-prod-us-east-1", + "neo_content_bucket": "jumpstart-neo-cache-prod-us-east-1", + }, + "us-west-2": { + "content_bucket": "jumpstart-cache-prod-us-west-2", + }, + } + self.config_json = json.dumps(self.valid_config) + + @patch("builtins.open", new_callable=mock_open) + def test_successful_config_load(self, mock_file): + # Setup mock to return valid config + mock_file.return_value.__enter__().read.return_value = self.config_json + + result = _load_region_config("dummy/path") + + # Verify the returned dictionary contains JumpStartLaunchedRegionInfo objects + self.assertTrue(all(isinstance(region, JumpStartLaunchedRegionInfo) for region in result)) + + for region in result: + if region.region_name == "us-east-1": + self.assertEqual(region.region_name, "us-east-1") + self.assertEqual(region.content_bucket, "jumpstart-cache-prod-us-east-1") + self.assertEqual( + region.gated_content_bucket, "jumpstart-private-cache-prod-us-east-1" + ) + self.assertEqual(region.neo_content_bucket, "jumpstart-neo-cache-prod-us-east-1") + + elif region.region_name == "us-west-2": + self.assertEqual(region.region_name, "us-west-2") + self.assertEqual(region.content_bucket, "jumpstart-cache-prod-us-west-2") + self.assertIsNone(region.gated_content_bucket) + self.assertIsNone(region.neo_content_bucket) + else: + raise AssertionError(f"Unexpected region name found: {region.region_name}") + + @patch("builtins.open", new_callable=mock_open) + def test_missing_required_field(self, mock_file): + # Config missing required content_bucket field + invalid_config = { + "us-east-1": { + "gated_content_bucket": "XXXXXXXXXXX", + "neo_content_bucket": "some-other-bucket", + } + } + mock_file.return_value.__enter__().read.return_value = json.dumps(invalid_config) + + # Should return empty dict due to exception handling + result = _load_region_config("dummy/path") + self.assertEqual(result, set()) + + @patch("builtins.open") + def test_file_not_found(self, mock_file): + # Simulate file not found + mock_file.side_effect = FileNotFoundError() + + # Should return empty dict due to exception handling + result = _load_region_config("dummy/path") + self.assertEqual(result, set()) + + @patch("builtins.open", new_callable=mock_open) + def test_invalid_json(self, mock_file): + # Setup mock to return invalid JSON + mock_file.return_value.__enter__().read.return_value = "invalid json content" + + # Should return empty dict due to exception handling + result = _load_region_config("dummy/path") + self.assertEqual(result, set()) + + @patch("builtins.open", new_callable=mock_open) + def test_empty_config(self, mock_file): + # Setup mock to return empty JSON object + mock_file.return_value.__enter__().read.return_value = "{}" + + result = _load_region_config("dummy/path") + self.assertEqual(result, set()) + + @patch("sagemaker.jumpstart.constants.JUMPSTART_LOGGER") + @patch("builtins.open") + def test_logging_on_error(self, mock_file, mock_logger): + + # Simulate an error + mock_file.side_effect = Exception("Test error") + + result = _load_region_config("dummy/path") + + self.assertEqual(result, set()) + + # Verify error was logged + mock_logger.error.assert_called_once() + + class TestJumpStartLogger(TestCase): @patch.dict("os.environ", {}) @patch("logging.StreamHandler.emit") @@ -2144,6 +2249,22 @@ def test_has_instance_rate_stat(stats, expected): assert utils.has_instance_rate_stat(stats) is expected +def test_get_latest_version(): + assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0"]) == "2.16.0" + + +def test_get_latest_version_empty_list_is_none(): + assert utils.get_latest_version([]) is None + + +def test_get_latest_version_none_is_none(): + assert utils.get_latest_version(None) is None + + +def test_get_latest_version_with_invalid_sem_ver(): + assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0", "abc"]) == "abc" + + @pytest.mark.parametrize( "data, expected", [(None, []), ([], []), (get_base_deployment_configs_metadata(), get_base_deployment_configs())], diff --git a/tests/unit/sagemaker/local/test_local_entities.py b/tests/unit/sagemaker/local/test_local_entities.py index 6a026c316b..74a361cf73 100644 --- a/tests/unit/sagemaker/local/test_local_entities.py +++ b/tests/unit/sagemaker/local/test_local_entities.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import re import os import pytest @@ -290,10 +291,10 @@ def test_start_local_pipeline_with_wrong_parameter_type(sagemaker_local_session) local_pipeline = sagemaker.local.entities._LocalPipeline(pipeline) with pytest.raises(ClientError) as error: local_pipeline.start(PipelineParameters={"MyStr": True}) - assert ( - f"Unexpected type for parameter '{parameter.name}'. Expected " - f"{parameter.parameter_type.python_type} but found {type(True)}." in str(error.value) + expected_error_pattern = ( + r"Unexpected type for parameter 'MyStr'\. Expected .* but found \." ) + assert re.search(expected_error_pattern, str(error.value)) def test_start_local_pipeline_with_empty_parameter_string_value( diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 6bfb28f684..4167ca62c3 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -23,6 +23,7 @@ from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.explainer import ExplainerConfig from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements +from sagemaker.enums import EndpointType from tests.unit.sagemaker.inference_recommender.constants import ( DESCRIBE_COMPILATION_JOB_RESPONSE, DESCRIBE_MODEL_PACKAGE_RESPONSE, @@ -130,6 +131,7 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem model_data_download_timeout=None, container_startup_health_check_timeout=None, routing_config=None, + inference_ami_version=None, ) sagemaker_session.create_model.assert_called_with( @@ -192,6 +194,7 @@ def test_deploy_accelerator_type( model_data_download_timeout=None, container_startup_health_check_timeout=None, routing_config=None, + inference_ami_version=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -519,6 +522,7 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model, model_data_download_timeout=None, container_startup_health_check_timeout=None, routing_config=None, + inference_ami_version=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -956,6 +960,7 @@ def test_deploy_customized_volume_size_and_timeout( model_data_download_timeout=model_data_download_timeout_sec, container_startup_health_check_timeout=startup_health_check_timeout_sec, routing_config=None, + inference_ami_version=None, ) sagemaker_session.create_model.assert_called_with( @@ -1006,6 +1011,7 @@ def test_deploy_with_resources(sagemaker_session, name_from_base, production_var model_data_download_timeout=None, container_startup_health_check_timeout=None, routing_config=None, + inference_ami_version=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( name=name_from_base(MODEL_NAME), @@ -1046,3 +1052,143 @@ def test_deploy_with_name_and_resources(sagemaker_session): async_inference_config_dict=None, live_logging=False, ) + + +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME) +@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) +def test_deploy_with_update_endpoint(production_variant, name_from_base, sagemaker_session): + model = Model( + MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session + ) + + # Mock the create_endpoint_config to return a specific config name + endpoint_config_name = "test-config-name" + sagemaker_session.create_endpoint_config.return_value = endpoint_config_name + + # Test update_endpoint=True scenario + endpoint_name = "existing-endpoint" + model.deploy( + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + endpoint_name=endpoint_name, + update_endpoint=True, + ) + + # Verify create_endpoint_config is called with correct parameters + sagemaker_session.create_endpoint_config.assert_called_with( + name=MODEL_NAME, + model_name=MODEL_NAME, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=None, + tags=None, + kms_key=None, + data_capture_config_dict=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, + explainer_config_dict=None, + async_inference_config_dict=None, + serverless_inference_config=None, + routing_config=None, + inference_ami_version=None, + ) + + # Verify update_endpoint is called with correct parameters + sagemaker_session.update_endpoint.assert_called_with(endpoint_name, endpoint_config_name) + + # Test update_endpoint with serverless config + serverless_inference_config = ServerlessInferenceConfig() + serverless_inference_config_dict = { + "MemorySizeInMB": 2048, + "MaxConcurrency": 5, + } + model.deploy( + endpoint_name=endpoint_name, + update_endpoint=True, + serverless_inference_config=serverless_inference_config, + ) + + sagemaker_session.create_endpoint_config.assert_called_with( + name=MODEL_NAME, + model_name=MODEL_NAME, + initial_instance_count=None, + instance_type=None, + accelerator_type=None, + tags=None, + kms_key=None, + data_capture_config_dict=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, + explainer_config_dict=None, + async_inference_config_dict=None, + serverless_inference_config=serverless_inference_config_dict, + routing_config=None, + inference_ami_version=None, + ) + + # Verify update_endpoint is called with the new config + sagemaker_session.update_endpoint.assert_called_with(endpoint_name, endpoint_config_name) + + # Test update_endpoint with async inference config + async_inference_config = AsyncInferenceConfig( + output_path="s3://bucket/output", failure_path="s3://bucket/failure" + ) + async_inference_config_dict = { + "OutputConfig": { + "S3OutputPath": "s3://bucket/output", + "S3FailurePath": "s3://bucket/failure", + }, + } + model.deploy( + endpoint_name=endpoint_name, + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + update_endpoint=True, + async_inference_config=async_inference_config, + ) + + sagemaker_session.create_endpoint_config.assert_called_with( + name=MODEL_NAME, + model_name=MODEL_NAME, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=None, + tags=None, + kms_key=None, + data_capture_config_dict=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, + explainer_config_dict=None, + async_inference_config_dict=async_inference_config_dict, + serverless_inference_config=None, + routing_config=None, + inference_ami_version=None, + ) + + # Verify update_endpoint is called with the new config + sagemaker_session.update_endpoint.assert_called_with(endpoint_name, endpoint_config_name) + + +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) +def test_deploy_with_update_endpoint_inference_component(production_variant, sagemaker_session): + model = Model( + MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session + ) + + # Test that updating endpoint with inference component raises error + with pytest.raises( + ValueError, match="Currently update_endpoint is supported for single model endpoints" + ): + model.deploy( + endpoint_name="test-endpoint", + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + update_endpoint=True, + resources=RESOURCES, + endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, + ) diff --git a/tests/unit/sagemaker/model/test_framework_model.py b/tests/unit/sagemaker/model/test_framework_model.py index d41dd6f821..432d90bd37 100644 --- a/tests/unit/sagemaker/model/test_framework_model.py +++ b/tests/unit/sagemaker/model/test_framework_model.py @@ -511,6 +511,20 @@ def test_is_repack_with_code_location(repack_model, sagemaker_session): assert not model.is_repack() +@patch("sagemaker.utils.repack_model") +def test_is_repack_with_none_type(repack_model, sagemaker_session): + """Test is_repack() returns a boolean value when source_dir and entry_point are None""" + + model = FrameworkModel( + role=ROLE, + sagemaker_session=sagemaker_session, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + + assert model.is_repack() is False + + @patch("sagemaker.git_utils.git_clone_repo") @patch("sagemaker.model.fw_utils.tar_and_upload_dir") def test_is_repack_with_git_config(tar_and_upload_dir, git_clone_repo, sagemaker_session): diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 9175613662..3d498dfc59 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -1046,6 +1046,20 @@ def test_is_repack_with_code_location(repack_model, sagemaker_session): assert model.is_repack() +@patch("sagemaker.utils.repack_model") +def test_is_repack_with_none_type(repack_model, sagemaker_session): + """Test is_repack() returns a boolean value when source_dir and entry_point are None""" + + model = Model( + role=ROLE, + sagemaker_session=sagemaker_session, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + + assert model.is_repack() is False + + @patch("sagemaker.git_utils.git_clone_repo") @patch("sagemaker.model.fw_utils.tar_and_upload_dir") def test_is_repack_with_git_config(tar_and_upload_dir, git_clone_repo, sagemaker_session): diff --git a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py index 30d6dfdf6c..fe4fa08825 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py @@ -21,12 +21,10 @@ from sagemaker.modules.train.container_drivers.scripts.environment import ( set_env, - log_key_value, log_env_variables, - mask_sensitive_info, HIDDEN_VALUE, ) -from sagemaker.modules.train.container_drivers.utils import safe_serialize, safe_deserialize +from sagemaker.modules.train.container_drivers.common.utils import safe_serialize, safe_deserialize RESOURCE_CONFIG = dict( current_host="algo-1", @@ -75,6 +73,15 @@ }, } +SOURCE_CODE = { + "source_dir": "code", + "entry_script": "train.py", +} + +DISTRIBUTED_CONFIG = { + "process_count_per_node": 2, +} + OUTPUT_FILE = os.path.join(os.path.dirname(__file__), "sm_training.env") # flake8: noqa @@ -89,6 +96,10 @@ export SM_LOG_LEVEL='20' export SM_MASTER_ADDR='algo-1' export SM_MASTER_PORT='7777' +export SM_SOURCE_DIR='/opt/ml/input/data/code' +export SM_ENTRY_SCRIPT='train.py' +export SM_DISTRIBUTED_DRIVER_DIR='/opt/ml/input/data/sm_drivers/distributed_drivers' +export SM_DISTRIBUTED_CONFIG='{"process_count_per_node": 2}' export SM_CHANNEL_TRAIN='/opt/ml/input/data/train' export SM_CHANNEL_VALIDATION='/opt/ml/input/data/validation' export SM_CHANNELS='["train", "validation"]' @@ -112,6 +123,14 @@ """ +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.read_source_code_json", + return_value=SOURCE_CODE, +) +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.read_distributed_json", + return_value=DISTRIBUTED_CONFIG, +) @patch("sagemaker.modules.train.container_drivers.scripts.environment.num_cpus", return_value=8) @patch("sagemaker.modules.train.container_drivers.scripts.environment.num_gpus", return_value=0) @patch("sagemaker.modules.train.container_drivers.scripts.environment.num_neurons", return_value=0) @@ -124,7 +143,13 @@ side_effect=safe_deserialize, ) def test_set_env( - mock_safe_deserialize, mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons + mock_safe_deserialize, + mock_safe_serialize, + mock_num_neurons, + mock_num_gpus, + mock_num_cpus, + mock_read_distributed_json, + mock_read_source_code_json, ): with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): set_env( @@ -137,6 +162,8 @@ def test_set_env( mock_num_cpus.assert_called_once() mock_num_gpus.assert_called_once() mock_num_neurons.assert_called_once() + mock_read_distributed_json.assert_called_once() + mock_read_source_code_json.assert_called_once() with open(OUTPUT_FILE, "r") as f: env_file = f.read().strip() diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py index a1a84da1ab..bf51db8285 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py @@ -15,13 +15,14 @@ import os import sys +import json from unittest.mock import patch, MagicMock sys.modules["utils"] = MagicMock() sys.modules["mpi_utils"] = MagicMock() -from sagemaker.modules.train.container_drivers import mpi_driver # noqa: E402 +from sagemaker.modules.train.container_drivers.distributed_drivers import mpi_driver # noqa: E402 DUMMY_MPI_COMMAND = [ @@ -40,12 +41,7 @@ "script.py", ] -DUMMY_SOURCE_CODE = { - "source_code": "source_code", - "entry_script": "script.py", -} DUMMY_DISTRIBUTED = { - "_type": "mpi", "process_count_per_node": 2, "mpi_additional_options": [ "--verbose", @@ -62,17 +58,28 @@ "SM_HOSTS": '["algo-1", "algo-2"]', "SM_MASTER_ADDR": "algo-1", "SM_HOST_COUNT": "2", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "/opt/ml/input/data/code/script.py", }, ) -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_master_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_worker_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.hyperparameters_to_cli_args") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_mpirun_command") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.execute_commands") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands") def test_mpi_driver_worker( mock_execute_commands, mock_get_mpirun_command, @@ -81,12 +88,8 @@ def test_mpi_driver_worker( mock_bootstrap_master_node, mock_start_sshd_daemon, mock_write_env_vars_to_file, - mock_read_source_code_json, - mock_read_distributed_json, ): mock_hyperparameters_to_cli_args.return_value = [] - mock_read_source_code_json.return_value = DUMMY_SOURCE_CODE - mock_read_distributed_json.return_value = DUMMY_DISTRIBUTED mpi_driver.main() @@ -106,19 +109,32 @@ def test_mpi_driver_worker( "SM_HOSTS": '["algo-1", "algo-2"]', "SM_MASTER_ADDR": "algo-1", "SM_HOST_COUNT": "2", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", }, ) -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_master_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_worker_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_process_count") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.hyperparameters_to_cli_args") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_mpirun_command") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.execute_commands") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_status_file_to_workers") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_process_count") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_status_file_to_workers" +) def test_mpi_driver_master( mock_write_status_file_to_workers, mock_execute_commands, @@ -129,12 +145,8 @@ def test_mpi_driver_master( mock_bootstrap_master_node, mock_start_sshd_daemon, mock_write_env_vars_to_file, - mock_read_source_code_config_json, - mock_read_distributed_json, ): mock_hyperparameters_to_cli_args.return_value = [] - mock_read_source_code_config_json.return_value = DUMMY_SOURCE_CODE - mock_read_distributed_json.return_value = DUMMY_DISTRIBUTED mock_get_mpirun_command.return_value = DUMMY_MPI_COMMAND mock_get_process_count.return_value = 2 mock_execute_commands.return_value = (0, "") diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py new file mode 100644 index 0000000000..35208d708a --- /dev/null +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py @@ -0,0 +1,113 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""MPI Utils Unit Tests.""" +from __future__ import absolute_import + +import subprocess +from unittest.mock import Mock, patch + +import paramiko +import pytest + +# Mock the utils module before importing mpi_utils +mock_utils = Mock() +mock_utils.logger = Mock() +mock_utils.SM_EFA_NCCL_INSTANCES = [] +mock_utils.SM_EFA_RDMA_INSTANCES = [] +mock_utils.get_python_executable = Mock(return_value="/usr/bin/python") + +with patch.dict("sys.modules", {"utils": mock_utils}): + from sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils import ( + CustomHostKeyPolicy, + _can_connect, + write_status_file_to_workers, + ) + +TEST_HOST = "algo-1" +TEST_WORKER = "algo-2" +TEST_STATUS_FILE = "/tmp/test-status" + + +def test_custom_host_key_policy_valid_hostname(): + """Test CustomHostKeyPolicy accepts algo- prefixed hostnames.""" + policy = CustomHostKeyPolicy() + mock_client = Mock() + mock_key = Mock() + mock_key.get_name.return_value = "ssh-rsa" + + policy.missing_host_key(mock_client, "algo-1", mock_key) + + mock_client.get_host_keys.assert_called_once() + mock_client.get_host_keys().add.assert_called_once_with("algo-1", "ssh-rsa", mock_key) + + +def test_custom_host_key_policy_invalid_hostname(): + """Test CustomHostKeyPolicy rejects non-algo prefixed hostnames.""" + policy = CustomHostKeyPolicy() + mock_client = Mock() + mock_key = Mock() + + with pytest.raises(paramiko.SSHException) as exc_info: + policy.missing_host_key(mock_client, "invalid-1", mock_key) + + assert "Unknown host key for invalid-1" in str(exc_info.value) + mock_client.get_host_keys.assert_not_called() + + +@patch("paramiko.SSHClient") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils.logger") +def test_can_connect_success(mock_logger, mock_ssh_client): + """Test successful SSH connection.""" + mock_client = Mock() + mock_ssh_client.return_value.__enter__.return_value = mock_client + mock_client.connect.return_value = None # Successful connection + + result = _can_connect(TEST_HOST) + + assert result is True + mock_client.load_system_host_keys.assert_called_once() + mock_client.set_missing_host_key_policy.assert_called_once() + mock_client.connect.assert_called_once_with(TEST_HOST, port=22) + + +@patch("paramiko.SSHClient") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils.logger") +def test_can_connect_failure(mock_logger, mock_ssh_client): + """Test SSH connection failure.""" + mock_client = Mock() + mock_ssh_client.return_value.__enter__.return_value = mock_client + mock_client.connect.side_effect = paramiko.SSHException("Connection failed") + + result = _can_connect(TEST_HOST) + + assert result is False + mock_client.load_system_host_keys.assert_called_once() + mock_client.set_missing_host_key_policy.assert_called_once() + mock_client.connect.assert_called_once_with(TEST_HOST, port=22) + + +@patch("subprocess.run") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils.logger") +def test_write_status_file_to_workers_failure(mock_logger, mock_run): + """Test failed status file writing to workers with retry timeout.""" + mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") + + with pytest.raises(TimeoutError) as exc_info: + write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE) + + assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value) + assert mock_run.call_count > 1 # Verifies that retries occurred + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py index 4cff07a0c0..2568346158 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py @@ -15,38 +15,38 @@ import os import sys +import json from unittest.mock import patch, MagicMock sys.modules["utils"] = MagicMock() -from sagemaker.modules.train.container_drivers import torchrun_driver # noqa: E402 - -DUMMY_SOURCE_CODE = { - "source_code": "source_code", - "entry_script": "script.py", -} +from sagemaker.modules.train.container_drivers.distributed_drivers import ( # noqa: E402 + torchrun_driver, +) -DUMMY_distributed = {"_type": "torchrun", "process_count_per_node": 2} +DUMMY_DISTRIBUTED = {"process_count_per_node": 2} @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_python_executable", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_python_executable", return_value="python3", ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), ) def test_get_base_pytorch_command_torchrun(mock_pytorch_version, mock_get_python_executable): assert torchrun_driver.get_base_pytorch_command() == ["torchrun"] @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_python_executable", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_python_executable", return_value="python3", ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(1, 8) + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", + return_value=(1, 8), ) def test_get_base_pytorch_command_torch_distributed_launch( mock_pytorch_version, mock_get_python_executable @@ -62,38 +62,29 @@ def test_get_base_pytorch_command_torch_distributed_launch( "SM_CURRENT_INSTANCE_TYPE": "ml.p4d.24xlarge", "SM_NETWORK_INTERFACE_NAME": "eth0", "SM_HOST_COUNT": "1", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", }, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.USER_CODE_PATH", - "/opt/ml/input/data/code", -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_process_count", return_value=2 + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_process_count", + return_value=2, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_base_pytorch_command", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_base_pytorch_command", return_value=["torchrun"], ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_source_code_json", - return_value=DUMMY_SOURCE_CODE, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_json", - return_value=DUMMY_distributed, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.hyperparameters_to_cli_args", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.hyperparameters_to_cli_args", return_value=[], ) def test_create_commands_single_node( mock_hyperparameters_to_cli_args, - mock_read_distributed_json, - mock_read_source_code_json, mock_get_base_pytorch_command, mock_pytorch_version, mock_get_process_count, @@ -102,7 +93,7 @@ def test_create_commands_single_node( "torchrun", "--nnodes=1", "--nproc_per_node=2", - "/opt/ml/input/data/code/script.py", + "script.py", ] command = torchrun_driver.create_commands() @@ -118,38 +109,29 @@ def test_create_commands_single_node( "SM_MASTER_ADDR": "algo-1", "SM_MASTER_PORT": "7777", "SM_CURRENT_HOST_RANK": "0", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", }, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.USER_CODE_PATH", - "/opt/ml/input/data/code", -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_process_count", return_value=2 + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_process_count", + return_value=2, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_base_pytorch_command", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_base_pytorch_command", return_value=["torchrun"], ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_source_code_json", - return_value=DUMMY_SOURCE_CODE, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_json", - return_value=DUMMY_distributed, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.hyperparameters_to_cli_args", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.hyperparameters_to_cli_args", return_value=[], ) def test_create_commands_multi_node( mock_hyperparameters_to_cli_args, - mock_read_distributed_json, - mock_read_source_code_json, mock_get_base_pytorch_command, mock_pytorch_version, mock_get_process_count, @@ -161,7 +143,7 @@ def test_create_commands_multi_node( "--master_addr=algo-1", "--master_port=7777", "--node_rank=0", - "/opt/ml/input/data/code/script.py", + "script.py", ] command = torchrun_driver.create_commands() diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py index aba97996b0..beff06e8d8 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py @@ -12,11 +12,13 @@ # language governing permissions and limitations under the License. """Container Utils Unit Tests.""" from __future__ import absolute_import +import os -from sagemaker.modules.train.container_drivers.utils import ( +from sagemaker.modules.train.container_drivers.common.utils import ( safe_deserialize, safe_serialize, hyperparameters_to_cli_args, + get_process_count, ) SM_HPS = { @@ -119,3 +121,18 @@ def test_safe_serialize_empty_data(): assert safe_serialize("") == "" assert safe_serialize([]) == "[]" assert safe_serialize({}) == "{}" + + +def test_get_process_count(): + assert get_process_count() == 1 + assert get_process_count(2) == 2 + os.environ["SM_NUM_GPUS"] = "4" + assert get_process_count() == 4 + os.environ["SM_NUM_GPUS"] = "0" + os.environ["SM_NUM_NEURONS"] = "8" + assert get_process_count() == 8 + os.environ["SM_NUM_NEURONS"] = "0" + assert get_process_count() == 1 + del os.environ["SM_NUM_GPUS"] + del os.environ["SM_NUM_NEURONS"] + assert get_process_count() == 1 diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py index 66eafab4f0..f5f7ceb083 100644 --- a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -26,6 +26,7 @@ _load_recipes_cfg, _configure_gpu_args, _configure_trainium_args, + _get_trainining_recipe_gpu_model_name_and_script, ) from sagemaker.modules.utils import _run_clone_command_silent from sagemaker.modules.configs import Compute @@ -178,3 +179,37 @@ def test_get_args_from_recipe_compute( assert mock_gpu_args.call_count == 0 assert mock_trainium_args.call_count == 0 assert args is None + + @pytest.mark.parametrize( + "test_case", + [ + { + "model_type": "llama_v3", + "script": "llama_pretrain.py", + "model_base_name": "llama_v3", + }, + { + "model_type": "mistral", + "script": "mistral_pretrain.py", + "model_base_name": "mistral", + }, + { + "model_type": "deepseek_llamav3", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + { + "model_type": "deepseek_qwenv2", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + ], + ) + def test_get_trainining_recipe_gpu_model_name_and_script(test_case): + model_type = test_case["model_type"] + script = test_case["script"] + model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script( + model_type, script + ) + assert model_base_name == test_case["model_base_name"] + assert script == test_case["script"] diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 093da20ab8..13530a3983 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -17,8 +17,10 @@ import tempfile import json import os +import yaml import pytest -from unittest.mock import patch, MagicMock, ANY +from pydantic import ValidationError +from unittest.mock import patch, MagicMock, ANY, mock_open from sagemaker import image_uris from sagemaker_core.main.resources import TrainingJob @@ -65,7 +67,7 @@ ) from sagemaker.modules.distributed import Torchrun, SMP, MPI from sagemaker.modules.train.sm_recipes.utils import _load_recipes_cfg -from sagemaker.modules.templates import EXEUCTE_TORCHRUN_DRIVER, EXECUTE_MPI_DRIVER +from sagemaker.modules.templates import EXEUCTE_DISTRIBUTED_DRIVER from tests.unit import DATA_DIR DEFAULT_BASE_NAME = "dummy-image-job" @@ -410,7 +412,9 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ { "source_code": DEFAULT_SOURCE_CODE, "distributed": Torchrun(), - "expected_template": EXEUCTE_TORCHRUN_DRIVER, + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="Torchrun", driver_script="torchrun_driver.py" + ), "expected_hyperparameters": {}, }, { @@ -423,7 +427,9 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ tensor_parallel_degree=5, ) ), - "expected_template": EXEUCTE_TORCHRUN_DRIVER, + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="Torchrun", driver_script="torchrun_driver.py" + ), "expected_hyperparameters": { "mp_parameters": json.dumps( { @@ -438,9 +444,11 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ { "source_code": DEFAULT_SOURCE_CODE, "distributed": MPI( - custom_mpi_options=["-x", "VAR1", "-x", "VAR2"], + mpi_additional_options=["-x", "VAR1", "-x", "VAR2"], + ), + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="MPI", driver_script="mpi_driver.py" ), - "expected_template": EXECUTE_MPI_DRIVER, "expected_hyperparameters": {}, }, ], @@ -497,21 +505,15 @@ def test_train_with_distributed_config( assert os.path.exists(expected_runner_json_path) with open(expected_runner_json_path, "r") as f: runner_json_content = f.read() - assert test_case["distributed"].model_dump(exclude_none=True) == ( - json.loads(runner_json_content) - ) + assert test_case["distributed"].model_dump() == (json.loads(runner_json_content)) assert os.path.exists(expected_source_code_json_path) with open(expected_source_code_json_path, "r") as f: source_code_json_content = f.read() - assert test_case["source_code"].model_dump(exclude_none=True) == ( - json.loads(source_code_json_content) - ) + assert test_case["source_code"].model_dump() == (json.loads(source_code_json_content)) assert os.path.exists(expected_source_code_json_path) with open(expected_source_code_json_path, "r") as f: source_code_json_content = f.read() - assert test_case["source_code"].model_dump(exclude_none=True) == ( - json.loads(source_code_json_content) - ) + assert test_case["source_code"].model_dump() == (json.loads(source_code_json_content)) finally: shutil.rmtree(tmp_dir.name) assert not os.path.exists(tmp_dir.name) @@ -1047,15 +1049,139 @@ def mock_upload_data(path, bucket, key_prefix): model_trainer.train() - assert mock_local_container.train.called_once_with( + mock_local_container.assert_called_once_with( training_job_name=unique_name, instance_type=compute.instance_type, instance_count=compute.instance_count, image=training_image, container_root=local_container_root, sagemaker_session=modules_session, - container_entry_point=DEFAULT_ENTRYPOINT, + container_entrypoint=DEFAULT_ENTRYPOINT, container_arguments=DEFAULT_ARGUMENTS, + input_data_config=ANY, hyper_parameters=hyperparameters, environment=environment, ) + + +def test_safe_configs(): + # Test extra fails + with pytest.raises(ValueError): + SourceCode(entry_point="train.py") + # Test invalid type fails + with pytest.raises(ValueError): + SourceCode(entry_script=1) + + +@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory") +def test_destructor_cleanup(mock_tmp_dir, modules_session): + + with pytest.raises(ValidationError): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute="test", + ) + mock_tmp_dir.cleanup.assert_not_called() + + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + ) + model_trainer._temp_recipe_train_dir = mock_tmp_dir + mock_tmp_dir.assert_not_called() + del model_trainer + mock_tmp_dir.cleanup.assert_called_once() + + +@patch("os.path.exists") +def test_hyperparameters_valid_json(mock_exists, modules_session): + mock_exists.return_value = True + expected_hyperparameters = {"param1": "value1", "param2": 2} + mock_file_open = mock_open(read_data=json.dumps(expected_hyperparameters)) + + with patch("builtins.open", mock_file_open): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.json", + ) + assert model_trainer.hyperparameters == expected_hyperparameters + mock_file_open.assert_called_once_with("hyperparameters.json", "r") + mock_exists.assert_called_once_with("hyperparameters.json") + + +@patch("os.path.exists") +def test_hyperparameters_valid_yaml(mock_exists, modules_session): + mock_exists.return_value = True + expected_hyperparameters = {"param1": "value1", "param2": 2} + mock_file_open = mock_open(read_data=yaml.dump(expected_hyperparameters)) + + with patch("builtins.open", mock_file_open): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) + assert model_trainer.hyperparameters == expected_hyperparameters + mock_file_open.assert_called_once_with("hyperparameters.yaml", "r") + mock_exists.assert_called_once_with("hyperparameters.yaml") + + +def test_hyperparameters_not_exist(modules_session): + with pytest.raises(ValueError): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="nonexistent.json", + ) + + +@patch("os.path.exists") +def test_hyperparameters_invalid(mock_exists, modules_session): + mock_exists.return_value = True + + # YAML contents must be a valid mapping + mock_file_open = mock_open(read_data="- item1\n- item2") + with patch("builtins.open", mock_file_open): + with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) + + # YAML contents must be a valid mapping + mock_file_open = mock_open(read_data="invalid") + with patch("builtins.open", mock_file_open): + with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) + + # Must be valid YAML + mock_file_open = mock_open(read_data="* invalid") + with patch("builtins.open", mock_file_open): + with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) diff --git a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py index 53119e532a..026e1a2d54 100644 --- a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py +++ b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py @@ -568,11 +568,12 @@ def test_clarify_model_monitor(): # The subclass should has monitoring_type() defined # noinspection PyAbstractClass - class DummyClarifyModelMonitoir(ClarifyModelMonitor): + class DummyClarifyModelMonitor(ClarifyModelMonitor): + _TEST_CLASS = True pass with pytest.raises(TypeError): - DummyClarifyModelMonitoir.monitoring_type() + DummyClarifyModelMonitor.monitoring_type() def test_clarify_model_monitor_invalid_update(clarify_model_monitors): @@ -593,6 +594,8 @@ def test_clarify_model_monitor_invalid_attach(sagemaker_session): ) # attach, invalid monitoring type for clarify_model_monitor_cls in ClarifyModelMonitor.__subclasses__(): + if hasattr(clarify_model_monitor_cls, "_TEST_CLASS"): + continue with pytest.raises(TypeError): clarify_model_monitor_cls.attach(SCHEDULE_NAME, sagemaker_session) diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index d31b9f8527..b338885491 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -73,6 +73,7 @@ LINFINITY_METHOD = "LInfinity" CRON_DAILY = CronExpressionGenerator.daily() +CRON_NOW = CronExpressionGenerator.now() BASELINING_JOB_NAME = "baselining-job" BASELINE_DATASET_PATH = "/my/local/path/baseline.csv" PREPROCESSOR_PATH = "/my/local/path/preprocessor.py" @@ -1136,6 +1137,36 @@ def _test_data_quality_monitor_update_schedule(data_quality_monitor, sagemaker_s sagemaker_session.sagemaker_client.delete_data_quality_job_definition.assert_not_called() sagemaker_session.sagemaker_client.create_data_quality_job_definition.assert_not_called() + # update schedule + sagemaker_session.describe_monitoring_schedule = MagicMock() + sagemaker_session.sagemaker_client.describe_data_quality_job_definition = MagicMock() + sagemaker_session.sagemaker_client.create_data_quality_job_definition = MagicMock() + + # Test updating monitoring schedule with schedule_cron_expression set to NOW + sagemaker_session.sagemaker_client.update_monitoring_schedule = Mock() + data_quality_monitor.update_monitoring_schedule( + data_analysis_start_time="-PT24H", + data_analysis_end_time="-PT0H", + schedule_cron_expression=CRON_NOW, + ) + + sagemaker_session.sagemaker_client.update_monitoring_schedule.assert_called_once_with( + MonitoringScheduleName=data_quality_monitor.monitoring_schedule_name, + MonitoringScheduleConfig={ + "MonitoringJobDefinitionName": data_quality_monitor.job_definition_name, + "MonitoringType": DefaultModelMonitor.monitoring_type(), + "ScheduleConfig": { + "ScheduleExpression": CRON_NOW, + "DataAnalysisStartTime": "-PT24H", + "DataAnalysisEndTime": "-PT0H", + }, + }, + ) + + # A new data quality job definition should be created + sagemaker_session.sagemaker_client.describe_data_quality_job_definition.assert_called_once() + sagemaker_session.sagemaker_client.create_data_quality_job_definition.assert_called_once() + # update one property of job definition time.sleep( 0.001 diff --git a/tests/unit/sagemaker/remote_function/runtime_environment/test_mpi_utils.py b/tests/unit/sagemaker/remote_function/runtime_environment/test_mpi_utils.py new file mode 100644 index 0000000000..aa983141ae --- /dev/null +++ b/tests/unit/sagemaker/remote_function/runtime_environment/test_mpi_utils.py @@ -0,0 +1,125 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""MPI Utils Unit Tests.""" +from __future__ import absolute_import + +import os +from mock import patch + +import sagemaker.remote_function.runtime_environment.mpi_utils_remote as mpi_utils_remote # noqa: E402 + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-1", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +def test_mpi_utils_main_job_start( + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main() + + mock_start_sshd_daemon.assert_called_once() + mock_bootstrap_worker_node.assert_not_called() + mock_bootstrap_master_node.assert_called_once() + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-2", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +def test_mpi_utils_worker_job_start( + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main() + + mock_start_sshd_daemon.assert_called_once() + mock_bootstrap_worker_node.assert_called_once() + mock_bootstrap_master_node.assert_not_called() + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-1", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +@patch( + "sagemaker.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers" +) +def test_mpi_utils_main_job_end( + mock_write_status_file_to_workers, + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main(["--job_ended", "1"]) + + mock_start_sshd_daemon.assert_not_called() + mock_bootstrap_worker_node.assert_not_called() + mock_bootstrap_master_node.assert_not_called() + mock_write_status_file_to_workers.assert_called_once() + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-2", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +@patch( + "sagemaker.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers" +) +def test_mpi_utils_worker_job_end( + mock_write_status_file_to_workers, + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main(["--job_ended", "1"]) + + mock_start_sshd_daemon.assert_not_called() + mock_bootstrap_worker_node.assert_not_called() + mock_bootstrap_master_node.assert_not_called() + mock_write_status_file_to_workers.assert_not_called() diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py index 536bfdfca7..de8758bfad 100644 --- a/tests/unit/sagemaker/remote_function/test_client.py +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -1504,7 +1504,9 @@ def test_consistency_between_remote_and_step_decorator(): "s3_kms_key", "s3_root_uri", "sagemaker_session", + "disable_output_compression", "use_torchrun", + "use_mpirun", "nproc_per_node", ] diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index 888c634bfe..f153b5b2ca 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -15,6 +15,7 @@ import os import sys +import tempfile import pytest from mock import patch, Mock, ANY, mock_open from mock.mock import MagicMock @@ -49,6 +50,11 @@ _prepare_dependencies_and_pre_execution_scripts, ) +from sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment import ( + set_env, + safe_serialize, +) + REGION = "us-west-2" TRAINING_JOB_ARN = "training-job-arn" @@ -68,6 +74,178 @@ EXPECTED_OUTPUT_URI = S3_URI + "/output" EXPECTED_DEPENDENCIES_URI = S3_URI + "/additional_dependencies/requirements.txt" +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_CPU = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.t3.xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='4' +export SM_NUM_GPUS='0' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.t3.xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.t3.xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='4' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.t3.xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 4, "num_gpus": 0, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.t3.xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.t3.xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +""" + +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.12xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='48' +export SM_NUM_GPUS='4' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='4' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 48, "num_gpus": 4, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +export NCCL_SOCKET_IFNAME='eth0' +export NCCL_PROTO='simple' +""" + +# flake8: noqa +EXPECTED_ENV_MULTI_NODE_MULTI_GPUS = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.2xlarge' +export SM_HOSTS='["algo-1", "algo-2", "algo-3", "algo-4"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='4' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='8' +export SM_NUM_GPUS='1' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='1' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.2xlarge", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "host_count": 4, "nproc_per_node": 1, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 8, "num_gpus": 1, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}}' +export NCCL_SOCKET_IFNAME='eth0' +export NCCL_PROTO='simple' +""" + +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.12xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='48' +export SM_NUM_GPUS='4' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='4' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 48, "num_gpus": 4, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +export MASTER_ADDR='algo-1' +export MASTER_PORT='7777' +export SM_HOSTS_LIST='algo-1:4' +export SM_FI_PROVIDER='' +export SM_NCCL_PROTO='' +export SM_FI_EFA_USE_DEVICE_RDMA='' +""" + +# flake8: noqa +EXPECTED_ENV_MULTI_NODE_MULTI_GPUS_MPIRUN = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.2xlarge' +export SM_HOSTS='["algo-1", "algo-2", "algo-3", "algo-4"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='4' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='8' +export SM_NUM_GPUS='1' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='1' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.2xlarge", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "host_count": 4, "nproc_per_node": 1, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 8, "num_gpus": 1, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}}' +export MASTER_ADDR='algo-1' +export MASTER_PORT='7777' +export SM_HOSTS_LIST='algo-1:1,algo-2:1,algo-3:1,algo-4:1' +export SM_FI_PROVIDER='' +export SM_NCCL_PROTO='' +export SM_FI_EFA_USE_DEVICE_RDMA='' +""" + +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN_WITH_NPROC_PER_NODE = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.12xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='48' +export SM_NUM_GPUS='4' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='2' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 2, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 48, "num_gpus": 4, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +export MASTER_ADDR='algo-1' +export MASTER_PORT='7777' +export SM_HOSTS_LIST='algo-1:2' +export SM_FI_PROVIDER='' +export SM_NCCL_PROTO='' +export SM_FI_EFA_USE_DEVICE_RDMA='' +""" + DESCRIBE_TRAINING_JOB_RESPONSE = { "TrainingJobArn": TRAINING_JOB_ARN, "TrainingJobStatus": "{}", @@ -112,8 +290,8 @@ def mock_get_current_run(): return current_run -def describe_training_job_response(job_status): - return { +def describe_training_job_response(job_status, disable_output_compression=False): + job_response = { "TrainingJobArn": TRAINING_JOB_ARN, "TrainingJobStatus": job_status, "ResourceConfig": { @@ -121,15 +299,38 @@ def describe_training_job_response(job_status): "InstanceType": "ml.c4.xlarge", "VolumeSizeInGB": 30, }, - "OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"}, } + if disable_output_compression: + output_config = { + "S3OutputPath": "s3://sagemaker-123/image_uri/output", + "CompressionType": "NONE", + } + else: + output_config = { + "S3OutputPath": "s3://sagemaker-123/image_uri/output", + "CompressionType": "NONE", + } + + job_response["OutputDataConfig"] = output_config + + return job_response + COMPLETED_TRAINING_JOB = describe_training_job_response("Completed") INPROGRESS_TRAINING_JOB = describe_training_job_response("InProgress") CANCELLED_TRAINING_JOB = describe_training_job_response("Stopped") FAILED_TRAINING_JOB = describe_training_job_response("Failed") +COMPLETED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response( + "Completed", True +) +INPROGRESS_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response( + "InProgress", True +) +CANCELLED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response("Stopped", True) +FAILED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response("Failed", True) + def mock_session(): session = Mock() @@ -376,8 +577,6 @@ def test_start( s3_base_uri=f"{S3_URI}/{job.job_name}", hmac_key=HMAC_KEY, s3_kms_key=None, - use_torchrun=False, - nproc_per_node=1, ) mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) @@ -392,7 +591,7 @@ def test_start( s3_kms_key=None, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=1, + use_mpirun=False, ) mock_dependency_upload.assert_called_once_with( @@ -510,8 +709,6 @@ def test_start_with_checkpoint_location( s3_base_uri=f"{S3_URI}/{job.job_name}", hmac_key=HMAC_KEY, s3_kms_key=None, - use_torchrun=False, - nproc_per_node=1, ) mock_stored_function().save.assert_called_once_with( @@ -665,8 +862,6 @@ def test_start_with_complete_job_settings( s3_base_uri=f"{S3_URI}/{job.job_name}", hmac_key=HMAC_KEY, s3_kms_key=KMS_KEY_ARN, - use_torchrun=False, - nproc_per_node=1, ) local_dependencies_path = mock_runtime_manager().snapshot() @@ -679,7 +874,7 @@ def test_start_with_complete_job_settings( s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=1, + use_mpirun=False, ) mock_user_workspace_upload.assert_called_once_with( @@ -838,8 +1033,6 @@ def test_get_train_args_under_pipeline_context( step_name=MOCKED_PIPELINE_CONFIG.step_name, func_step_s3_dir=MOCKED_PIPELINE_CONFIG.pipeline_build_time, ), - use_torchrun=False, - nproc_per_node=1, ) mock_stored_function.save_pipeline_step_function.assert_called_once_with(mocked_serialized_data) @@ -853,7 +1046,7 @@ def test_get_train_args_under_pipeline_context( s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=1, + use_mpirun=False, ) mock_user_workspace_upload.assert_called_once_with( @@ -1029,7 +1222,7 @@ def test_start_with_spark( s3_kms_key=None, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=1, + use_mpirun=False, ) session().sagemaker_client.create_training_job.assert_called_once_with( @@ -1132,6 +1325,27 @@ def test_describe(session, *args): session().sagemaker_client.describe_training_job.assert_called_once() +@patch("sagemaker.remote_function.job._prepare_and_upload_runtime_scripts") +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_describe_disable_output_compression(session, *args): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + instance_type="ml.m5.large", + disable_output_compression=True, + ) + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + job.describe() + assert job.describe() == COMPLETED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION + + session().sagemaker_client.describe_training_job.assert_called_once() + + @patch("sagemaker.remote_function.job._prepare_and_upload_runtime_scripts") @patch("sagemaker.remote_function.job._prepare_and_upload_workspace") @patch("sagemaker.remote_function.job.StoredFunction") @@ -1184,13 +1398,11 @@ def test_prepare_and_upload_runtime_scripts(session, mock_copy, mock_s3_upload): s3_base_uri=S3_URI, s3_kms_key=KMS_KEY_ARN, sagemaker_session=session(), - use_torchrun=False, - nproc_per_node=1, ) assert s3_path == mock_s3_upload.return_value - assert mock_copy.call_count == 2 + assert mock_copy.call_count == 3 mock_s3_upload.assert_called_once() @@ -1210,7 +1422,7 @@ def test_prepare_and_upload_runtime_scripts_under_pipeline_context( ) # Bootstrap scripts are uploaded on the first call assert s3_path == mock_s3_upload.return_value - assert mock_copy.call_count == 2 + assert mock_copy.call_count == 3 mock_s3_upload.assert_called_once() mock_copy.reset_mock() @@ -1619,3 +1831,848 @@ def test_extend_spark_config_to_request( } ], ) + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_torchrun_single_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_type="ml.g5.12xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=True, + use_mpirun=False, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=True, + use_mpirun=False, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--distribution", + "torchrun", + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=1, + InstanceType="ml.g5.12xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_torchrun_multi_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_count=2, + instance_type="ml.g5.2xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=True, + use_mpirun=False, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=True, + use_mpirun=False, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + "S3DataDistributionType": "FullyReplicated", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + "S3DataDistributionType": "FullyReplicated", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--distribution", + "torchrun", + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=2, + InstanceType="ml.g5.2xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_cpu( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + with tempfile.NamedTemporaryFile() as f: + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.t3.xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.t3.xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + distribution=None, + output_file=f.name, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(f.name, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_SINGLE_NODE_CPU) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=48, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_multi_gpu( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + with tempfile.NamedTemporaryFile() as f: + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.12xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.12xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + distribution="torchrun", + output_file=f.name, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(f.name, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=8, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=1, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_multi_node_multi_gpu( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + with tempfile.NamedTemporaryFile() as f: + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1", "algo-2", "algo-3", "algo-4"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.2xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.2xlarge", + hosts=["algo-4", "algo-2", "algo-1", "algo-3"], + ) + ], + network_interface_name="eth0", + ), + distribution="torchrun", + output_file=f.name, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(f.name, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_MULTI_NODE_MULTI_GPUS) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=48, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_multi_gpu_mpirun( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + with tempfile.NamedTemporaryFile() as f: + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.12xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.12xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + distribution="mpirun", + output_file=f.name, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(f.name, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=8, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=1, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_multi_node_multi_gpu_mpirun( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + with tempfile.NamedTemporaryFile() as f: + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1", "algo-2", "algo-3", "algo-4"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.2xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.2xlarge", + hosts=["algo-4", "algo-2", "algo-1", "algo-3"], + ) + ], + network_interface_name="eth0", + ), + distribution="mpirun", + output_file=f.name, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(f.name, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_MULTI_NODE_MULTI_GPUS_MPIRUN) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_torchrun_single_node_with_nproc_per_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_type="ml.g5.12xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=True, + use_mpirun=False, + nproc_per_node=2, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=True, + use_mpirun=False, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--distribution", + "torchrun", + "--user_nproc_per_node", + "2", + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=1, + InstanceType="ml.g5.12xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_mpirun_single_node_with_nproc_per_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_type="ml.g5.12xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=False, + use_mpirun=True, + nproc_per_node=2, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=False, + use_mpirun=True, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--distribution", + "mpirun", + "--user_nproc_per_node", + "2", + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=1, + InstanceType="ml.g5.12xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=48, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_multi_gpu_mpirun_with_nproc_per_node( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + with tempfile.NamedTemporaryFile() as f: + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.12xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.12xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + distribution="mpirun", + user_nproc_per_node=2, + output_file=f.name, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(f.name, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines( + EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN_WITH_NPROC_PER_NODE + ) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + + +def _remove_extra_lines(string): + """Removes extra blank lines from a string.""" + return "\n".join([line for line in string.splitlines() if line.strip()]) diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index b6bd69e304..415d7eab5b 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -75,7 +75,7 @@ "-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" ) mock_djl_image_uri = ( - "123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1" + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124" ) mock_model_data = { @@ -1166,6 +1166,9 @@ def test_optimize_quantize_for_jumpstart( mock_pysdk_model.image_uri = mock_tgi_image_uri mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pysdk_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } sample_input = { "inputs": "The diamondback terrapin or simply terrapin is a species " @@ -1201,6 +1204,10 @@ def test_optimize_quantize_for_jumpstart( ) self.assertIsNotNone(out_put) + self.assertEqual( + out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + ) @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) @@ -1287,6 +1294,9 @@ def test_optimize_quantize_and_compile_for_jumpstart( mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] mock_pysdk_model.config_name = "config_name" mock_pysdk_model._metadata_configs = {"config_name": mock_metadata_config} + mock_pysdk_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } sample_input = { "inputs": "The diamondback terrapin or simply terrapin is a species " @@ -1319,6 +1329,8 @@ def test_optimize_quantize_and_compile_for_jumpstart( ) self.assertIsNotNone(out_put) + self.assertIsNone(out_put["OptimizationConfigs"][1]["ModelCompilationConfig"].get("Image")) + self.assertIsNone(out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"].get("Image")) @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) @@ -1633,13 +1645,17 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations( mock_serve_settings, mock_telemetry, ): - mock_sagemaker_session = Mock() + mock_sagemaker_session = MagicMock() + mock_sagemaker_session.sagemaker_client.create_optimization_job = MagicMock() mock_sagemaker_session.wait_for_optimization_job.side_effect = ( lambda *args: mock_optimization_job_response ) mock_lmi_js_model = MagicMock() mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } mock_lmi_js_model.env = { "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -1671,6 +1687,13 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations( output_path="s3://bucket/code/", ) + assert ( + mock_sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelQuantizationConfig"]["Image"] + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124" + ) + assert mock_lmi_js_model.set_deployment_config.call_args_list[0].kwargs == { "instance_type": "ml.g5.24xlarge", "config_name": "lmi", @@ -1711,13 +1734,17 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_over mock_serve_settings, mock_telemetry, ): - mock_sagemaker_session = Mock() + mock_sagemaker_session = MagicMock() + mock_sagemaker_session.sagemaker_client.create_optimization_job = MagicMock() mock_sagemaker_session.wait_for_optimization_job.side_effect = ( lambda *args: mock_optimization_job_response ) mock_lmi_js_model = MagicMock() mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi27.0.0-cu124" + } mock_lmi_js_model.env = { "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -1748,6 +1775,13 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_over output_path="s3://bucket/code/", ) + assert ( + mock_sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelQuantizationConfig"]["Image"] + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi27.0.0-cu124" + ) + assert mock_lmi_js_model.set_deployment_config.call_args_list[0].kwargs == { "instance_type": "ml.g5.24xlarge", "config_name": "lmi", @@ -1763,3 +1797,163 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_over "OPTION_TENSOR_PARALLEL_DEGREE": "8", "OPTION_QUANTIZE": "fp8", # should be added to the env } + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_fine_tuned_model", + return_value=False, + ) + def test_optimize_on_js_model_test_image_defaulting_scenarios( + self, + mock_is_fine_tuned, + mock_is_jumpstart_model, + mock_js_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + + mock_lmi_js_model = MagicMock() + mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-1-70b-instruct", + schema_builder=SchemaBuilder("test", "test"), + sagemaker_session=MagicMock(), + ) + model_builder.pysdk_model = mock_lmi_js_model + + # assert lmi version is upgraded to hardcoded default + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + ) + + # assert lmi version is left as is + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi21.0.0-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi21.0.0-cu124", + ) + + # assert lmi version is upgraded to the highest provided version + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelShardingConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + }, + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124" + } + }, + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelShardingConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + self.assertEqual( + optimization_args["OptimizationConfigs"][1]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + + # assert lmi version is upgraded to the highest provided version and sets empty image config + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124" + } + }, + {"ModelShardingConfig": {}}, + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + self.assertEqual( + optimization_args["OptimizationConfigs"][1]["ModelShardingConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + + # assert lmi version is left as is on minor version bump + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.1.0-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.1.0-cu124", + ) + + # assert lmi version is left as is on patch version bump + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.0.1-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.0.1-cu124", + ) diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 7355fe4f38..6661c6e2bf 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -3270,7 +3270,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation( mock_pysdk_model = Mock() mock_pysdk_model.model_data = None - mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-1-8B-Instruct"} + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-2-8B-Instruct"} sample_input = {"inputs": "dummy prompt", "parameters": {}} @@ -3279,7 +3279,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation( dummy_schema_builder = SchemaBuilder(sample_input, sample_output) model_builder = ModelBuilder( - model="meta-llama/Meta-Llama-3-1-8B-Instruct", + model="meta-llama/Meta-Llama-3-2-8B-Instruct", schema_builder=dummy_schema_builder, env_vars={"HF_TOKEN": "token"}, model_metadata={ @@ -3293,7 +3293,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation( self.assertRaisesRegex( ValueError, - "Compilation is not supported for Llama-3.1 with a GPU instance.", + "Compilation is not supported for models greater than Llama-3.0 with a GPU instance.", lambda: model_builder.optimize( job_name="job_name-123", instance_type="ml.g5.24xlarge", @@ -3733,6 +3733,9 @@ def test_optimize_sharding_with_override_for_js( pysdk_model.env = {"key": "val"} pysdk_model._enable_network_isolation = True pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + pysdk_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } mock_build_for_jumpstart.side_effect = lambda **kwargs: pysdk_model mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( @@ -3803,8 +3806,9 @@ def test_optimize_sharding_with_override_for_js( OptimizationConfigs=[ { "ModelShardingConfig": { - "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} - } + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}, + }, } ], OutputConfig={ @@ -4037,14 +4041,30 @@ def test_neuron_configurations_rule_set(self): @pytest.mark.parametrize( "test_case", [ + # Real-time deployment without update { "input_args": {"endpoint_name": "test"}, "call_params": { "instance_type": "ml.g5.2xlarge", "initial_instance_count": 1, "endpoint_name": "test", + "update_endpoint": False, + }, + }, + # Real-time deployment with update + { + "input_args": { + "endpoint_name": "existing-endpoint", + "update_endpoint": True, + }, + "call_params": { + "instance_type": "ml.g5.2xlarge", + "initial_instance_count": 1, + "endpoint_name": "existing-endpoint", + "update_endpoint": True, }, }, + # Serverless deployment without update { "input_args": { "endpoint_name": "test", @@ -4053,8 +4073,23 @@ def test_neuron_configurations_rule_set(self): "call_params": { "serverless_inference_config": ServerlessInferenceConfig(), "endpoint_name": "test", + "update_endpoint": False, }, }, + # Serverless deployment with update + { + "input_args": { + "endpoint_name": "existing-endpoint", + "inference_config": ServerlessInferenceConfig(), + "update_endpoint": True, + }, + "call_params": { + "serverless_inference_config": ServerlessInferenceConfig(), + "endpoint_name": "existing-endpoint", + "update_endpoint": True, + }, + }, + # Async deployment without update { "input_args": { "endpoint_name": "test", @@ -4065,10 +4100,30 @@ def test_neuron_configurations_rule_set(self): "instance_type": "ml.g5.2xlarge", "initial_instance_count": 1, "endpoint_name": "test", + "update_endpoint": False, }, }, + # Async deployment with update { - "input_args": {"endpoint_name": "test", "inference_config": RESOURCE_REQUIREMENTS}, + "input_args": { + "endpoint_name": "existing-endpoint", + "inference_config": AsyncInferenceConfig(output_path="op-path"), + "update_endpoint": True, + }, + "call_params": { + "async_inference_config": AsyncInferenceConfig(output_path="op-path"), + "instance_type": "ml.g5.2xlarge", + "initial_instance_count": 1, + "endpoint_name": "existing-endpoint", + "update_endpoint": True, + }, + }, + # Multi-Model deployment (update_endpoint not supported) + { + "input_args": { + "endpoint_name": "test", + "inference_config": RESOURCE_REQUIREMENTS, + }, "call_params": { "resources": RESOURCE_REQUIREMENTS, "role": "role-arn", @@ -4076,8 +4131,10 @@ def test_neuron_configurations_rule_set(self): "instance_type": "ml.g5.2xlarge", "mode": Mode.SAGEMAKER_ENDPOINT, "endpoint_type": EndpointType.INFERENCE_COMPONENT_BASED, + "update_endpoint": False, }, }, + # Batch transform { "input_args": { "inference_config": BatchTransformInferenceConfig( @@ -4092,7 +4149,16 @@ def test_neuron_configurations_rule_set(self): "id": "Batch", }, ], - ids=["Real Time", "Serverless", "Async", "Multi-Model", "Batch"], + ids=[ + "Real Time", + "Real Time Update", + "Serverless", + "Serverless Update", + "Async", + "Async Update", + "Multi-Model", + "Batch", + ], ) @patch("sagemaker.serve.builder.model_builder.unique_name_from_base") def test_deploy(mock_unique_name_from_base, test_case): @@ -4115,3 +4181,20 @@ def test_deploy(mock_unique_name_from_base, test_case): diff = deepdiff.DeepDiff(kwargs, test_case["call_params"]) assert diff == {} + + +def test_deploy_multi_model_update_error(): + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HUGGING_FACE_HUB_TOKEN": "token"}, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + setattr(model_builder, "built_model", MagicMock()) + + with pytest.raises( + ValueError, match="Currently update_endpoint is supported for single model endpoints" + ): + model_builder.deploy( + endpoint_name="test", inference_config=RESOURCE_REQUIREMENTS, update_endpoint=True + ) diff --git a/tests/unit/sagemaker/serve/detector/test_dependency_manager.py b/tests/unit/sagemaker/serve/detector/test_dependency_manager.py index 491968dd25..52e9822e57 100644 --- a/tests/unit/sagemaker/serve/detector/test_dependency_manager.py +++ b/tests/unit/sagemaker/serve/detector/test_dependency_manager.py @@ -21,7 +21,7 @@ DEPENDENCY_LIST = [ "requests==2.26.0", - "numpy>=1.20.0", + "numpy==1.26.4", "pandas<=1.3.3", "matplotlib<3.5.0", "scikit-learn>0.24.1", @@ -34,7 +34,7 @@ EXPECTED_DEPENDENCY_MAP = { "requests": "==2.26.0", - "numpy": ">=1.20.0", + "numpy": "==1.26.4", "pandas": "<=1.3.3", "matplotlib": "<3.5.0", "scikit-learn": ">0.24.1", diff --git a/tests/unit/sagemaker/serve/detector/test_pickle_dependencies.py b/tests/unit/sagemaker/serve/detector/test_pickle_dependencies.py index 34cab8a526..ced9555fc5 100644 --- a/tests/unit/sagemaker/serve/detector/test_pickle_dependencies.py +++ b/tests/unit/sagemaker/serve/detector/test_pickle_dependencies.py @@ -93,13 +93,14 @@ def create_mock_modules(name, doc, file): # happy case def test_generate_requirements_exact_match(monkeypatch): - with patch("cloudpickle.load"), patch("tqdm.tqdm"), patch( - "sagemaker.serve.detector.pickle_dependencies.subprocess.run" - ) as subprocess_run, patch( - "sagemaker.serve.detector.pickle_dependencies.subprocess.Popen" - ) as subprocess_popen, patch( - "builtins.open" - ) as mocked_open, monkeypatch.context() as m: + with ( + patch("cloudpickle.load"), + patch("tqdm.tqdm"), + patch("sagemaker.serve.detector.pickle_dependencies.subprocess.run") as subprocess_run, + patch("sagemaker.serve.detector.pickle_dependencies.subprocess.Popen") as subprocess_popen, + patch("builtins.open") as mocked_open, + monkeypatch.context() as m, + ): mock_run_stdout = MagicMock() mock_run_stdout.stdout = json.dumps(INSTALLED_PKG_JSON).encode("utf-8") subprocess_run.return_value = mock_run_stdout @@ -147,13 +148,14 @@ def test_generate_requirements_exact_match(monkeypatch): def test_generate_requirements_txt_pruning_unused_packages(monkeypatch): - with patch("cloudpickle.load"), patch("tqdm.tqdm"), patch( - "sagemaker.serve.detector.pickle_dependencies.subprocess.run" - ) as subprocess_run, patch( - "sagemaker.serve.detector.pickle_dependencies.subprocess.Popen" - ) as subprocess_popen, patch( - "builtins.open" - ) as mocked_open, monkeypatch.context() as m: + with ( + patch("cloudpickle.load"), + patch("tqdm.tqdm"), + patch("sagemaker.serve.detector.pickle_dependencies.subprocess.run") as subprocess_run, + patch("sagemaker.serve.detector.pickle_dependencies.subprocess.Popen") as subprocess_popen, + patch("builtins.open") as mocked_open, + monkeypatch.context() as m, + ): mock_run_stdout = MagicMock() mock_run_stdout.stdout = json.dumps(INSTALLED_PKG_JSON_UNUSED).encode("utf-8") subprocess_run.return_value = mock_run_stdout @@ -201,13 +203,14 @@ def test_generate_requirements_txt_pruning_unused_packages(monkeypatch): def test_generate_requirements_txt_no_currently_used_packages(monkeypatch): - with patch("cloudpickle.load"), patch("tqdm.tqdm"), patch( - "sagemaker.serve.detector.pickle_dependencies.subprocess.run" - ) as subprocess_run, patch( - "sagemaker.serve.detector.pickle_dependencies.subprocess.Popen" - ) as subprocess_popen, patch( - "builtins.open" - ) as mocked_open, monkeypatch.context() as m: + with ( + patch("cloudpickle.load"), + patch("tqdm.tqdm"), + patch("sagemaker.serve.detector.pickle_dependencies.subprocess.run") as subprocess_run, + patch("sagemaker.serve.detector.pickle_dependencies.subprocess.Popen") as subprocess_popen, + patch("builtins.open") as mocked_open, + monkeypatch.context() as m, + ): mock_run_stdout = MagicMock() mock_run_stdout.stdout = json.dumps([]).encode("utf-8") subprocess_run.return_value = mock_run_stdout diff --git a/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py b/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py index 183d15d13e..aa99e1971c 100644 --- a/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py +++ b/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py @@ -52,8 +52,8 @@ def test_create_dir_structure_from_new(self, mock_path, mock_disk_usage, mock_di mock_disk_space.assert_called_once_with(mock_model_path) mock_disk_usage.assert_called_once() - self.assertEquals(ret_model_path, mock_model_path) - self.assertEquals(ret_code_dir, mock_code_dir) + self.assertEqual(ret_model_path, mock_model_path) + self.assertEqual(ret_code_dir, mock_code_dir) @patch("sagemaker.serve.model_server.djl_serving.prepare.Path") def test_create_dir_structure_invalid_path(self, mock_path): @@ -65,7 +65,7 @@ def test_create_dir_structure_invalid_path(self, mock_path): with self.assertRaises(ValueError) as context: _create_dir_structure(mock_model_path) - self.assertEquals("model_dir is not a valid directory", str(context.exception)) + self.assertEqual("model_dir is not a valid directory", str(context.exception)) @patch("sagemaker.serve.model_server.djl_serving.prepare.S3Downloader") @patch("builtins.open", new_callable=mock_open, read_data="data") diff --git a/tests/unit/sagemaker/serve/model_server/multi_model_server/test_multi_model_server_prepare.py b/tests/unit/sagemaker/serve/model_server/multi_model_server/test_multi_model_server_prepare.py index e877c1e7e9..567a72182a 100644 --- a/tests/unit/sagemaker/serve/model_server/multi_model_server/test_multi_model_server_prepare.py +++ b/tests/unit/sagemaker/serve/model_server/multi_model_server/test_multi_model_server_prepare.py @@ -91,8 +91,8 @@ def test_create_dir_structure_from_new(self, mock_path, mock_disk_usage, mock_di mock_disk_space.assert_called_once_with(mock_model_path) mock_disk_usage.assert_called_once() - self.assertEquals(ret_model_path, mock_model_path) - self.assertEquals(ret_code_dir, mock_code_dir) + self.assertEqual(ret_model_path, mock_model_path) + self.assertEqual(ret_code_dir, mock_code_dir) @patch("sagemaker.serve.model_server.multi_model_server.prepare.Path") def test_create_dir_structure_invalid_path(self, mock_path): @@ -104,4 +104,4 @@ def test_create_dir_structure_invalid_path(self, mock_path): with self.assertRaises(ValueError) as context: _create_dir_structure(mock_model_path) - self.assertEquals("model_dir is not a valid directory", str(context.exception)) + self.assertEqual("model_dir is not a valid directory", str(context.exception)) diff --git a/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py index 88d109831d..ed94f10ce9 100644 --- a/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py +++ b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py @@ -50,8 +50,8 @@ def test_create_dir_structure_from_new(self, mock_path, mock_disk_usage, mock_di mock_disk_space.assert_called_once_with(mock_model_path) mock_disk_usage.assert_called_once() - self.assertEquals(ret_model_path, mock_model_path) - self.assertEquals(ret_code_dir, mock_code_dir) + self.assertEqual(ret_model_path, mock_model_path) + self.assertEqual(ret_code_dir, mock_code_dir) @patch("sagemaker.serve.model_server.tgi.prepare.Path") def test_create_dir_structure_invalid_path(self, mock_path): @@ -63,7 +63,7 @@ def test_create_dir_structure_invalid_path(self, mock_path): with self.assertRaises(ValueError) as context: _create_dir_structure(mock_model_path) - self.assertEquals("model_dir is not a valid directory", str(context.exception)) + self.assertEqual("model_dir is not a valid directory", str(context.exception)) @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") @patch("builtins.open", read_data="data") diff --git a/tests/unit/sagemaker/serve/utils/test_hardware_detector.py b/tests/unit/sagemaker/serve/utils/test_hardware_detector.py index d383f95809..58839bfc50 100644 --- a/tests/unit/sagemaker/serve/utils/test_hardware_detector.py +++ b/tests/unit/sagemaker/serve/utils/test_hardware_detector.py @@ -21,7 +21,7 @@ REGION = "us-west-2" VALID_INSTANCE_TYPE = "ml.g5.48xlarge" INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge" -EXPECTED_INSTANCE_GPU_INFO = (8, 196608) +EXPECTED_INSTANCE_GPU_INFO = (8, 183104) MIB_CONVERSION_FACTOR = 0.00000095367431640625 MEMORY_BUFFER_MULTIPLIER = 1.2 # 20% buffer @@ -39,7 +39,7 @@ def test_get_gpu_info_success(sagemaker_session, boto_session): "MemoryInfo": {"SizeInMiB": 24576}, } ], - "TotalGpuMemoryInMiB": 196608, + "TotalGpuMemoryInMiB": 183104, }, } ] diff --git a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py index 4729efbda4..fc832ad02d 100644 --- a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py +++ b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py @@ -14,7 +14,7 @@ import unittest from unittest.mock import Mock, patch, MagicMock from sagemaker.serve import Mode, ModelServer -from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH +from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH, MLFLOW_TRACKING_ARN from sagemaker.serve.utils.telemetry_logger import ( _send_telemetry, _capture_telemetry, @@ -40,7 +40,10 @@ MOCK_HUGGINGFACE_ID = "meta-llama/Llama-2-7b-hf" MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex") MOCK_ENDPOINT_ARN = "arn:aws:sagemaker:us-west-2:123456789012:endpoint/test" -MOCK_MODEL_METADATA_FOR_MLFLOW = {MLFLOW_MODEL_PATH: "s3://some_path"} +MOCK_MODEL_METADATA_FOR_MLFLOW = { + MLFLOW_MODEL_PATH: "s3://some_path", + MLFLOW_TRACKING_ARN: "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", +} class ModelBuilderMock: @@ -274,6 +277,7 @@ def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry): f"&x-defaultImageUsage={ImageUriOption.DEFAULT_IMAGE.value}" f"&x-endpointArn={MOCK_ENDPOINT_ARN}" f"&x-mlflowModelPathType=2" + f"&x-mlflowTrackingServerArn={MOCK_MODEL_METADATA_FOR_MLFLOW[MLFLOW_TRACKING_ARN]}" f"&x-latency={latency}" ) diff --git a/tests/unit/sagemaker/telemetry/test_telemetry_logging.py b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py index 9107256b5b..bd8db82a16 100644 --- a/tests/unit/sagemaker/telemetry/test_telemetry_logging.py +++ b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py @@ -300,3 +300,39 @@ def test_get_default_sagemaker_session_with_no_region(self): assert "Must setup local AWS configuration with a region supported by SageMaker." in str( context.exception ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_valid_region(self, mock_get_region, mock_get_accountId): + """Test to verify telemetry request is sent when region is valid""" + mock_get_accountId.return_value = "testAccountId" + mock_session = MagicMock() + + # Test with valid region + mock_get_region.return_value = "us-east-1" + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + _send_telemetry_request(1, [1, 2], mock_session) + # Assert telemetry request was sent + mock_requests_helper.assert_called_once_with( + "https://sm-pysdk-t-us-east-1.s3.us-east-1.amazonaws.com/telemetry?" + "x-accountId=testAccountId&x-status=1&x-feature=1,2", + 2, + ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_invalid_region(self, mock_get_region, mock_get_accountId): + """Test to verify telemetry request is not sent when region is invalid""" + mock_get_accountId.return_value = "testAccountId" + mock_session = MagicMock() + + # Test with invalid region + mock_get_region.return_value = "invalid-region" + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + _send_telemetry_request(1, [1, 2], mock_session) + # Assert telemetry request was not sent + mock_requests_helper.assert_not_called() diff --git a/tests/unit/sagemaker/test_studio.py b/tests/unit/sagemaker/test_studio.py index 47528e1f36..81302894ab 100644 --- a/tests/unit/sagemaker/test_studio.py +++ b/tests/unit/sagemaker/test_studio.py @@ -12,7 +12,8 @@ # language governing permissions and limitations under the License. # language governing permissions and limitations under the License. from __future__ import absolute_import - +import os +from pathlib import Path from sagemaker._studio import ( _append_project_tags, _find_config, @@ -21,6 +22,66 @@ ) +def test_find_config_cross_platform(tmpdir): + """Test _find_config works correctly across different platforms.""" + # Create a completely separate directory for isolated tests + import tempfile + + with tempfile.TemporaryDirectory() as isolated_root: + # Setup test directory structure for positive tests + config = tmpdir.join(".sagemaker-code-config") + config.write('{"sagemakerProjectId": "proj-1234"}') + + # Test 1: Direct parent directory + working_dir = tmpdir.mkdir("sub") + found_path = _find_config(working_dir) + assert found_path == config + + # Test 2: Deeply nested directories + nested_dir = tmpdir.mkdir("deep").mkdir("nested").mkdir("path") + found_path = _find_config(nested_dir) + assert found_path == config + + # Test 3: Start from root directory + import os + + root_dir = os.path.abspath(os.sep) + found_path = _find_config(root_dir) + assert found_path is None + + # Test 4: No config file in path - using truly isolated directory + isolated_path = Path(isolated_root) / "nested" / "path" + isolated_path.mkdir(parents=True) + found_path = _find_config(isolated_path) + assert found_path is None + + +def test_find_config_path_separators(tmpdir): + """Test _find_config handles different path separator styles. + + Tests: + 1. Forward slashes + 2. Backslashes + 3. Mixed separators + """ + # Setup + config = tmpdir.join(".sagemaker-code-config") + config.write('{"sagemakerProjectId": "proj-1234"}') + base_path = str(tmpdir) + + # Always include the OS native path and forward slashes (which are equivalent on all OS) + paths = [os.path.join(base_path, "dir1", "dir2"), "/".join([base_path, "dir1", "dir2"])] + + # Only on Windows add the backslashes and mixed separator test cases. + if os.name == "nt": + paths.extend(["\\".join([base_path, "dir1", "dir2"]), base_path + "/dir1\\dir2"]) + + for path in paths: + os.makedirs(path, exist_ok=True) + found_path = _find_config(path) + assert found_path == config + + def test_find_config(tmpdir): path = tmpdir.join(".sagemaker-code-config") path.write('{"sagemakerProjectId": "proj-1234"}') diff --git a/tests/unit/sagemaker/workflow/test_notebook_job_step.py b/tests/unit/sagemaker/workflow/test_notebook_job_step.py index 9cc34ee243..6a5bb20daa 100644 --- a/tests/unit/sagemaker/workflow/test_notebook_job_step.py +++ b/tests/unit/sagemaker/workflow/test_notebook_job_step.py @@ -12,11 +12,13 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import os import unittest + from mock import Mock, patch -from sagemaker.workflow.notebook_job_step import NotebookJobStep from sagemaker.workflow.functions import Join +from sagemaker.workflow.notebook_job_step import NotebookJobStep REGION = "us-west-2" PIPELINE_NAME = "test-pipeline-name" @@ -573,3 +575,62 @@ def _create_step_with_required_fields(self): image_uri=IMAGE_URI, kernel_name=KERNEL_NAME, ) + + def test_environment_variables_not_shared(self): + """Test that environment variables are not shared between NotebookJob steps""" + # Setup shared environment variables + shared_env_vars = {"test": "test"} + + # Create two steps with the same environment variables dictionary + step1 = NotebookJobStep( + name="step1", + input_notebook=INPUT_NOTEBOOK, + image_uri=IMAGE_URI, + kernel_name=KERNEL_NAME, + environment_variables=shared_env_vars, + ) + + step2 = NotebookJobStep( + name="step2", + input_notebook=INPUT_NOTEBOOK, + image_uri=IMAGE_URI, + kernel_name=KERNEL_NAME, + environment_variables=shared_env_vars, + ) + + # Get the arguments for both steps + step1_args = step1.arguments + step2_args = step2.arguments + + # Verify that the environment variables are different objects + self.assertIsNot( + step1_args["Environment"], + step2_args["Environment"], + "Environment dictionaries should be different objects", + ) + + # Verify that modifying one step's environment doesn't affect the other + step1_env = step1_args["Environment"] + step2_env = step2_args["Environment"] + + # Both should have the original test value + self.assertEqual(step1_env["test"], "test") + self.assertEqual(step2_env["test"], "test") + + # Modify step1's environment + step1_env["test"] = "modified" + + # Verify step2's environment remains unchanged + self.assertEqual(step2_env["test"], "test") + + # Verify notebook names are correct for each step + self.assertEqual( + step1_env["SM_INPUT_NOTEBOOK_NAME"], + os.path.basename(INPUT_NOTEBOOK), + "Step 1 should have its own notebook name", + ) + self.assertEqual( + step2_env["SM_INPUT_NOTEBOOK_NAME"], + os.path.basename(INPUT_NOTEBOOK), + "Step 2 should have its own notebook name", + ) diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index 14c2d442eb..523b981736 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -99,7 +99,7 @@ def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock RoleArn=pipeline_role_arn, ) pipeline.upsert() - assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=pipeline_role_arn, @@ -130,7 +130,7 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar role_arn=role_arn, parallelism_config=dict(MaxParallelExecutionSteps=10), ) - assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn, @@ -149,7 +149,7 @@ def test_pipeline_create_and_start_with_parallelism_config(sagemaker_session_moc role_arn=role_arn, parallelism_config=dict(MaxParallelExecutionSteps=10), ) - assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn, @@ -168,7 +168,7 @@ def test_pipeline_create_and_start_with_parallelism_config(sagemaker_session_moc # Specify ParallelismConfiguration to another value which will be honored in backend pipeline.start(parallelism_config=dict(MaxParallelExecutionSteps=20)) - assert sagemaker_session_mock.sagemaker_client.start_pipeline_execution.called_with( + sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with( PipelineName="MyPipeline", ParallelismConfiguration={"MaxParallelExecutionSteps": 20}, ) @@ -209,7 +209,7 @@ def test_pipeline_update(sagemaker_session_mock, role_arn): assert not pipeline.steps pipeline.update(role_arn=role_arn) assert len(json.loads(pipeline.definition())["Steps"]) == 0 - assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn ) @@ -253,7 +253,7 @@ def test_pipeline_update(sagemaker_session_mock, role_arn): pipeline.update(role_arn=role_arn) assert len(json.loads(pipeline.definition())["Steps"]) == 3 - assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn ) @@ -345,7 +345,11 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar role_arn=role_arn, parallelism_config=dict(MaxParallelExecutionSteps=10), ) - assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + pipeline.update( + role_arn=role_arn, + parallelism_config={"MaxParallelExecutionSteps": 10}, + ) + sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn, @@ -418,13 +422,11 @@ def _raise_does_already_exists_client_error(**kwargs): sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_once_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn ) - assert sagemaker_session_mock.sagemaker_client.list_tags.called_with( - ResourceArn="mock_pipeline_arn" - ) + sagemaker_session_mock.sagemaker_client.list_tags.assert_called_with(ResourceArn="pipeline-arn") tags.append({"Key": "dummy", "Value": "dummy_tag"}) - assert sagemaker_session_mock.sagemaker_client.add_tags.called_with( - ResourceArn="mock_pipeline_arn", Tags=tags + sagemaker_session_mock.sagemaker_client.add_tags.assert_called_with( + ResourceArn="pipeline-arn", Tags=tags ) @@ -523,7 +525,7 @@ def test_pipeline_delete(sagemaker_session_mock): sagemaker_session=sagemaker_session_mock, ) pipeline.delete() - assert sagemaker_session_mock.sagemaker_client.delete_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.delete_pipeline.assert_called_with( PipelineName="MyPipeline", ) @@ -536,7 +538,7 @@ def test_pipeline_describe(sagemaker_session_mock): sagemaker_session=sagemaker_session_mock, ) pipeline.describe() - assert sagemaker_session_mock.sagemaker_client.describe_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.describe_pipeline.assert_called_with( PipelineName="MyPipeline", ) @@ -552,17 +554,17 @@ def test_pipeline_start(sagemaker_session_mock): sagemaker_session=sagemaker_session_mock, ) pipeline.start() - assert sagemaker_session_mock.start_pipeline_execution.called_with( + sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with( PipelineName="MyPipeline", ) pipeline.start(execution_display_name="pipeline-execution") - assert sagemaker_session_mock.start_pipeline_execution.called_with( + sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with( PipelineName="MyPipeline", PipelineExecutionDisplayName="pipeline-execution" ) pipeline.start(parameters=dict(alpha="epsilon")) - assert sagemaker_session_mock.start_pipeline_execution.called_with( + sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with( PipelineName="MyPipeline", PipelineParameters=[{"Name": "alpha", "Value": "epsilon"}] ) @@ -821,10 +823,8 @@ def test_pipeline_build_parameters_from_execution(sagemaker_session_mock): pipeline_execution_arn=reference_execution_arn, parameter_value_overrides=parameter_value_overrides, ) - assert ( - sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with( - PipelineExecutionArn=reference_execution_arn - ) + sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_called_with( + PipelineExecutionArn=reference_execution_arn ) assert len(parameters) == 1 assert parameters["TestParameterName"] == "NewParameterValue" @@ -850,10 +850,8 @@ def test_pipeline_build_parameters_from_execution_with_invalid_overrides(sagemak + f"are not present in the pipeline execution: {reference_execution_arn}" in str(error) ) - assert ( - sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with( - PipelineExecutionArn=reference_execution_arn - ) + sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_called_with( + PipelineExecutionArn=reference_execution_arn ) @@ -908,24 +906,23 @@ def test_pipeline_execution_basics(sagemaker_session_mock): ) execution = pipeline.start() execution.stop() - assert sagemaker_session_mock.sagemaker_client.stop_pipeline_execution.called_with( + sagemaker_session_mock.sagemaker_client.stop_pipeline_execution.assert_called_with( PipelineExecutionArn="my:arn" ) execution.describe() - assert sagemaker_session_mock.sagemaker_client.describe_pipeline_execution.called_with( + sagemaker_session_mock.sagemaker_client.describe_pipeline_execution.assert_called_with( PipelineExecutionArn="my:arn" ) steps = execution.list_steps() - assert sagemaker_session_mock.sagemaker_client.describe_pipeline_execution_steps.called_with( + sagemaker_session_mock.sagemaker_client.list_pipeline_execution_steps.assert_called_with( PipelineExecutionArn="my:arn" ) assert len(steps) == 1 list_parameters_response = execution.list_parameters() - assert ( - sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with( - PipelineExecutionArn="my:arn" - ) + sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_called_with( + PipelineExecutionArn="my:arn" ) + parameter_list = list_parameters_response["PipelineParameters"] assert len(parameter_list) == 1 assert parameter_list[0]["Name"] == "TestParameterName" diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index b3d667a1c3..84906ce620 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -671,7 +671,7 @@ def test_processing_step_normalizes_args_with_local_code(mock_normalize_args, sc mock_normalize_args.return_value = [step.inputs, step.outputs] step.to_request() mock_normalize_args.assert_called_with( - job_name="MyProcessingStep-3e89f0c7e101c356cbedf27d9d27e9db", + job_name=None, arguments=step.job_arguments, inputs=step.inputs, outputs=step.outputs, diff --git a/tests/unit/sagemaker/workflow/test_utilities.py b/tests/unit/sagemaker/workflow/test_utilities.py index e65d3ea933..b284ced91e 100644 --- a/tests/unit/sagemaker/workflow/test_utilities.py +++ b/tests/unit/sagemaker/workflow/test_utilities.py @@ -31,14 +31,14 @@ def test_hash_file(): with tempfile.NamedTemporaryFile() as tmp: tmp.write("hashme".encode()) hash = hash_file(tmp.name) - assert hash == "d41d8cd98f00b204e9800998ecf8427e" + assert hash == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" def test_hash_file_uri(): with tempfile.NamedTemporaryFile() as tmp: tmp.write("hashme".encode()) hash = hash_file(f"file:///{tmp.name}") - assert hash == "d41d8cd98f00b204e9800998ecf8427e" + assert hash == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" def test_hash_files_or_dirs_with_file(): diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index 48b1d762c3..b18ed71f9b 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -80,10 +80,11 @@ def test_repack_model_step(estimator): assert hyperparameters["inference_script"] == '"dummy_script.py"' assert hyperparameters["model_archive"] == '"s3://my-bucket/model.tar.gz"' assert hyperparameters["sagemaker_program"] == f'"{REPACK_SCRIPT_LAUNCHER}"' - assert ( - hyperparameters["sagemaker_submit_directory"] - == '"s3://my-bucket/MyRepackModelStep-b5ea77f701b47a8d075605497462ccc2/source/sourcedir.tar.gz"' - ) + + # ex: "gits3://my-bucket/sagemaker-scikit-learn-2025-04-07-20-39-38-854/source/sourcedir.tar.gz" + sagemaker_submit_directory = hyperparameters["sagemaker_submit_directory"] + assert sagemaker_submit_directory.startswith('"s3://my-bucket/sagemaker-scikit-learn-') + assert sagemaker_submit_directory.endswith('/source/sourcedir.tar.gz"') del request_dict["Arguments"]["HyperParameters"] del request_dict["Arguments"]["AlgorithmSpecification"]["TrainingImage"] diff --git a/tests/unit/test_common.py b/tests/unit/test_common.py index 8fe7383fe4..9fe49ad448 100644 --- a/tests/unit/test_common.py +++ b/tests/unit/test_common.py @@ -16,12 +16,12 @@ import tempfile import pytest import itertools +from sagemaker.deserializers import RecordDeserializer +from sagemaker.serializers import RecordSerializer from scipy.sparse import coo_matrix from sagemaker.amazon.common import ( - RecordDeserializer, write_numpy_to_dense_tensor, read_recordio, - RecordSerializer, write_spmatrix_to_sparse_tensor, ) from sagemaker.amazon.record_pb2 import Record diff --git a/tests/unit/test_exception_on_bad_status.py b/tests/unit/test_exception_on_bad_status.py index 2ef017efd3..dc53c97799 100644 --- a/tests/unit/test_exception_on_bad_status.py +++ b/tests/unit/test_exception_on_bad_status.py @@ -52,7 +52,7 @@ def test_raise_when_failed_created_package(): False ), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not" except Exception as e: - assert type(e) == sagemaker.exceptions.UnexpectedStatusException + assert isinstance(e, sagemaker.exceptions.UnexpectedStatusException) assert e.actual_status == "EnRoute" assert "Completed" in e.allowed_statuses @@ -73,7 +73,7 @@ def test_does_raise_when_incorrect_job_status(): False ), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not" except Exception as e: - assert type(e) == sagemaker.exceptions.UnexpectedStatusException + assert isinstance(e, sagemaker.exceptions.UnexpectedStatusException) assert e.actual_status == "Failed" assert "Completed" in e.allowed_statuses assert "Stopped" in e.allowed_statuses @@ -92,7 +92,7 @@ def test_does_raise_capacity_error_when_incorrect_job_status(): ) assert False, "sagemaker.exceptions.CapacityError should have been raised but was not" except Exception as e: - assert type(e) == sagemaker.exceptions.CapacityError + assert isinstance(e, sagemaker.exceptions.CapacityError) assert e.actual_status == "Failed" assert "Completed" in e.allowed_statuses assert "Stopped" in e.allowed_statuses @@ -114,6 +114,6 @@ def test_raise_when_failed_to_deploy_endpoint(): False ), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not" except Exception as e: - assert type(e) == sagemaker.exceptions.UnexpectedStatusException + assert isinstance(e, sagemaker.exceptions.UnexpectedStatusException) assert e.actual_status == "Failed" assert "InService" in e.allowed_statuses diff --git a/tests/unit/test_hyperparameter.py b/tests/unit/test_hyperparameter.py index ba7a363c40..edb2de97ee 100644 --- a/tests/unit/test_hyperparameter.py +++ b/tests/unit/test_hyperparameter.py @@ -62,7 +62,7 @@ def test_validated(): def test_data_type(): x = Test() x.validated = 66 - assert type(x.validated) == Test.__dict__["validated"].data_type + assert isinstance(x.validated, Test.__dict__["validated"].data_type) def test_from_string(): diff --git a/tests/unit/test_inputs.py b/tests/unit/test_inputs.py index 7d9c2b2c2f..133c31eb75 100644 --- a/tests/unit/test_inputs.py +++ b/tests/unit/test_inputs.py @@ -41,6 +41,8 @@ def test_training_input_all_arguments(): record_wrapping = "RecordIO" s3_data_type = "Manifestfile" input_mode = "Pipe" + hub_access_config = {"HubContentArn": "some-hub-content-arn"} + model_access_config = {"AcceptEula": True} result = TrainingInput( s3_data=prefix, distribution=distribution, @@ -49,6 +51,8 @@ def test_training_input_all_arguments(): content_type=content_type, record_wrapping=record_wrapping, s3_data_type=s3_data_type, + hub_access_config=hub_access_config, + model_access_config=model_access_config, ) expected = { "DataSource": { @@ -56,6 +60,8 @@ def test_training_input_all_arguments(): "S3DataDistributionType": distribution, "S3DataType": s3_data_type, "S3Uri": prefix, + "ModelAccessConfig": model_access_config, + "HubAccessConfig": hub_access_config, } }, "CompressionType": compression, @@ -76,6 +82,8 @@ def test_training_input_all_arguments_heterogeneous_cluster(): s3_data_type = "Manifestfile" instance_groups = ["data-server"] input_mode = "Pipe" + hub_access_config = {"HubContentArn": "some-hub-content-arn"} + model_access_config = {"AcceptEula": True} result = TrainingInput( s3_data=prefix, distribution=distribution, @@ -85,6 +93,8 @@ def test_training_input_all_arguments_heterogeneous_cluster(): record_wrapping=record_wrapping, s3_data_type=s3_data_type, instance_groups=instance_groups, + hub_access_config=hub_access_config, + model_access_config=model_access_config, ) expected = { @@ -94,6 +104,8 @@ def test_training_input_all_arguments_heterogeneous_cluster(): "S3DataType": s3_data_type, "S3Uri": prefix, "InstanceGroupNames": instance_groups, + "ModelAccessConfig": model_access_config, + "HubAccessConfig": hub_access_config, } }, "CompressionType": compression, diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index c93a381c11..dc21f50b68 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -206,6 +206,32 @@ def test_load_config_with_model_channel_no_inputs(estimator): assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME +def test_load_config_with_access_configs(estimator): + estimator.model_uri = MODEL_URI + estimator.model_channel_name = MODEL_CHANNEL_NAME + estimator.model_access_config = {"AcceptEula": True} + estimator.hub_access_config = {"HubContentArn": "dummy_arn"} + + config = _Job._load_config(inputs=None, estimator=estimator) + assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == MODEL_URI + assert config["input_config"][0]["ChannelName"] == MODEL_CHANNEL_NAME + assert config["role"] == ROLE + assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH + assert "KmsKeyId" not in config["output_config"] + assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT + assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE + assert config["resource_config"]["VolumeSizeInGB"] == VOLUME_SIZE + assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME + assert ( + config["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] + == estimator.model_access_config + ) + assert ( + config["input_config"][0]["DataSource"]["S3DataSource"]["HubAccessConfig"] + == estimator.hub_access_config + ) + + def test_load_config_with_code_channel(framework): inputs = TrainingInput(BUCKET_NAME) @@ -347,20 +373,43 @@ def test_format_record_set_list_input(): @pytest.mark.parametrize( - "channel_uri, channel_name, content_type, input_mode", + "channel_uri, channel_name, content_type, input_mode, model_access_config, hub_access_config", [ - [MODEL_URI, MODEL_CHANNEL_NAME, "application/x-sagemaker-model", "File"], - [CODE_URI, CODE_CHANNEL_NAME, None, None], + [ + MODEL_URI, + MODEL_CHANNEL_NAME, + "application/x-sagemaker-model", + "File", + {"AcceptEula": True}, + None, + ], + [CODE_URI, CODE_CHANNEL_NAME, None, None, None, {"HubContentArn": "dummy_arn"}], ], ) -def test_prepare_channel(channel_uri, channel_name, content_type, input_mode): +def test_prepare_channel( + channel_uri, channel_name, content_type, input_mode, model_access_config, hub_access_config +): channel = _Job._prepare_channel( - [], channel_uri, channel_name, content_type=content_type, input_mode=input_mode + [], + channel_uri, + channel_name, + content_type=content_type, + input_mode=input_mode, + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) assert channel["DataSource"]["S3DataSource"]["S3Uri"] == channel_uri assert channel["DataSource"]["S3DataSource"]["S3DataDistributionType"] == "FullyReplicated" assert channel["DataSource"]["S3DataSource"]["S3DataType"] == "S3Prefix" + if hub_access_config: + assert channel["DataSource"]["S3DataSource"]["HubAccessConfig"] == hub_access_config + else: + assert "HubAccessConfig" not in channel["DataSource"]["S3DataSource"] + if model_access_config: + assert channel["DataSource"]["S3DataSource"]["ModelAccessConfig"] == model_access_config + else: + assert "ModelAccessConfig" not in channel["DataSource"]["S3DataSource"] assert channel["ChannelName"] == channel_name assert "CompressionType" not in channel assert "RecordWrapperType" not in channel @@ -546,6 +595,23 @@ def test_format_string_uri_input_string(): assert s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] == inputs +def test_format_string_uri_input_string_with_access_configs(): + inputs = BUCKET_NAME + model_access_config = {"AcceptEula": True} + hub_access_config = {"HubContentArn": "dummy_arn"} + + s3_uri_input = _Job._format_string_uri_input( + inputs, model_access_config=model_access_config, hub_access_config=hub_access_config + ) + + assert s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] == inputs + assert s3_uri_input.config["DataSource"]["S3DataSource"]["HubAccessConfig"] == hub_access_config + assert ( + s3_uri_input.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] + == model_access_config + ) + + def test_format_string_uri_file_system_input(): file_system_id = "fs-fd85e556" file_system_type = "EFS" @@ -585,6 +651,26 @@ def test_format_string_uri_input(): ) +def test_format_string_uri_input_with_access_configs(): + inputs = TrainingInput(BUCKET_NAME) + model_access_config = {"AcceptEula": True} + hub_access_config = {"HubContentArn": "dummy_arn"} + + s3_uri_input = _Job._format_string_uri_input( + inputs, model_access_config=model_access_config, hub_access_config=hub_access_config + ) + + assert ( + s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] + == inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ) + assert s3_uri_input.config["DataSource"]["S3DataSource"]["HubAccessConfig"] == hub_access_config + assert ( + s3_uri_input.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] + == model_access_config + ) + + def test_format_string_uri_input_exception(): inputs = 1 diff --git a/tests/unit/test_predictor_async.py b/tests/unit/test_predictor_async.py index fa2d6da6c7..c9f12ff023 100644 --- a/tests/unit/test_predictor_async.py +++ b/tests/unit/test_predictor_async.py @@ -233,7 +233,7 @@ def test_async_predict_call_verify_exceptions(): with pytest.raises( PollingTimeoutError, match=f"No result at {ASYNC_OUTPUT_LOCATION} after polling for " - f"{DEFAULT_WAITER_CONFIG.delay*DEFAULT_WAITER_CONFIG.max_attempts}" + f"{DEFAULT_WAITER_CONFIG.delay * DEFAULT_WAITER_CONFIG.max_attempts}" f" seconds. Inference could still be running", ): predictor_async.predict(input_path=input_location, waiter_config=DEFAULT_WAITER_CONFIG) @@ -253,7 +253,7 @@ def test_async_predict_call_verify_exceptions_with_null_failure_path(): with pytest.raises( PollingTimeoutError, match=f"No result at {ASYNC_OUTPUT_LOCATION} after polling for " - f"{DEFAULT_WAITER_CONFIG.delay*DEFAULT_WAITER_CONFIG.max_attempts}" + f"{DEFAULT_WAITER_CONFIG.delay * DEFAULT_WAITER_CONFIG.max_attempts}" f" seconds. Inference could still be running", ): predictor_async.predict(input_path=input_location, waiter_config=DEFAULT_WAITER_CONFIG) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 6076d44e90..34d3c6784b 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -23,7 +23,10 @@ from sagemaker import image_uris from sagemaker.pytorch import defaults from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel -from sagemaker.pytorch.estimator import _get_training_recipe_image_uri +from sagemaker.pytorch.estimator import ( + _get_training_recipe_image_uri, + _get_training_recipe_gpu_script, +) from sagemaker.instance_group import InstanceGroup from sagemaker.session_settings import SessionSettings @@ -1049,6 +1052,52 @@ def test_training_recipe_for_trainium(sagemaker_session): assert pytorch.distribution == expected_distribution +@pytest.mark.parametrize( + "test_case", + [ + { + "script": "llama_pretrain.py", + "recipe": { + "model": { + "model_type": "llama_v3", + }, + }, + }, + { + "script": "mistral_pretrain.py", + "recipe": { + "model": { + "model_type": "mistral", + }, + }, + }, + { + "script": "deepseek_pretrain.py", + "recipe": { + "model": { + "model_type": "deepseek_llamav3", + }, + }, + }, + { + "script": "deepseek_pretrain.py", + "recipe": { + "model": { + "model_type": "deepseek_qwenv2", + }, + }, + }, + ], +) +@patch("shutil.copyfile") +def test_get_training_recipe_gpu_script(mock_copyfile, test_case): + script = test_case["script"] + recipe = test_case["recipe"] + mock_copyfile.return_value = None + + assert _get_training_recipe_gpu_script("code_dir", recipe, "source_dir") == script + + def test_training_recipe_for_trainium_custom_source_dir(sagemaker_session): container_log_level = '"logging.INFO"' diff --git a/tests/unit/test_s3.py b/tests/unit/test_s3.py index a226954986..b54552cacb 100644 --- a/tests/unit/test_s3.py +++ b/tests/unit/test_s3.py @@ -17,6 +17,7 @@ from mock import Mock from sagemaker import s3 +from sagemaker.s3_utils import is_s3_url BUCKET_NAME = "mybucket" REGION = "us-west-2" @@ -132,6 +133,34 @@ def test_parse_s3_url_fail(): assert "Expecting 's3' scheme" in str(error) +@pytest.mark.parametrize( + "input_url", + [ + ("s3://bucket/code_location"), + ("s3://bucket/code_location/sub_location"), + ("s3://bucket/code_location/sub_location/"), + ("s3://bucket/"), + ("s3://bucket"), + ], +) +def test_is_s3_url_true(input_url): + assert is_s3_url(input_url) is True + + +@pytest.mark.parametrize( + "input_url", + [ + ("bucket/code_location"), + ("bucket/code_location/sub_location"), + ("sub_location/"), + ("s3/bucket/"), + ("t3://bucket"), + ], +) +def test_is_s3_url_false(input_url): + assert is_s3_url(input_url) is False + + @pytest.mark.parametrize( "expected_output, input_args", [ diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d2d2c3bcfb..e3d763e612 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -5006,6 +5006,7 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session) domain = "COMPUTER_VISION" task = "IMAGE_CLASSIFICATION" sample_payload_url = "s3://test-bucket/model" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5094,6 +5095,8 @@ def test_create_model_package_from_containers_with_source_uri_and_inference_spec skip_model_validation = "All" source_uri = "dummy-source-uri" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + created_versioned_mp_arn = ( "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" ) @@ -5149,6 +5152,7 @@ def test_create_model_package_from_containers_with_source_uri_for_unversioned_mp approval_status = ("Approved",) skip_model_validation = "All" source_uri = "dummy-source-uri" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} with pytest.raises( ValueError, @@ -5221,6 +5225,8 @@ def test_create_model_package_from_containers_with_source_uri_set_to_mp(sagemake return_value={"ModelPackageArn": created_versioned_mp_arn} ) + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + sagemaker_session.create_model_package_from_containers( model_package_group_name=model_package_group_name, containers=containers, @@ -5443,6 +5449,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s approval_status = ("Approved",) description = "description" customer_metadata_properties = {"key1": "value1"} + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5510,6 +5517,7 @@ def test_create_model_package_from_containers_with_one_instance_types( approval_status = ("Approved",) description = "description" customer_metadata_properties = {"key1": "value1"} + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -7183,3 +7191,97 @@ def test_delete_hub_content_reference(sagemaker_session): } sagemaker_session.sagemaker_client.delete_hub_content_reference.assert_called_with(**request) + + +def test_create_model_package_from_containers_to_create_mpg_if_not_present_without_search( + sagemaker_session, +): + sagemaker_session.sagemaker_client.search.side_effect = Exception() + sagemaker_session.sagemaker_client.search.return_value = {} + sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [ + { + "ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg"}], + "NextToken": "NextToken", + }, + {"ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg-test"}]}, + ] + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", + model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg", + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [ + {"ModelPackageGroupSummaryList": []} + ] + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with( + ModelPackageGroupName="mock-mpg" + ) + + +def test_create_model_package_from_containers_to_create_mpg_if_not_present(sagemaker_session): + # with search api + sagemaker_session.sagemaker_client.search.return_value = { + "Results": [ + { + "ModelPackageGroup": { + "ModelPackageGroupName": "mock-mpg", + "ModelPackageGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:model-package-group/mock-mpg", + } + } + ] + } + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", + model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg", + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with( + ModelPackageGroupName="mock-mpg" + ) + + +def test_get_most_recently_created_approved_model_package(sagemaker_session): + sagemaker_session.sagemaker_client.list_model_packages.side_effect = [ + ( + { + "ModelPackageSummaryList": [], + "NextToken": "NextToken", + } + ), + ( + { + "ModelPackageSummaryList": [ + { + "CreationTime": 1697440162, + "ModelApprovalStatus": "Approved", + "ModelPackageArn": "arn:aws:sagemaker:us-west-2:123456789012:model-package/model-version/3", + "ModelPackageGroupName": "model-version", + "ModelPackageVersion": 3, + }, + ], + } + ), + ] + model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name="mpg" + ) + assert model_package is not None + assert ( + model_package.model_package_arn + == "arn:aws:sagemaker:us-west-2:123456789012:model-package/model-version/3" + ) diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index f0325b79e9..b4d21008b5 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -46,7 +46,54 @@ from sagemaker.workflow.parameters import ParameterString, ParameterInteger from src.sagemaker.tuner import InstanceConfig -from .tuner_test_utils import * # noqa: F403 +from .tuner_test_utils import ( + BASE_JOB_NAME, + BUCKET_NAME, + CategoricalParameter, + ContinuousParameter, + DATA_DIR, + EARLY_STOPPING_TYPE, + Estimator, + ESTIMATOR, + ESTIMATOR_NAME, + ESTIMATOR_NAME_TWO, + ESTIMATOR_TWO, + FRAMEWORK_VERSION, + HYPERPARAMETER_RANGES, + HYPERPARAMETER_RANGES_TWO, + IMAGE_NAME, + INPUTS, + INSTANCE_COUNT, + INSTANCE_TYPE, + IntegerParameter, + JOB_NAME, + LIST_TAGS_RESULT, + MAX_JOBS, + MAX_PARALLEL_JOBS, + METRIC_DEFINITIONS, + MODEL_DATA, + MULTI_ALGO_TUNING_JOB_DETAILS, + NUM_COMPONENTS, + OBJECTIVE_METRIC_NAME, + OBJECTIVE_METRIC_NAME_TWO, + OBJECTIVE_TYPE, + PCA, + PY_VERSION, + REGION, + ROLE, + SAGEMAKER_SESSION, + SCRIPT_NAME, + STRATEGY, + TAGS, + TRAINING_JOB_DESCRIPTION, + TRAINING_JOB_NAME, + TUNING_JOB_DETAILS, + WarmStartConfig, + WarmStartTypes, + WARM_START_CONFIG, + ENDPOINT_DESC, + ENDPOINT_CONFIG_DESC, +) @pytest.fixture() diff --git a/tox.ini b/tox.ini index b16c0d2f0b..c47d206380 100644 --- a/tox.ini +++ b/tox.ini @@ -5,7 +5,7 @@ [tox] isolated_build = true -envlist = black-format,flake8,pylint,docstyle,sphinx,doc8,twine,py38,py39,py310,py311 +envlist = black-format,flake8,pylint,docstyle,sphinx,doc8,twine,py39,py310,py311,py312 skip_missing_interpreters = False @@ -21,13 +21,13 @@ exclude = tests/data/ venv/ env/ - tests/unit/test_tensorboard.py # excluding this file for time being + tests/unit/test_tensorboard.py max-complexity = 10 ignore = C901, - E203, # whitespace before ':': Black disagrees with and explicitly violates this. + E203, FI10, FI12, FI13, @@ -35,7 +35,7 @@ ignore = FI15, FI16, FI17, - FI18, # __future__ import "annotations" missing -> check only Python 3.7 compatible + FI18, FI50, FI51, FI52, @@ -67,7 +67,7 @@ markers = [testenv] setenv = PYTHONHASHSEED=42 -pip_version = pip==21.3 +pip_version = pip==24.3 passenv = AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY @@ -82,15 +82,18 @@ passenv = # Can be used to specify which tests to run, e.g.: tox -- -s commands = python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')" - pip install 'apache-airflow==2.9.3' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.3/constraints-3.8.txt" - pip install 'torch==2.0.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' - pip install 'torchvision==0.15.2+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' - pip install 'dill>=0.3.8' + pip install 'apache-airflow==2.10.4' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.4/constraints-3.9.txt" + pip install 'torch==2.3.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' + pip install 'torchvision==0.18.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' + pip install 'dill>=0.3.9' pytest {posargs} deps = .[test] depends = - {py38,py39,py310,p311}: clean + {py39,py310,py311,py312}: clean + +[testenv:py312] +basepython = python3.12 [testenv:runcoverage] description = run unit tests with coverage @@ -105,6 +108,7 @@ deps = -r requirements/tox/flake8_requirements.txt commands = flake8 +basepython = python3.12 [testenv:pylint] skipdist = true @@ -112,7 +116,7 @@ skip_install = true deps = -r requirements/tox/pylint_requirements.txt commands = - python -m pylint --rcfile=.pylintrc -j 0 src/sagemaker + python -m pylint --rcfile=.pylintrc -j 0 src/sagemaker --fail-under=9.9 [testenv:spelling] skipdist = true @@ -132,14 +136,14 @@ commands = twine check dist/*.tar.gz [testenv:sphinx] -pip_version = pip==21.3 +pip_version = pip==24.3 changedir = doc # pip install requirements.txt is separate as RTD does it in separate steps # having the requirements.txt installed in deps above results in Double Requirement exception # https://github.com/pypa/pip/issues/988 commands = pip install --exists-action=w -r requirements.txt - sphinx-build -T -W -b html -d _build/doctrees-readthedocs -D language=en . _build/html + sphinx-build -T -b html -d _build/doctrees-readthedocs -D language=en . _build/html [testenv:doc8] deps =