diff --git a/.github/workflows/codebuild-canaries.yml b/.github/workflows/codebuild-canaries.yml new file mode 100644 index 0000000000..a6b5a978ef --- /dev/null +++ b/.github/workflows/codebuild-canaries.yml @@ -0,0 +1,24 @@ +name: Canaries +on: + schedule: + - cron: "0 */3 * * *" + workflow_dispatch: + +permissions: + id-token: write # This is required for requesting the JWT + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Integ Tests + uses: aws-actions/aws-codebuild-run-build@v1 + id: codebuild + with: + project-name: sagemaker-python-sdk-canaries diff --git a/CHANGELOG.md b/CHANGELOG.md index e68653ce0d..742e46d127 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,98 @@ # Changelog +## v2.240.0 (2025-02-25) + +### Features + + * Add support for TGI Neuronx 0.0.27 and HF PT 2.3.0 image in PySDK + +### Bug Fixes and Other Changes + + * Remove main function entrypoint in ModelBuilder dependency manager. + * forbid extras in Configs + * altconfig hubcontent and reenable integ test + * Merge branch 'master-rba' into local_merge + * py_version doc fixes + * Add backward compatbility for RecordSerializer and RecordDeserializer + * update image_uri_configs 02-21-2025 06:18:10 PST + * update image_uri_configs 02-20-2025 06:18:08 PST + +### Documentation Changes + + * Removed a line about python version requirements of training script which can misguide users. + +## v2.239.3 (2025-02-19) + +### Bug Fixes and Other Changes + + * added ap-southeast-7 and mx-central-1 for Jumpstart + * update image_uri_configs 02-19-2025 06:18:15 PST + +## v2.239.2 (2025-02-18) + +### Bug Fixes and Other Changes + + * Add warning about not supporting torch.nn.SyncBatchNorm + * pass in inference_ami_version to model_based endpoint type + * Fix hyperparameter strategy docs + * Add framework_version to all TensorFlowModel examples + * Move RecordSerializer and RecordDeserializer to sagemaker.serializers and sagemaker.deserialzers + +## v2.239.1 (2025-02-14) + +### Bug Fixes and Other Changes + + * keep sagemaker_session from being overridden to None + * Fix all type hint and docstrings for callable + * Fix the workshop link for Step Functions + * Fix Tensorflow doc link + * Fix FeatureGroup docstring + * Add type hint for ProcessingOutput + * Fix sourcedir.tar.gz filenames in docstrings + * Fix documentation for local mode + * bug in get latest version was getting the max sorted alphabetically + * Add cleanup logic to model builder integ tests for endpoints + * Fixed pagination failing while listing collections + * fix ValueError when updating a data quality monitoring schedule + * Add docstring for image_uris.retrieve + * Create GitHub action to trigger canaries + * update image_uri_configs 02-04-2025 06:18:00 PST + +## v2.239.0 (2025-02-01) + +### Features + + * Add support for deepseek recipes + +### Bug Fixes and Other Changes + + * mpirun protocol - distributed training with @remote decorator + * Allow telemetry only in supported regions + * Fix ssh host policy + +## v2.238.0 (2025-01-29) + +### Features + + * use jumpstart deployment config image as default optimization image + +### Bug Fixes and Other Changes + + * chore: add new images for HF TGI + * update image_uri_configs 01-29-2025 06:18:08 PST + * skip TF tests for unsupported versions + * Merge branch 'master-rba' into local_merge + * Add missing attributes to local resourceconfig + * update image_uri_configs 01-27-2025 06:18:13 PST + * update image_uri_configs 01-24-2025 06:18:11 PST + * add missing schema definition in docs + * Omegaconf upgrade + * SageMaker @remote function: Added multi-node functionality + * remove option + * fix typo + * fix tests + * Add an option for user to remove inputs and container artifacts when using local model trainer + ## v2.237.3 (2025-01-09) ### Bug Fixes and Other Changes diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 24226af4ee..65b7c0ee0c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -61,6 +61,10 @@ Before sending us a pull request, please ensure that: 1. Follow the instructions at [Modifying an EBS Volume Using Elastic Volumes (Console)](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/requesting-ebs-volume-modifications.html#modify-ebs-volume) to increase the EBS volume size associated with the newly created EC2 instance. 1. Wait 5-10min for the new EBS volume increase to finalize. 1. Allow EC2 to claim the additional space by stopping and then starting your EC2 host. +2. Set up a venv to manage dependencies: + 1. `python -m venv ~/.venv/myproject-env` to create the venv + 2. `source ~/.venv/myproject-env/bin/activate` to activate the venv + 3. `deactivate` to exit the venv ### Pull Down the Code @@ -74,8 +78,8 @@ Before sending us a pull request, please ensure that: ### Run the Unit Tests 1. Install tox using `pip install tox` -1. Install coverage using `pip install .[test]` -1. cd into the sagemaker-python-sdk folder: `cd sagemaker-python-sdk` or `cd /environment/sagemaker-python-sdk` +1. cd into the github project sagemaker-python-sdk folder: `cd sagemaker-python-sdk` or `cd /environment/sagemaker-python-sdk` +1. Install coverage using `pip install '.[test]'` 1. Run the following tox command and verify that all code checks and unit tests pass: `tox tests/unit` 1. You can also run a single test with the following command: `tox -e py310 -- -s -vv ::` 1. You can run coverage via runcvoerage env : `tox -e runcoverage -- tests/unit` or `tox -e py310 -- tests/unit --cov=sagemaker --cov-append --cov-report xml` diff --git a/VERSION b/VERSION index 1ca006360a..1b1f3a78e8 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.237.4.dev0 +2.240.1.dev0 diff --git a/doc/frameworks/pytorch/using_pytorch.rst b/doc/frameworks/pytorch/using_pytorch.rst index d415f38c27..4141dd84db 100644 --- a/doc/frameworks/pytorch/using_pytorch.rst +++ b/doc/frameworks/pytorch/using_pytorch.rst @@ -28,8 +28,6 @@ To train a PyTorch model by using the SageMaker Python SDK: Prepare a PyTorch Training Script ================================= -Your PyTorch training script must be a Python 3.6 compatible source file. - Prepare your script in a separate source file than the notebook, terminal session, or source file you're using to submit the script to SageMaker via a ``PyTorch`` Estimator. This will be discussed in further detail below. @@ -375,6 +373,9 @@ To initialize distributed training in your script, call `torch.distributed.init_process_group `_ with the desired backend and the rank of the current host. +Warning: Some torch features, such as (and likely not limited to) ``torch.nn.SyncBatchNorm`` +is not supported and its existence in ``init_process_group`` will cause an exception during +distributed training. .. code:: python diff --git a/doc/frameworks/tensorflow/deploying_tensorflow_serving.rst b/doc/frameworks/tensorflow/deploying_tensorflow_serving.rst index 1d7344fbbb..a645cd5a62 100644 --- a/doc/frameworks/tensorflow/deploying_tensorflow_serving.rst +++ b/doc/frameworks/tensorflow/deploying_tensorflow_serving.rst @@ -64,7 +64,7 @@ If you already have existing model artifacts in S3, you can skip training and de from sagemaker.tensorflow import TensorFlowModel - model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole') + model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole', framework_version='x.x.x') predictor = model.deploy(initial_instance_count=1, instance_type='ml.c5.xlarge') @@ -74,7 +74,7 @@ Python-based TensorFlow serving on SageMaker has support for `Elastic Inference from sagemaker.tensorflow import TensorFlowModel - model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole') + model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole', framework_version='x.x.x') predictor = model.deploy(initial_instance_count=1, instance_type='ml.c5.xlarge', accelerator_type='ml.eia1.medium') diff --git a/doc/frameworks/tensorflow/using_tf.rst b/doc/frameworks/tensorflow/using_tf.rst index 1e51b5f43a..5b888f95be 100644 --- a/doc/frameworks/tensorflow/using_tf.rst +++ b/doc/frameworks/tensorflow/using_tf.rst @@ -246,7 +246,7 @@ Training with parameter servers If you specify parameter_server as the value of the distribution parameter, the container launches a parameter server thread on each instance in the training cluster, and then executes your training code. You can find more information on -TensorFlow distributed training at `TensorFlow docs `__. +TensorFlow distributed training at `TensorFlow docs `__. To enable parameter server training: .. code:: python @@ -468,7 +468,7 @@ If you already have existing model artifacts in S3, you can skip training and de from sagemaker.tensorflow import TensorFlowModel - model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole') + model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole', framework_version='x.x.x') predictor = model.deploy(initial_instance_count=1, instance_type='ml.c5.xlarge') @@ -478,7 +478,7 @@ Python-based TensorFlow serving on SageMaker has support for `Elastic Inference from sagemaker.tensorflow import TensorFlowModel - model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole') + model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole', framework_version='x.x.x') predictor = model.deploy(initial_instance_count=1, instance_type='ml.c5.xlarge', accelerator_type='ml.eia1.medium') @@ -767,7 +767,8 @@ This customized Python code must be named ``inference.py`` and is specified thro model = TensorFlowModel(entry_point='inference.py', model_data='s3://mybucket/model.tar.gz', - role='MySageMakerRole') + role='MySageMakerRole', + framework_version='x.x.x') In the example above, ``inference.py`` is assumed to be a file inside ``model.tar.gz``. If you want to use a local file instead, you must add the ``source_dir`` argument. See the documentation on `TensorFlowModel `_. @@ -923,7 +924,8 @@ processing. There are 2 ways to do this: model = TensorFlowModel(entry_point='inference.py', dependencies=['requirements.txt'], model_data='s3://mybucket/model.tar.gz', - role='MySageMakerRole') + role='MySageMakerRole', + framework_version='x.x.x') 2. If you are working in a network-isolation situation or if you don't @@ -941,7 +943,8 @@ processing. There are 2 ways to do this: model = TensorFlowModel(entry_point='inference.py', dependencies=['/path/to/folder/named/lib'], model_data='s3://mybucket/model.tar.gz', - role='MySageMakerRole') + role='MySageMakerRole', + framework_version='x.x.x') For more information, see: https://github.com/aws/sagemaker-tensorflow-serving-container#prepost-processing diff --git a/doc/overview.rst b/doc/overview.rst index a6267f6fd6..77e6bd0c3b 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -1958,7 +1958,7 @@ Make sure to have a Compose Version compatible with your Docker Engine installat Local mode configuration ======================== -The local mode uses a YAML configuration file located at ``~/.sagemaker/config.yaml`` to define the default values that are automatically passed to the ``config`` attribute of ``LocalSession``. This is an example of the configuration, for the full schema, see `sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA `_. +The local mode uses a YAML configuration file located at ``${user_config_directory}/sagemaker/config.yaml`` to define the default values that are automatically passed to the ``config`` attribute of ``LocalSession``. This is an example of the configuration, for the full schema, see `sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA `_. .. code:: yaml @@ -1966,7 +1966,7 @@ The local mode uses a YAML configuration file located at ``~/.sagemaker/config.y local_code: true # Using everything locally region_name: "us-west-2" # Name of the region container_config: # Additional docker container config - shm_size: "128M + shm_size: "128M" If you want to keep everything local, and not use Amazon S3 either, you can enable "local code" in one of two ways: diff --git a/doc/v2.rst b/doc/v2.rst index 0677594b31..bca663af33 100644 --- a/doc/v2.rst +++ b/doc/v2.rst @@ -324,9 +324,9 @@ The follow serializer/deserializer classes have been renamed and/or moved: +--------------------------------------------------------+-------------------------------------------------------+ | ``sagemaker.predictor._NPYSerializer`` | ``sagemaker.serializers.NumpySerializer`` | +--------------------------------------------------------+-------------------------------------------------------+ -| ``sagemaker.amazon.common.numpy_to_record_serializer`` | ``sagemaker.amazon.common.RecordSerializer`` | +| ``sagemaker.amazon.common.numpy_to_record_serializer`` | ``sagemaker.serializers.RecordSerializer`` | +--------------------------------------------------------+-------------------------------------------------------+ -| ``sagemaker.amazon.common.record_deserializer`` | ``sagemaker.amazon.common.RecordDeserializer`` | +| ``sagemaker.amazon.common.record_deserializer`` | ``sagemaker.deserializers.RecordDeserializer`` | +--------------------------------------------------------+-------------------------------------------------------+ | ``sagemaker.predictor._JsonDeserializer`` | ``sagemaker.deserializers.JSONDeserializer`` | +--------------------------------------------------------+-------------------------------------------------------+ diff --git a/doc/workflows/step_functions/index.rst b/doc/workflows/step_functions/index.rst index a327d376a0..bfe9582341 100644 --- a/doc/workflows/step_functions/index.rst +++ b/doc/workflows/step_functions/index.rst @@ -11,5 +11,5 @@ without having to provision and integrate the AWS services separately. The AWS Step Functions Python SDK uses the SageMaker Python SDK as a dependency. To get started with step functions, try the workshop or visit the SDK's website: -* `Workshop on using AWS Step Functions with SageMaker `__ +* `Create and manage Amazon SageMaker AI jobs with Step Functions `__ * `AWS Step Functions Python SDK website `__ diff --git a/src/sagemaker/amazon/common.py b/src/sagemaker/amazon/common.py index 4632bda628..fc5d355749 100644 --- a/src/sagemaker/amazon/common.py +++ b/src/sagemaker/amazon/common.py @@ -13,282 +13,13 @@ """Placeholder docstring""" from __future__ import absolute_import -import io -import logging -import struct -import sys - -import numpy as np - -from sagemaker.amazon.record_pb2 import Record -from sagemaker.deprecations import deprecated_class -from sagemaker.deserializers import SimpleBaseDeserializer -from sagemaker.serializers import SimpleBaseSerializer -from sagemaker.utils import DeferredError - - -class RecordSerializer(SimpleBaseSerializer): - """Serialize a NumPy array for an inference request.""" - - def __init__(self, content_type="application/x-recordio-protobuf"): - """Initialize a ``RecordSerializer`` instance. - - Args: - content_type (str): The MIME type to signal to the inference endpoint when sending - request data (default: "application/x-recordio-protobuf"). - """ - super(RecordSerializer, self).__init__(content_type=content_type) - - def serialize(self, data): - """Serialize a NumPy array into a buffer containing RecordIO records. - - Args: - data (numpy.ndarray): The data to serialize. - - Returns: - io.BytesIO: A buffer containing the data serialized as records. - """ - if len(data.shape) == 1: - data = data.reshape(1, data.shape[0]) - - if len(data.shape) != 2: - raise ValueError( - "Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape) - ) - - buffer = io.BytesIO() - write_numpy_to_dense_tensor(buffer, data) - buffer.seek(0) - - return buffer - - -class RecordDeserializer(SimpleBaseDeserializer): - """Deserialize RecordIO Protobuf data from an inference endpoint.""" - - def __init__(self, accept="application/x-recordio-protobuf"): - """Initialize a ``RecordDeserializer`` instance. - - Args: - accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that - is expected from the inference endpoint (default: - "application/x-recordio-protobuf"). - """ - super(RecordDeserializer, self).__init__(accept=accept) - - def deserialize(self, data, content_type): - """Deserialize RecordIO Protobuf data from an inference endpoint. - - Args: - data (object): The protobuf message to deserialize. - content_type (str): The MIME type of the data. - Returns: - list: A list of records. - """ - try: - return read_records(data) - finally: - data.close() - - -def _write_feature_tensor(resolved_type, record, vector): - """Placeholder Docstring""" - if resolved_type == "Int32": - record.features["values"].int32_tensor.values.extend(vector) - elif resolved_type == "Float64": - record.features["values"].float64_tensor.values.extend(vector) - elif resolved_type == "Float32": - record.features["values"].float32_tensor.values.extend(vector) - - -def _write_label_tensor(resolved_type, record, scalar): - """Placeholder Docstring""" - if resolved_type == "Int32": - record.label["values"].int32_tensor.values.extend([scalar]) - elif resolved_type == "Float64": - record.label["values"].float64_tensor.values.extend([scalar]) - elif resolved_type == "Float32": - record.label["values"].float32_tensor.values.extend([scalar]) - - -def _write_keys_tensor(resolved_type, record, vector): - """Placeholder Docstring""" - if resolved_type == "Int32": - record.features["values"].int32_tensor.keys.extend(vector) - elif resolved_type == "Float64": - record.features["values"].float64_tensor.keys.extend(vector) - elif resolved_type == "Float32": - record.features["values"].float32_tensor.keys.extend(vector) - - -def _write_shape(resolved_type, record, scalar): - """Placeholder Docstring""" - if resolved_type == "Int32": - record.features["values"].int32_tensor.shape.extend([scalar]) - elif resolved_type == "Float64": - record.features["values"].float64_tensor.shape.extend([scalar]) - elif resolved_type == "Float32": - record.features["values"].float32_tensor.shape.extend([scalar]) - - -def write_numpy_to_dense_tensor(file, array, labels=None): - """Writes a numpy array to a dense tensor - - Args: - file: - array: - labels: - """ - - # Validate shape of array and labels, resolve array and label types - if not len(array.shape) == 2: - raise ValueError("Array must be a Matrix") - if labels is not None: - if not len(labels.shape) == 1: - raise ValueError("Labels must be a Vector") - if labels.shape[0] not in array.shape: - raise ValueError( - "Label shape {} not compatible with array shape {}".format( - labels.shape, array.shape - ) - ) - resolved_label_type = _resolve_type(labels.dtype) - resolved_type = _resolve_type(array.dtype) - - # Write each vector in array into a Record in the file object - record = Record() - for index, vector in enumerate(array): - record.Clear() - _write_feature_tensor(resolved_type, record, vector) - if labels is not None: - _write_label_tensor(resolved_label_type, record, labels[index]) - _write_recordio(file, record.SerializeToString()) - - -def write_spmatrix_to_sparse_tensor(file, array, labels=None): - """Writes a scipy sparse matrix to a sparse tensor - - Args: - file: - array: - labels: - """ - try: - import scipy - except ImportError as e: - logging.warning( - "scipy failed to import. Sparse matrix functions will be impaired or broken." - ) - # Any subsequent attempt to use scipy will raise the ImportError - scipy = DeferredError(e) - - if not scipy.sparse.issparse(array): - raise TypeError("Array must be sparse") - - # Validate shape of array and labels, resolve array and label types - if not len(array.shape) == 2: - raise ValueError("Array must be a Matrix") - if labels is not None: - if not len(labels.shape) == 1: - raise ValueError("Labels must be a Vector") - if labels.shape[0] not in array.shape: - raise ValueError( - "Label shape {} not compatible with array shape {}".format( - labels.shape, array.shape - ) - ) - resolved_label_type = _resolve_type(labels.dtype) - resolved_type = _resolve_type(array.dtype) - - csr_array = array.tocsr() - n_rows, n_cols = csr_array.shape - - record = Record() - for row_idx in range(n_rows): - record.Clear() - row = csr_array.getrow(row_idx) - # Write values - _write_feature_tensor(resolved_type, record, row.data) - # Write keys - _write_keys_tensor(resolved_type, record, row.indices.astype(np.uint64)) - - # Write labels - if labels is not None: - _write_label_tensor(resolved_label_type, record, labels[row_idx]) - - # Write shape - _write_shape(resolved_type, record, n_cols) - - _write_recordio(file, record.SerializeToString()) - - -def read_records(file): - """Eagerly read a collection of amazon Record protobuf objects from file. - - Args: - file: - """ - records = [] - for record_data in read_recordio(file): - record = Record() - record.ParseFromString(record_data) - records.append(record) - return records - - -# MXNet requires recordio records have length in bytes that's a multiple of 4 -# This sets up padding bytes to append to the end of the record, for diferent -# amounts of padding required. -padding = {} -for amount in range(4): - if sys.version_info >= (3,): - padding[amount] = bytes([0x00 for _ in range(amount)]) - else: - padding[amount] = bytearray([0x00 for _ in range(amount)]) - -_kmagic = 0xCED7230A - - -def _write_recordio(f, data): - """Writes a single data point as a RecordIO record to the given file. - - Args: - f: - data: - """ - length = len(data) - f.write(struct.pack("I", _kmagic)) - f.write(struct.pack("I", length)) - pad = (((length + 3) >> 2) << 2) - length - f.write(data) - f.write(padding[pad]) - - -def read_recordio(f): - """Placeholder Docstring""" - while True: - try: - (read_kmagic,) = struct.unpack("I", f.read(4)) - except struct.error: - return - assert read_kmagic == _kmagic - (len_record,) = struct.unpack("I", f.read(4)) - pad = (((len_record + 3) >> 2) << 2) - len_record - yield f.read(len_record) - if pad: - f.read(pad) - - -def _resolve_type(dtype): - """Placeholder Docstring""" - if dtype == np.dtype(int): - return "Int32" - if dtype == np.dtype(float): - return "Float64" - if dtype == np.dtype("float32"): - return "Float32" - raise ValueError("Unsupported dtype {} on array".format(dtype)) - - -numpy_to_record_serializer = deprecated_class(RecordSerializer, "numpy_to_record_serializer") -record_deserializer = deprecated_class(RecordDeserializer, "record_deserializer") +# these imports ensure backward compatibility. +from sagemaker.deserializers import RecordDeserializer # noqa: F401 # pylint: disable=W0611 +from sagemaker.serializers import RecordSerializer # noqa: F401 # pylint: disable=W0611 +from sagemaker.serializer_utils import ( # noqa: F401 # pylint: disable=W0611 + read_recordio, + read_records, + write_numpy_to_dense_tensor, + write_spmatrix_to_sparse_tensor, + _write_recordio, +) diff --git a/src/sagemaker/amazon/factorization_machines.py b/src/sagemaker/amazon/factorization_machines.py index 2b24356ee9..1149cd02b2 100644 --- a/src/sagemaker/amazon/factorization_machines.py +++ b/src/sagemaker/amazon/factorization_machines.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt, isin, ge +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/hyperparameter.py b/src/sagemaker/amazon/hyperparameter.py index 856927cb13..b479f8a271 100644 --- a/src/sagemaker/amazon/hyperparameter.py +++ b/src/sagemaker/amazon/hyperparameter.py @@ -28,7 +28,7 @@ def __init__(self, name, validate=lambda _: True, validation_message="", data_ty """Args: name (str): The name of this hyperparameter validate - (callable[object]->[bool]): A validation function or list of validation + (Callable[object]->[bool]): A validation function or list of validation functions. Each function validates an object and returns False if the object diff --git a/src/sagemaker/amazon/ipinsights.py b/src/sagemaker/amazon/ipinsights.py index 737d13dd44..bc8e1b5d86 100644 --- a/src/sagemaker/amazon/ipinsights.py +++ b/src/sagemaker/amazon/ipinsights.py @@ -209,7 +209,7 @@ def __init__( chain. serializer (sagemaker.serializers.BaseSerializer): Optional. Default serializes input data to text/csv. - deserializer (callable): Optional. Default parses JSON responses + deserializer (Callable): Optional. Default parses JSON responses using ``json.load(...)``. component_name (str): Optional. Name of the Amazon SageMaker inference component corresponding the predictor. diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index 144cdc934a..25abb9cb27 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt, isin, ge, le +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/knn.py b/src/sagemaker/amazon/knn.py index f9c73381b4..89ec979e09 100644 --- a/src/sagemaker/amazon/knn.py +++ b/src/sagemaker/amazon/knn.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import ge, isin +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/lda.py b/src/sagemaker/amazon/lda.py index bd64d3ae2e..c57da9643e 100644 --- a/src/sagemaker/amazon/lda.py +++ b/src/sagemaker/amazon/lda.py @@ -18,11 +18,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer +from sagemaker.deserializers import RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/linear_learner.py b/src/sagemaker/amazon/linear_learner.py index 695eb31dc1..4533dcdaea 100644 --- a/src/sagemaker/amazon/linear_learner.py +++ b/src/sagemaker/amazon/linear_learner.py @@ -18,11 +18,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer +from sagemaker.deserializers import RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import isin, gt, lt, ge, le from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/ntm.py b/src/sagemaker/amazon/ntm.py index 4267ac8969..41dde1c33c 100644 --- a/src/sagemaker/amazon/ntm.py +++ b/src/sagemaker/amazon/ntm.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import ge, le, isin +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/pca.py b/src/sagemaker/amazon/pca.py index 953fff9d0b..b724435afa 100644 --- a/src/sagemaker/amazon/pca.py +++ b/src/sagemaker/amazon/pca.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt, isin +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/amazon/randomcutforest.py b/src/sagemaker/amazon/randomcutforest.py index 21d98741b0..d60d5a7741 100644 --- a/src/sagemaker/amazon/randomcutforest.py +++ b/src/sagemaker/amazon/randomcutforest.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import ge, le +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index bb4059c03a..e18d7ba2b9 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -478,7 +478,7 @@ def create_model( training cluster for distributed training. Default: False model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - predictor_cls (callable[string, sagemaker.session.Session]): A + Callable[[string, sagemaker.session.Session], Any]: A function to call to create a predictor (default: None). If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -591,7 +591,7 @@ def deploy( training cluster for distributed training. Default: False model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -609,7 +609,7 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Returns: - callable[string, sagemaker.session.Session] or ``None``: + Optional[Callable[[string, sagemaker.session.Session], Any]]: If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on the created endpoint name. Otherwise, ``None``. """ diff --git a/src/sagemaker/automl/automlv2.py b/src/sagemaker/automl/automlv2.py index 0819e5384e..b071be3b24 100644 --- a/src/sagemaker/automl/automlv2.py +++ b/src/sagemaker/automl/automlv2.py @@ -1022,7 +1022,7 @@ def create_model( training cluster for distributed training. Default: False model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -1130,7 +1130,7 @@ def deploy( training cluster for distributed training. Default: False model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -1148,7 +1148,7 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Returns: - callable[string, sagemaker.session.Session] or ``None``: + Optional[Callable[[string, sagemaker.session.Session], Any]]: If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on the created endpoint name. Otherwise, ``None``. """ diff --git a/src/sagemaker/base_deserializers.py b/src/sagemaker/base_deserializers.py index a152f0144d..ded68fc4b0 100644 --- a/src/sagemaker/base_deserializers.py +++ b/src/sagemaker/base_deserializers.py @@ -23,6 +23,7 @@ import numpy as np from six import with_metaclass +from sagemaker.serializer_utils import read_records from sagemaker.utils import DeferredError try: @@ -388,3 +389,31 @@ def deserialize(self, stream, content_type="tensor/pt"): "Unable to deserialize your data to torch.Tensor.\ Please provide custom deserializer in InferenceSpec." ) + + +class RecordDeserializer(SimpleBaseDeserializer): + """Deserialize RecordIO Protobuf data from an inference endpoint.""" + + def __init__(self, accept="application/x-recordio-protobuf"): + """Initialize a ``RecordDeserializer`` instance. + + Args: + accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that + is expected from the inference endpoint (default: + "application/x-recordio-protobuf"). + """ + super(RecordDeserializer, self).__init__(accept=accept) + + def deserialize(self, data, content_type): + """Deserialize RecordIO Protobuf data from an inference endpoint. + + Args: + data (object): The protobuf message to deserialize. + content_type (str): The MIME type of the data. + Returns: + list: A list of records. + """ + try: + return read_records(data) + finally: + data.close() diff --git a/src/sagemaker/base_serializers.py b/src/sagemaker/base_serializers.py index 45fea23493..0e1df120ff 100644 --- a/src/sagemaker/base_serializers.py +++ b/src/sagemaker/base_serializers.py @@ -22,6 +22,7 @@ from pandas import DataFrame from six import with_metaclass +from sagemaker.serializer_utils import write_numpy_to_dense_tensor from sagemaker.utils import DeferredError try: @@ -466,3 +467,39 @@ def serialize(self, data): ) raise ValueError("Object of type %s is not a torch.Tensor" % type(data)) + + +class RecordSerializer(SimpleBaseSerializer): + """Serialize a NumPy array for an inference request.""" + + def __init__(self, content_type="application/x-recordio-protobuf"): + """Initialize a ``RecordSerializer`` instance. + + Args: + content_type (str): The MIME type to signal to the inference endpoint when sending + request data (default: "application/x-recordio-protobuf"). + """ + super(RecordSerializer, self).__init__(content_type=content_type) + + def serialize(self, data): + """Serialize a NumPy array into a buffer containing RecordIO records. + + Args: + data (numpy.ndarray): The data to serialize. + + Returns: + io.BytesIO: A buffer containing the data serialized as records. + """ + if len(data.shape) == 1: + data = data.reshape(1, data.shape[0]) + + if len(data.shape) != 2: + raise ValueError( + "Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape) + ) + + buffer = io.BytesIO() + write_numpy_to_dense_tensor(buffer, data) + buffer.seek(0) + + return buffer diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 806009b0f6..c2d2187b69 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -96,7 +96,7 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[str] = None, py_version: Optional[str] = None, - predictor_cls: callable = ChainerPredictor, + predictor_cls: Optional[Callable] = ChainerPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): @@ -125,7 +125,7 @@ def __init__( py_version (str): Python version you want to use for executing your model training code. Defaults to ``None``. Required unless ``image_uri`` is provided. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py index 0e2aabbec4..54bccba55e 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py @@ -51,8 +51,8 @@ "StreamDeserializer": ("sagemaker.deserializers",), "NumpyDeserializer": ("sagemaker.deserializers",), "JSONDeserializer": ("sagemaker.deserializers",), - "RecordSerializer ": ("sagemaker.amazon.common",), - "RecordDeserializer": ("sagemaker.amazon.common",), + "RecordSerializer ": ("sagemaker.serializers",), + "RecordDeserializer": ("sagemaker.deserializers",), } OLD_CLASS_NAME_TO_NEW_CLASS_NAME = { @@ -101,8 +101,8 @@ def node_should_be_modified(self, node): - ``sagemaker.predictor.StreamDeserializer`` - ``sagemaker.predictor._NumpyDeserializer`` - ``sagemaker.predictor._JsonDeserializer`` - - ``sagemaker.amazon.common.numpy_to_record_serializer`` - - ``sagemaker.amazon.common.record_deserializer`` + - ``sagemaker.serializers.numpy_to_record_serializer`` + - ``sagemaker.deserializers.record_deserializer`` Args: node (ast.Call): a node that represents a function call. For more, @@ -128,8 +128,8 @@ def modify_node(self, node): - ``sagemaker.deserializers.StreamDeserializer`` - ``sagemaker.deserializers.NumpyDeserializer`` - ``sagemaker.deserializers._JsonDeserializer`` - - ``sagemaker.amazon.common.RecordSerializer`` - - ``sagemaker.amazon.common.RecordDeserializer`` + - ``sagemaker.serializers.RecordSerializer`` + - ``sagemaker.deserializers.RecordDeserializer`` Args: node (ast.Call): a node that represents a SerDe constructor. @@ -303,8 +303,8 @@ def node_should_be_modified(self, node): """Checks if the import statement imports a SerDe from the ``sagemaker.amazon.common``. This checks for: - - ``sagemaker.amazon.common.numpy_to_record_serializer`` - - ``sagemaker.amazon.common.record_deserializer`` + - ``sagemaker.serializers.numpy_to_record_serializer`` + - ``sagemaker.deserializers.record_deserializer`` Args: node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. @@ -322,8 +322,8 @@ def modify_node(self, node): """Upgrades the ``numpy_to_record_serializer`` and ``record_deserializer`` imports. This upgrades the classes to (if applicable): - - ``sagemaker.amazon.common.RecordSerializer`` - - ``sagemaker.amazon.common.RecordDeserializer`` + - ``sagemaker.serializers.RecordSerializer`` + - ``sagemaker.deserializers.RecordDeserializer`` Args: node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 957a9dfb0c..dad5137329 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -31,8 +31,10 @@ StreamDeserializer, StringDeserializer, TorchTensorDeserializer, + RecordDeserializer, ) +from sagemaker.deprecations import deprecated_class from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartModelType @@ -150,3 +152,6 @@ def retrieve_default( model_type=model_type, config_name=config_name, ) + + +record_deserializer = deprecated_class(RecordDeserializer, "record_deserializer") diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index 8c724a6502..94db4efe29 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Dict, Any +from typing import Callable, Optional, Dict, Any from sagemaker import image_uris from sagemaker.model import Model @@ -54,7 +54,7 @@ def __init__( parallel_loading: bool = False, model_loading_timeout: Optional[int] = None, prediction_timeout: Optional[int] = None, - predictor_cls: callable = DJLPredictor, + predictor_cls: Optional[Callable] = DJLPredictor, huggingface_hub_token: Optional[str] = None, **kwargs, ): @@ -97,10 +97,10 @@ def __init__( None. If not provided, the default is 240 seconds. prediction_timeout (int): The worker predict call (handler) timeout in seconds. Defaults to None. If not provided, the default is 120 seconds. - predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a - predictor with an endpoint name and SageMaker ``Session``. If specified, - ``deploy()`` returns - the result of invoking this function on the created endpoint name. + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call + to create a predictor with an endpoint name and SageMaker ``Session``. If + specified, ``deploy()`` returns the result of invoking this function on the created + endpoint name. huggingface_hub_token (str): The HuggingFace Hub token to use for downloading the model artifacts for a model stored on the huggingface hub. Defaults to None. If not provided, the token must be specified in the diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 6efc04c88e..fa40719c9f 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -387,8 +387,8 @@ def __init__( source_dir (str or PipelineVariable): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. The structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. The structure within this directory + is preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. With the following GitHub repo directory structure: @@ -2550,7 +2550,6 @@ def _get_train_args(cls, estimator, inputs, experiment_config): raise ValueError( "File URIs are supported in local mode only. Please use a S3 URI instead." ) - config = _Job._load_config(inputs, estimator) current_hyperparameters = estimator.hyperparameters() @@ -3421,8 +3420,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. diff --git a/src/sagemaker/experiments/_metrics.py b/src/sagemaker/experiments/_metrics.py index 31dd679cc8..026e73e8a6 100644 --- a/src/sagemaker/experiments/_metrics.py +++ b/src/sagemaker/experiments/_metrics.py @@ -197,8 +197,8 @@ def _send_metrics(self, metrics): response = self._metrics_client.batch_put_metrics(**request) errors = response["Errors"] if "Errors" in response else None if errors: - message = errors[0]["Message"] - raise Exception(f'{len(errors)} errors with message "{message}"') + error_code = errors[0]["Code"] + raise Exception(f'{len(errors)} errors with error code "{error_code}"') def _construct_batch_put_metrics_request(self, batch): """Creates dictionary object used as request to metrics service.""" diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 39915b60dc..4eb8d82b0c 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -631,7 +631,7 @@ def __str__(self) -> str: class FeatureGroup: """FeatureGroup definition. - This class instantiates a FeatureGroup object that comprises of a name for the FeatureGroup, + This class instantiates a FeatureGroup object that comprises a name for the FeatureGroup, session instance, and a list of feature definition objects i.e., FeatureDefinition. Attributes: diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 0ddb3cd255..0e4e582261 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -10,30 +10,29 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Utility methods used by framework classes""" +"""Utility methods used by framework classes.""" from __future__ import absolute_import import json import logging import os import re -import time import shutil import tempfile +import time from collections import namedtuple -from typing import List, Optional, Union, Dict +from typing import Dict, List, Optional, Union + from packaging import version import sagemaker.image_uris +import sagemaker.utils +from sagemaker.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning from sagemaker.instance_group import InstanceGroup from sagemaker.s3_utils import s3_path_join from sagemaker.session_settings import SessionSettings -import sagemaker.utils from sagemaker.workflow import is_pipeline_variable - -from sagemaker.deprecations import renamed_warning, renamed_kwargs from sagemaker.workflow.entities import PipelineVariable -from sagemaker.deprecations import deprecation_warn_base logger = logging.getLogger(__name__) @@ -41,6 +40,7 @@ UploadedCode = namedtuple("UploadedCode", ["s3_prefix", "script_name"]) """sagemaker.fw_utils.UploadedCode: An object containing the S3 prefix and script name. + This is for the source code used for the entry point with an ``Estimator``. It can be instantiated with positional or keyword arguments. """ @@ -152,6 +152,7 @@ "2.1.0", "2.1.2", "2.2.0", + "2.3.0", "2.3.1", "2.4.1", ] @@ -210,7 +211,7 @@ def validate_source_code_input_against_pipeline_variables( git_config: Optional[Dict[str, str]] = None, enable_network_isolation: Union[bool, PipelineVariable] = False, ): - """Validate source code input against pipeline variables + """Validate source code input against pipeline variables. Args: entry_point (str or PipelineVariable): The path to the local Python source file that @@ -251,7 +252,7 @@ def validate_source_code_input_against_pipeline_variables( logger.warning( "The source_dir is a pipeline variable: %s. During pipeline execution, " "the interpreted value of source_dir has to be an S3 URI and " - "must point to a tar.gz file", + "must point to a file with name ``sourcedir.tar.gz``", type(source_dir), ) @@ -480,7 +481,7 @@ def tar_and_upload_dir( def _list_files_to_compress(script, directory): - """Placeholder docstring""" + """Placeholder docstring.""" if directory is None: return [script] @@ -584,7 +585,6 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image): The location returned is a potential concatenation of 2 parts 1. code_location_key_prefix if it exists 2. model_name or a name derived from the image - Args: code_location_key_prefix (str): the s3 key prefix from code_location model_name (str): the name of the model @@ -619,8 +619,6 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution "enabled": True } } - - """ if training_instance_type == "local" or distribution is None: return @@ -645,7 +643,7 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution def profiler_config_deprecation_warning( profiler_config, image_uri, framework_name, framework_version ): - """Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0""" + """Deprecation message if framework profiling is specified TF >= 2.12 and PT >= 2.0.""" if profiler_config is None or profiler_config.framework_profile_params is None: return @@ -691,6 +689,7 @@ def validate_smdistributed( framework_name (str): A string representing the name of framework selected. framework_version (str): A string representing the framework version selected. py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` distribution (dict): A dictionary with information to enable distributed training. (Defaults to None if distributed training is not enabled.) For example: @@ -762,7 +761,8 @@ def _validate_smdataparallel_args( instance_type (str): A string representing the type of training instance selected. Ex: `ml.p3.16xlarge` framework_name (str): A string representing the name of framework selected. Ex: `tensorflow` framework_version (str): A string representing the framework version selected. Ex: `2.3.1` - py_version (str): A string representing the python version selected. Ex: `py3` + py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` distribution (dict): A dictionary with information to enable distributed training. (Defaults to None if distributed training is not enabled.) Ex: @@ -846,6 +846,7 @@ def validate_distribution( framework_name (str): A string representing the name of framework selected. framework_version (str): A string representing the framework version selected. py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` image_uri (str): A string representing a Docker image URI. kwargs(dict): Additional kwargs passed to this function @@ -952,7 +953,7 @@ def validate_distribution( def validate_distribution_for_instance_type(instance_type, distribution): - """Check if the provided distribution strategy is supported for the instance_type + """Check if the provided distribution strategy is supported for the instance_type. Args: instance_type (str): A string representing the type of training instance selected. @@ -1009,6 +1010,7 @@ def validate_torch_distributed_distribution( } framework_version (str): A string representing the framework version selected. py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` image_uri (str): A string representing a Docker image URI. entry_point (str or PipelineVariable): The absolute or relative path to the local Python source file that should be executed as the entry point to @@ -1071,7 +1073,7 @@ def validate_torch_distributed_distribution( def _is_gpu_instance(instance_type): - """Returns bool indicating whether instance_type supports GPU + """Returns bool indicating whether instance_type supports GPU. Args: instance_type (str): Name of the instance_type to check against. @@ -1090,7 +1092,7 @@ def _is_gpu_instance(instance_type): def _is_trainium_instance(instance_type): - """Returns bool indicating whether instance_type is a Trainium instance + """Returns bool indicating whether instance_type is a Trainium instance. Args: instance_type (str): Name of the instance_type to check against. @@ -1106,7 +1108,7 @@ def _is_trainium_instance(instance_type): def python_deprecation_warning(framework, latest_supported_version): - """Placeholder docstring""" + """Placeholder docstring.""" return PYTHON_2_DEPRECATION_WARNING.format( framework=framework, latest_supported_version=latest_supported_version ) @@ -1120,7 +1122,6 @@ def _region_supports_debugger(region_name): Returns: bool: Whether or not the region supports Amazon SageMaker Debugger. - """ return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS @@ -1133,7 +1134,6 @@ def _region_supports_profiler(region_name): Returns: bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature. - """ return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS @@ -1161,7 +1161,8 @@ def validate_version_or_image_args(framework_version, py_version, image_uri): Args: framework_version (str): The version of the framework. - py_version (str): The version of Python. + py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` image_uri (str): The URI of the image. Raises: @@ -1193,9 +1194,8 @@ def create_image_uri( instance_type (str): SageMaker instance type. Used to determine device type (cpu/gpu/family-specific optimized). framework_version (str): The version of the framework. - py_version (str): Optional. Python version. If specified, should be one - of 'py2' or 'py3'. If not specified, image uri will not include a - python component. + py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`. + If not specified, image uri will not include a python component. account (str): AWS account that contains the image. (default: '520713654638') accelerator_type (str): SageMaker Elastic Inference accelerator type. diff --git a/src/sagemaker/huggingface/estimator.py b/src/sagemaker/huggingface/estimator.py index 86df43d4e9..70cc17b209 100644 --- a/src/sagemaker/huggingface/estimator.py +++ b/src/sagemaker/huggingface/estimator.py @@ -15,17 +15,13 @@ import logging import re -from typing import Optional, Union, Dict +from typing import Dict, Optional, Union -from sagemaker.estimator import Framework, EstimatorBase -from sagemaker.fw_utils import ( - framework_name_from_image, - validate_distribution, -) +from sagemaker.estimator import EstimatorBase, Framework +from sagemaker.fw_utils import framework_name_from_image, validate_distribution from sagemaker.huggingface.model import HuggingFaceModel -from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT - from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig +from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -66,7 +62,7 @@ def __init__( Args: py_version (str): Python version you want to use for executing your model training code. Defaults to ``None``. Required unless ``image_uri`` is provided. If - using PyTorch, the current supported version is ``py36``. If using TensorFlow, + using PyTorch, the current supported version is ``py39``. If using TensorFlow, the current supported version is ``py37``. entry_point (str or PipelineVariable): Path (absolute or relative) to the Python source file which should be executed as the entry point to training. @@ -84,8 +80,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory are + preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index ea99be2fc0..05b981d21b 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -123,7 +123,7 @@ def __init__( pytorch_version: Optional[str] = None, py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = HuggingFacePredictor, + predictor_cls: Optional[Callable] = HuggingFacePredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): @@ -158,7 +158,7 @@ def __init__( If not specified, a default image for PyTorch will be used. If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If also ``None``, then a ``ValueError`` will be raised. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -304,7 +304,7 @@ def deploy( - If a wrong type of object is provided as serverless inference config or async inference config Returns: - callable[string, sagemaker.session.Session] or None: Invocation of + Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. """ diff --git a/src/sagemaker/image_uri_config/djl-neuronx.json b/src/sagemaker/image_uri_config/djl-neuronx.json index 3fd3c7619f..1fd7492ff4 100644 --- a/src/sagemaker/image_uri_config/djl-neuronx.json +++ b/src/sagemaker/image_uri_config/djl-neuronx.json @@ -13,12 +13,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -37,12 +39,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -61,12 +65,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -85,12 +91,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -109,12 +117,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -133,12 +143,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -157,12 +169,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -181,12 +195,14 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json b/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json index 0fdb190d30..478d6ff597 100644 --- a/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json +++ b/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json @@ -4,7 +4,7 @@ "inf2" ], "version_aliases": { - "0.0": "0.0.25" + "0.0": "0.0.27" }, "versions": { "0.0.16": { @@ -19,6 +19,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -27,6 +28,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -53,6 +55,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -61,6 +64,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -87,6 +91,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -95,6 +100,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -121,6 +127,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -129,6 +136,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -155,6 +163,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -163,6 +172,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -189,6 +199,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -197,6 +208,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -223,6 +235,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -231,6 +244,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -255,6 +269,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -263,6 +278,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -289,6 +305,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -297,6 +314,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -323,6 +341,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -331,6 +350,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -344,6 +364,40 @@ "container_version": { "inf2": "ubuntu22.04" } + }, + "0.0.27": { + "py_versions": [ + "py310" + ], + "registries": { + "ap-northeast-1": "763104351884", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "tag_prefix": "2.1.2-optimum0.0.27", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "inf2": "ubuntu22.04" + } } } } diff --git a/src/sagemaker/image_uri_config/huggingface-llm.json b/src/sagemaker/image_uri_config/huggingface-llm.json index 24cbd5ca96..cc6b2b20a0 100644 --- a/src/sagemaker/image_uri_config/huggingface-llm.json +++ b/src/sagemaker/image_uri_config/huggingface-llm.json @@ -12,7 +12,8 @@ "1.2": "1.2.0", "1.3": "1.3.3", "1.4": "1.4.5", - "2.0": "2.3.1" + "2.0": "2.4.0", + "3.0": "3.0.1" }, "versions": { "0.6.0": { @@ -766,6 +767,100 @@ "container_version": { "gpu": "cu124-ubuntu22.04" } + }, + "2.4.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "tag_prefix": "2.4.0-tgi2.4.0", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu124-ubuntu22.04-v2.2" + } + }, + "3.0.1": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "tag_prefix": "2.4.0-tgi3.0.1", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu124-ubuntu22.04-v2.1" + } } } } diff --git a/src/sagemaker/image_uri_config/huggingface-neuron.json b/src/sagemaker/image_uri_config/huggingface-neuron.json index 6ed4a62bc7..4e950bdb70 100644 --- a/src/sagemaker/image_uri_config/huggingface-neuron.json +++ b/src/sagemaker/image_uri_config/huggingface-neuron.json @@ -24,12 +24,14 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-neuronx.json b/src/sagemaker/image_uri_config/huggingface-neuronx.json index d0f0960ef7..a3426d5e0c 100644 --- a/src/sagemaker/image_uri_config/huggingface-neuronx.json +++ b/src/sagemaker/image_uri_config/huggingface-neuronx.json @@ -26,6 +26,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -34,6 +35,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -67,6 +69,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -75,6 +78,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -108,6 +112,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -116,6 +121,7 @@ "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -166,6 +172,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -178,6 +185,7 @@ "eu-south-1": "692866216735", "eu-south-2": "503227376785", "me-south-1": "217643126080", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -219,6 +227,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -231,6 +240,7 @@ "eu-south-1": "692866216735", "eu-south-2": "503227376785", "me-south-1": "217643126080", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -273,6 +283,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -285,6 +296,7 @@ "eu-south-1": "692866216735", "eu-south-2": "503227376785", "me-south-1": "217643126080", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -321,6 +333,7 @@ "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -333,6 +346,7 @@ "eu-south-1": "692866216735", "eu-south-2": "503227376785", "me-south-1": "217643126080", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-training-compiler.json b/src/sagemaker/image_uri_config/huggingface-training-compiler.json index 5104edfecc..fa3a4119ca 100644 --- a/src/sagemaker/image_uri_config/huggingface-training-compiler.json +++ b/src/sagemaker/image_uri_config/huggingface-training-compiler.json @@ -70,6 +70,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-central-2": "380420809688", @@ -81,6 +82,7 @@ "eu-west-3": "763104351884", "me-south-1": "217643126080", "me-central-1": "914824155844", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -111,6 +113,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-central-2": "380420809688", @@ -122,6 +125,7 @@ "eu-west-3": "763104351884", "me-south-1": "217643126080", "me-central-1": "914824155844", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -157,6 +161,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-central-2": "380420809688", @@ -168,6 +173,7 @@ "eu-west-3": "763104351884", "me-south-1": "217643126080", "me-central-1": "914824155844", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index 930b24566d..c314436346 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -13,7 +13,8 @@ "4.17": "4.17.0", "4.26": "4.26.0", "4.28": "4.28.1", - "4.36": "4.36.0" + "4.36": "4.36.0", + "4.46": "4.46.1" }, "versions": { "4.4.2": { @@ -1018,6 +1019,53 @@ "gpu": "cu121-ubuntu20.04" } } + }, + "4.46.1": { + "version_aliases": { + "pytorch2.3": "pytorch2.3.0" + }, + "pytorch2.3.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "huggingface-pytorch-training", + "container_version": { + "gpu": "cu121-ubuntu20.04" + } + } } } }, @@ -1883,6 +1931,58 @@ "cpu": "ubuntu22.04" } } + }, + "4.48.0": { + "version_aliases": { + "pytorch2.3": "pytorch2.3.0" + }, + "pytorch2.3.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "huggingface-pytorch-inference", + "container_version": { + "gpu": "cu121-ubuntu22.04", + "cpu": "ubuntu22.04" + } + } } } } diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 66150da2b0..b3a23733ae 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -208,6 +208,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -223,6 +224,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -253,6 +255,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -268,6 +271,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -298,6 +302,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -313,6 +318,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -343,6 +349,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -358,6 +365,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -388,6 +396,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -403,6 +412,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -433,6 +443,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -448,6 +459,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -478,6 +490,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -493,6 +506,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -523,6 +537,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -538,6 +553,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -567,6 +583,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -582,6 +599,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -611,6 +629,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -626,6 +645,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -655,6 +675,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -670,6 +691,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -699,6 +721,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -714,6 +737,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -743,6 +767,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -758,6 +783,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -787,6 +813,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -802,6 +829,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -831,6 +859,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -846,6 +875,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -875,6 +905,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -890,6 +921,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -919,6 +951,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -934,6 +967,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -963,6 +997,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -978,6 +1013,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1007,6 +1043,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1022,6 +1059,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1053,6 +1091,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1068,6 +1107,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1099,6 +1139,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1114,6 +1155,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1141,6 +1183,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1156,6 +1199,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1183,6 +1227,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1198,6 +1243,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1243,6 +1289,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1258,6 +1305,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1290,6 +1338,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1305,6 +1354,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1335,6 +1385,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1350,6 +1401,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1380,6 +1432,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1395,6 +1448,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1425,6 +1479,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1440,6 +1495,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1470,6 +1526,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1485,6 +1542,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1515,6 +1573,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1530,6 +1589,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1691,6 +1751,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1706,6 +1767,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1736,6 +1798,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1751,6 +1814,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1782,6 +1846,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1797,6 +1862,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1827,6 +1893,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1842,6 +1909,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1872,6 +1940,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1887,6 +1956,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1917,6 +1987,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1932,6 +2003,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1962,6 +2034,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1977,6 +2050,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2007,6 +2081,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2022,6 +2097,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2051,6 +2127,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2066,6 +2143,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2095,6 +2173,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2110,6 +2189,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2139,6 +2219,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2154,6 +2235,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2183,6 +2265,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2198,6 +2281,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2227,6 +2311,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2242,6 +2327,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2271,6 +2357,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2286,6 +2373,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2315,6 +2403,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2330,6 +2419,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2359,6 +2449,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2374,6 +2465,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2403,6 +2495,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2418,6 +2511,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2447,6 +2541,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2462,6 +2557,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2491,6 +2587,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2506,6 +2603,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2537,6 +2635,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2552,6 +2651,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2583,6 +2683,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2598,6 +2699,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2629,6 +2731,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2644,6 +2747,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2671,6 +2775,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2686,6 +2791,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/sagemaker-base-python.json b/src/sagemaker/image_uri_config/sagemaker-base-python.json index d4bb35f77b..e1de6bfd21 100644 --- a/src/sagemaker/image_uri_config/sagemaker-base-python.json +++ b/src/sagemaker/image_uri_config/sagemaker-base-python.json @@ -11,6 +11,7 @@ "ap-southeast-1": "492261229750", "ap-southeast-2": "452832661640", "ap-southeast-3": "276181064229", + "ap-southeast-5": "148761635175", "ca-central-1": "310906938811", "cn-north-1": "390048526115", "cn-northwest-1": "390780980154", @@ -30,6 +31,8 @@ "us-east-2": "429704687514", "us-gov-east-1": "107072934176", "us-gov-west-1": "107173498710", + "us-isof-east-1": "840123138293", + "us-isof-south-1": "883091641454", "us-west-1": "742091327244", "us-west-2": "236514542706" }, diff --git a/src/sagemaker/image_uri_config/sagemaker-distribution.json b/src/sagemaker/image_uri_config/sagemaker-distribution.json new file mode 100644 index 0000000000..d9ffca5d7b --- /dev/null +++ b/src/sagemaker/image_uri_config/sagemaker-distribution.json @@ -0,0 +1,37 @@ +{ + "processors": ["cpu", "gpu"], + "scope": ["inference"], + "version_aliases": { + "3.0": "3.0.0" + }, + "versions": { + "3.0.0": { + "registries": { + "us-east-1": "885854791233", + "us-east-2": "137914896644", + "us-west-1": "053634841547", + "us-west-2": "542918446943", + "af-south-1": "238384257742", + "ap-east-1": "523751269255", + "ap-south-1": "245090515133", + "ap-northeast-2": "064688005998", + "ap-southeast-1": "022667117163", + "ap-southeast-2": "648430277019", + "ap-northeast-1": "010972774902", + "ca-central-1": "481561238223", + "eu-central-1": "545423591354", + "eu-west-1": "819792524951", + "eu-west-2": "021081402939", + "eu-west-3": "856416204555", + "eu-north-1": "175620155138", + "eu-south-1": "810671768855", + "sa-east-1": "567556641782", + "ap-northeast-3": "564864627153", + "ap-southeast-3": "370607712162", + "me-south-1": "523774347010", + "me-central-1": "358593528301" + }, + "repository": "sagemaker-distribution-prod" + } + } +} diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 5f12889fd0..37fa7ee46d 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -332,7 +332,8 @@ "2.12": "2.12.1", "2.13": "2.13.0", "2.14": "2.14.1", - "2.16": "2.16.1" + "2.16": "2.16.1", + "2.18": "2.18.0" }, "versions": { "1.4.1": { @@ -640,6 +641,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -655,6 +657,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -681,6 +684,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -696,6 +700,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -722,6 +727,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -737,6 +743,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -763,6 +770,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -778,6 +786,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -804,6 +813,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -819,6 +829,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -845,6 +856,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -860,6 +872,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -886,6 +899,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -901,6 +915,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -927,6 +942,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -942,6 +958,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -968,6 +985,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -983,6 +1001,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1009,6 +1028,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1024,6 +1044,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1050,6 +1071,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1065,6 +1087,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1091,6 +1114,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1106,6 +1130,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1132,6 +1157,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1147,6 +1173,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1173,6 +1200,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1188,6 +1216,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1214,6 +1243,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1229,6 +1259,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1255,6 +1286,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1270,6 +1302,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1296,6 +1329,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1311,6 +1345,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1337,6 +1372,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1352,6 +1388,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1378,6 +1415,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1393,6 +1431,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1419,6 +1458,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1434,6 +1474,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1460,6 +1501,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1475,6 +1517,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1501,6 +1544,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1516,6 +1560,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1542,6 +1587,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1557,6 +1603,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1583,6 +1630,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1598,6 +1646,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1624,6 +1673,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1639,6 +1689,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1665,6 +1716,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1680,6 +1732,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1706,6 +1759,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1721,6 +1775,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1747,6 +1802,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1762,6 +1818,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1788,6 +1845,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1803,6 +1861,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1829,6 +1888,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1844,6 +1904,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1870,6 +1931,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1885,6 +1947,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1911,6 +1974,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1926,6 +1990,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1952,6 +2017,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -1967,6 +2033,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1993,6 +2060,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2008,6 +2076,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2034,6 +2103,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2049,6 +2119,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2075,6 +2146,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2090,6 +2162,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2116,6 +2189,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2131,6 +2205,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2157,6 +2232,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2172,6 +2248,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2200,6 +2277,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2215,6 +2293,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2243,6 +2322,48 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, + "2.18.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2258,6 +2379,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2302,6 +2424,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2317,6 +2440,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2349,6 +2473,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2364,6 +2489,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2396,6 +2522,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2411,6 +2538,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2443,6 +2571,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2458,6 +2587,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2490,6 +2620,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2505,6 +2636,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2550,7 +2682,8 @@ "2.12": "2.12.0", "2.13": "2.13.0", "2.14": "2.14.1", - "2.16": "2.16.2" + "2.16": "2.16.2", + "2.18": "2.18.0" }, "versions": { "1.4.1": { @@ -2942,6 +3075,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2957,6 +3091,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2987,6 +3122,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3002,6 +3138,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3033,6 +3170,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3048,6 +3186,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3079,6 +3218,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3094,6 +3234,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3125,6 +3266,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3140,6 +3282,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3171,6 +3314,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3186,6 +3330,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3216,6 +3361,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3231,6 +3377,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3261,6 +3408,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3276,6 +3424,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3306,6 +3455,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3321,6 +3471,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3351,6 +3502,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3366,6 +3518,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3396,6 +3549,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3411,6 +3565,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3441,6 +3596,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3456,6 +3612,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3486,6 +3643,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3501,6 +3659,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3531,6 +3690,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3546,6 +3706,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3576,6 +3737,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3591,6 +3753,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3620,6 +3783,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3635,6 +3799,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3664,6 +3829,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3679,6 +3845,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3708,6 +3875,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3723,6 +3891,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3752,6 +3921,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3767,6 +3937,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3796,6 +3967,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3811,6 +3983,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3840,6 +4013,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3855,6 +4029,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3884,6 +4059,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3899,6 +4075,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3928,6 +4105,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3943,6 +4121,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3972,6 +4151,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -3987,6 +4167,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4016,6 +4197,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4031,6 +4213,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4060,6 +4243,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4075,6 +4259,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4104,6 +4289,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4119,6 +4305,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4148,6 +4335,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4163,6 +4351,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4192,6 +4381,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4207,6 +4397,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4236,6 +4427,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4251,6 +4443,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4280,6 +4473,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4295,6 +4489,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4324,6 +4519,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4339,6 +4535,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4368,6 +4565,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4383,6 +4581,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4412,6 +4611,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4427,6 +4627,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4454,6 +4655,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4469,6 +4671,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4500,6 +4703,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4515,6 +4719,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4546,6 +4751,51 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-training" + }, + "2.18.0": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4561,6 +4811,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 8c3449e8c4..7d277cd854 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -101,6 +101,8 @@ def retrieve( https://github.com/aws/deep-learning-containers/blob/master/available_images.md (default: None). distribution (dict): A dictionary with information on how to run distributed training + base_framework_version (str): The base version number of PyTorch or Tensorflow. + (default: None). training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): A configuration class for the SageMaker Training Compiler (default: None). diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index 89779bef44..71678021d4 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -43,6 +43,8 @@ def __init__( attribute_names: Optional[List[Union[str, PipelineVariable]]] = None, target_attribute_name: Optional[Union[str, PipelineVariable]] = None, shuffle_config: Optional["ShuffleConfig"] = None, + hub_access_config: Optional[dict] = None, + model_access_config: Optional[dict] = None, ): r"""Create a definition for input data used by an SageMaker training job. @@ -102,6 +104,13 @@ def __init__( shuffle_config (sagemaker.inputs.ShuffleConfig): If specified this configuration enables shuffling on this channel. See the SageMaker API documentation for more info: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html + hub_access_config (dict): Specify the HubAccessConfig of a + Model Reference for which a training job is being created for. + model_access_config (dict): For models that require a Model Access Config, specify True + or False for to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). """ self.config = { "DataSource": {"S3DataSource": {"S3DataType": s3_data_type, "S3Uri": s3_data}} @@ -129,6 +138,27 @@ def __init__( self.config["TargetAttributeName"] = target_attribute_name if shuffle_config is not None: self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed} + self.add_hub_access_config(hub_access_config) + self.add_model_access_config(model_access_config) + + def add_hub_access_config(self, hub_access_config=None): + """Add Hub Access Config to the channel's configuration. + + Args: + hub_access_config (dict): The HubAccessConfig to be added to the + channel's configuration. + """ + if hub_access_config is not None: + self.config["DataSource"]["S3DataSource"]["HubAccessConfig"] = hub_access_config + + def add_model_access_config(self, model_access_config=None): + """Add Model Access Config to the channel's configuration. + + Args: + model_access_config (dict): Whether model terms of use have been accepted. + """ + if model_access_config is not None: + self.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] = model_access_config class ShuffleConfig(object): diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 210dd426c5..1ad7e3b981 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -65,6 +65,7 @@ def stop(self): @staticmethod def _load_config(inputs, estimator, expand_role=True, validate_uri=True): """Placeholder docstring""" + model_access_config, hub_access_config = _Job._get_access_configs(estimator) input_config = _Job._format_inputs_to_input_config(inputs, validate_uri) role = ( estimator.sagemaker_session.expand_role(estimator.role) @@ -95,19 +96,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): validate_uri, content_type="application/x-sagemaker-model", input_mode="File", + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) if model_channel: input_config = [] if input_config is None else input_config input_config.append(model_channel) - if estimator.enable_network_isolation(): - code_channel = _Job._prepare_channel( - input_config, estimator.code_uri, estimator.code_channel_name, validate_uri - ) + code_channel = _Job._prepare_channel( + input_config, + estimator.code_uri, + estimator.code_channel_name, + validate_uri, + ) - if code_channel: - input_config = [] if input_config is None else input_config - input_config.append(code_channel) + if code_channel: + input_config = [] if input_config is None else input_config + input_config.append(code_channel) return { "input_config": input_config, @@ -118,6 +123,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): "vpc_config": vpc_config, } + @staticmethod + def _get_access_configs(estimator): + """Return access configs from estimator object. + + JumpStartEstimator uses access configs which need to be added to the model channel, + so they are passed down to the job level. + + Args: + estimator (EstimatorBase): estimator object with access config field if applicable + """ + model_access_config, hub_access_config = None, None + if hasattr(estimator, "model_access_config"): + model_access_config = estimator.model_access_config + if hasattr(estimator, "hub_access_config"): + hub_access_config = estimator.hub_access_config + return model_access_config, hub_access_config + @staticmethod def _format_inputs_to_input_config(inputs, validate_uri=True): """Placeholder docstring""" @@ -173,6 +195,8 @@ def _format_string_uri_input( input_mode=None, compression=None, target_attribute_name=None, + model_access_config=None, + hub_access_config=None, ): """Placeholder docstring""" s3_input_result = TrainingInput( @@ -181,6 +205,8 @@ def _format_string_uri_input( input_mode=input_mode, compression=compression, target_attribute_name=target_attribute_name, + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"): return s3_input_result @@ -193,7 +219,11 @@ def _format_string_uri_input( ) if isinstance(uri_input, str): return s3_input_result - if isinstance(uri_input, (TrainingInput, file_input, FileSystemInput)): + if isinstance(uri_input, (file_input, FileSystemInput)): + return uri_input + if isinstance(uri_input, TrainingInput): + uri_input.add_hub_access_config(hub_access_config=hub_access_config) + uri_input.add_model_access_config(model_access_config=model_access_config) return uri_input if is_pipeline_variable(uri_input): return s3_input_result @@ -211,6 +241,8 @@ def _prepare_channel( validate_uri=True, content_type=None, input_mode=None, + model_access_config=None, + hub_access_config=None, ): """Placeholder docstring""" if not channel_uri: @@ -226,7 +258,12 @@ def _prepare_channel( raise ValueError("Duplicate channel {} not allowed.".format(channel_name)) channel_input = _Job._format_string_uri_input( - channel_uri, validate_uri, content_type, input_mode + channel_uri, + validate_uri, + content_type, + input_mode, + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) channel = _Job._convert_input_to_channel(channel_name, channel_input) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 20a2d16c15..2ed2deb803 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -288,6 +288,7 @@ def get_model_specs( ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) + # Users only input model id, not contentType, so first try to describe with ModelReference, then with Model if hub_arn: try: hub_model_arn = construct_hub_model_reference_arn_from_inputs( @@ -308,11 +309,22 @@ def get_model_specs( hub_model_arn = construct_hub_model_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) - model_specs = JumpStartModelsAccessor._cache.get_hub_model( - hub_model_arn=hub_model_arn - ) - model_specs.set_hub_content_type(HubContentType.MODEL) - return model_specs + + # Failed to describe ModelReference, try with Model + try: + model_specs = JumpStartModelsAccessor._cache.get_hub_model( + hub_model_arn=hub_model_arn + ) + model_specs.set_hub_content_type(HubContentType.MODEL) + + return model_specs + except Exception as ex: + # Failed with both, throw a custom error message + raise RuntimeError( + f"Cannot get details for {model_id} in Hub {hub_arn}. \ + {model_id} does not exist as a Model or ModelReference: \n" + + str(ex) + ) return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, version_str=version, model_type=model_type diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 90ee7dea8d..c1ad9710f1 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -29,6 +29,7 @@ get_region_fallback, verify_model_region_and_return_specs, ) +from sagemaker.s3_utils import is_s3_url from sagemaker.session import Session from sagemaker.jumpstart.types import JumpStartModelSpecs @@ -74,7 +75,7 @@ def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_ty def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str: """Returns instance specific training artifact key or default one as fallback.""" instance_specific_training_artifact_key: Optional[str] = ( - model_specs.training_instance_type_variants.get_instance_specific_artifact_key( + model_specs.training_instance_type_variants.get_instance_specific_training_artifact_key( instance_type=instance_type ) if instance_type @@ -185,8 +186,8 @@ def _retrieve_model_uri( os.environ.get(ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE) or default_jumpstart_bucket ) - - model_s3_uri = f"s3://{bucket}/{model_artifact_key}" + if not is_s3_url(model_artifact_key): + model_s3_uri = f"s3://{bucket}/{model_artifact_key}" return model_s3_uri diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 8ac813a6c5..f862d4702a 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -150,7 +150,8 @@ def __init__( if s3_client_config else boto3.client("s3", region_name=self._region) ) - self._sagemaker_session = sagemaker_session + # Fallback in case a caller overrides sagemaker_session to None + self._sagemaker_session = sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION def set_region(self, region: str) -> None: """Set region for cache. Clears cache after new region is set.""" @@ -262,7 +263,7 @@ def _model_id_retrieval_function( return JumpStartVersionedModelId(model_id, sm_compatible_model_version) versions_incompatible_with_sagemaker = [ - Version(header.version) + header.version for header in manifest.values() # type: ignore if header.model_id == model_id ] @@ -540,9 +541,7 @@ def _select_version( """ if version_str == "*": - if len(available_versions) == 0: - return None - return str(max(available_versions)) + return utils.get_latest_version(available_versions) if model_type == JumpStartModelType.PROPRIETARY: if "*" in version_str: diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index f3f7ecad1b..530e7ad16f 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -142,6 +142,11 @@ content_bucket="jumpstart-cache-prod-ap-southeast-5", gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-5", ), + JumpStartLaunchedRegionInfo( + region_name="ap-southeast-7", + content_bucket="jumpstart-cache-prod-ap-southeast-7", + gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-7", + ), JumpStartLaunchedRegionInfo( region_name="eu-west-2", content_bucket="jumpstart-cache-prod-eu-west-2", @@ -198,6 +203,11 @@ content_bucket="jumpstart-cache-prod-il-central-1", gated_content_bucket="jumpstart-private-cache-prod-il-central-1", ), + JumpStartLaunchedRegionInfo( + region_name="mx-central-1", + content_bucket="jumpstart-cache-prod-mx-central-1", + gated_content_bucket="jumpstart-private-cache-prod-mx-central-1", + ), JumpStartLaunchedRegionInfo( region_name="us-gov-east-1", content_bucket="jumpstart-cache-prod-us-gov-east-1", diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index a41c9ed952..af2fb5bc54 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -14,7 +14,7 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union from sagemaker import session from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -41,6 +41,9 @@ validate_model_id_and_get_type, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, + remove_env_var_from_estimator_kwargs_if_accept_eula_present, + get_model_access_config, + get_hub_access_config, ) from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model_monitor.data_capture_config import DataCaptureConfig @@ -350,8 +353,8 @@ def __init__( source_dir (Optional[Union[str, PipelineVariable]]): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file. If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + is preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. (Default: None). @@ -619,6 +622,10 @@ def _validate_model_id_and_get_type_hook(): self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation self.config_name = estimator_init_kwargs.config_name self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False) + # Access configs initialized to None, would be given a value when .fit() is called + # if applicable + self.model_access_config = None + self.hub_access_config = None super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict()) @@ -629,6 +636,7 @@ def fit( logs: Optional[str] = None, job_name: Optional[str] = None, experiment_config: Optional[Dict[str, str]] = None, + accept_eula: Optional[bool] = None, ) -> None: """Start training job by calling base ``Estimator`` class ``fit`` method. @@ -679,8 +687,16 @@ def fit( is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. However, the value of `TrialComponentDisplayName` is honored for display in Studio. (Default: None). + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). """ - + self.model_access_config = get_model_access_config(accept_eula) + self.hub_access_config = get_hub_access_config( + hub_content_arn=self.init_kwargs.get("model_reference_arn", None) + ) estimator_fit_kwargs = get_fit_kwargs( model_id=self.model_id, model_version=self.model_version, @@ -695,7 +711,9 @@ def fit( tolerate_deprecated_model=self.tolerate_deprecated_model, sagemaker_session=self.sagemaker_session, config_name=self.config_name, + hub_access_config=self.hub_access_config, ) + remove_env_var_from_estimator_kwargs_if_accept_eula_present(self.init_kwargs, accept_eula) return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) @@ -817,7 +835,7 @@ def deploy( explainer_config: Optional[ExplainerConfig] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, model_name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, @@ -918,7 +936,7 @@ def deploy( It can be null if this is being used to create a Model to pass to a ``PipelineModel`` which has its own Role field. (Default: None). - predictor_cls (Optional[callable[string, sagemaker.session.Session]]): A + predictor_cls (Optional[Callable[[string, sagemaker.session.Session], Any]]): A function to call to create a predictor (Default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. (Default: None). @@ -947,8 +965,8 @@ def deploy( source_dir (Optional[str]): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (Default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory is + preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. If the directory points to S3, no code is uploaded and the S3 location is used instead. (Default: None). diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index e4020a39bd..17ad7a76f5 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -14,7 +14,7 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union from sagemaker import ( environment_variables, hyperparameters as hyperparameters_utils, @@ -71,7 +71,6 @@ from sagemaker.jumpstart.utils import ( add_hub_content_arn_tags, add_jumpstart_model_info_tags, - get_eula_message, get_default_jumpstart_session_with_user_agent_suffix, get_top_ranked_config_name, update_dict_if_key_not_present, @@ -265,6 +264,7 @@ def get_fit_kwargs( tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, config_name: Optional[str] = None, + hub_access_config: Optional[Dict] = None, ) -> JumpStartEstimatorFitKwargs: """Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object.""" @@ -301,10 +301,32 @@ def get_fit_kwargs( estimator_fit_kwargs = _add_region_to_kwargs(estimator_fit_kwargs) estimator_fit_kwargs = _add_training_job_name_to_kwargs(estimator_fit_kwargs) estimator_fit_kwargs = _add_fit_extra_kwargs(estimator_fit_kwargs) + estimator_fit_kwargs = _add_hub_access_config_to_kwargs_inputs( + estimator_fit_kwargs, hub_access_config + ) return estimator_fit_kwargs +def _add_hub_access_config_to_kwargs_inputs( + kwargs: JumpStartEstimatorFitKwargs, hub_access_config=None +): + """Adds HubAccessConfig to kwargs inputs""" + + if isinstance(kwargs.inputs, str): + kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config) + elif isinstance(kwargs.inputs, TrainingInput): + kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config) + elif isinstance(kwargs.inputs, dict): + for k, v in kwargs.inputs.items(): + if isinstance(v, str): + kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config) + elif isinstance(kwargs.inputs, TrainingInput): + kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config) + + return kwargs + + def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, @@ -330,7 +352,7 @@ def get_deploy_kwargs( explainer_config: Optional[ExplainerConfig] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, sagemaker_session: Optional[Session] = None, @@ -668,18 +690,6 @@ def _add_env_to_kwargs( value, ) - environment = getattr(kwargs, "environment", {}) or {} - if ( - environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY) - and str(environment.get("accept_eula", "")).lower() != "true" - ): - model_specs = kwargs.specs - if model_specs.is_gated_model(): - raise ValueError( - "Need to define ‘accept_eula'='true' within Environment. " - f"{get_eula_message(model_specs, kwargs.region)}" - ) - return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 328e1e8227..4245c5ac91 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -15,7 +15,7 @@ import json -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from sagemaker_core.shapes import ModelAccessConfig from sagemaker import environment_variables, image_uris, instance_types, model_uris, script_uris from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig @@ -855,7 +855,7 @@ def get_init_kwargs( image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index bc42eebea0..402b2ce534 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -272,18 +272,21 @@ def delete_model_reference(self, model_name: str) -> None: def describe_model( self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None ) -> DescribeHubContentResponse: - """Describe model in the SageMaker Hub.""" + """Describe Model or ModelReference in a Hub.""" + hub_name = hub_name or self.hub_name + + # Users only input model id, not contentType, so first try to describe with ModelReference, then with Model try: model_version = get_hub_model_version( hub_model_name=model_name, hub_model_type=HubContentType.MODEL_REFERENCE.value, - hub_name=self.hub_name if not hub_name else hub_name, + hub_name=hub_name, sagemaker_session=self._sagemaker_session, hub_model_version=model_version, ) hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=self.hub_name if not hub_name else hub_name, + hub_name=hub_name, hub_content_name=model_name, hub_content_version=model_version, hub_content_type=HubContentType.MODEL_REFERENCE.value, @@ -294,19 +297,32 @@ def describe_model( "Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: " + str(ex) ) - model_version = get_hub_model_version( - hub_model_name=model_name, - hub_model_type=HubContentType.MODEL.value, - hub_name=self.hub_name if not hub_name else hub_name, - sagemaker_session=self._sagemaker_session, - hub_model_version=model_version, - ) - hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=self.hub_name if not hub_name else hub_name, - hub_content_name=model_name, - hub_content_version=model_version, - hub_content_type=HubContentType.MODEL.value, - ) + # Failed to describe ModelReference, try with Model + try: + model_version = get_hub_model_version( + hub_model_name=model_name, + hub_model_type=HubContentType.MODEL.value, + hub_name=hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=model_version, + ) + + hub_content_description: Dict[str, Any] = ( + self._sagemaker_session.describe_hub_content( + hub_name=hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL.value, + ) + ) + + except Exception as ex: + # Failed with both, throw a custom error message + raise RuntimeError( + f"Cannot get details for {model_name} in Hub {hub_name}. \ + {model_name} does not exist as a Model or ModelReference in {hub_name}: \n" + + str(ex) + ) return DescribeHubContentResponse(hub_content_description) diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index edc3c08fa7..1bbc6198a2 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -78,6 +78,9 @@ def construct_hub_arn_from_name( account_id: Optional[str] = None, ) -> str: """Constructs a Hub arn from the Hub name using default Session values.""" + if session is None: + # session is overridden to none by some callers + session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION account_id = account_id or session.account_id() region = region or session.boto_region_name @@ -211,6 +214,9 @@ def get_hub_model_version( ClientError: If the specified model is not found in the hub. KeyError: If the specified model version is not found. """ + if sagemaker_session is None: + # sagemaker_session is overridden to none by some callers + sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION try: hub_content_summaries = sagemaker_session.list_hub_content_versions( diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index b0b54db557..7dec3d78f9 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Any, Union +from typing import Callable, Dict, List, Optional, Any, Union import pandas as pd from botocore.exceptions import ClientError @@ -95,7 +95,7 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, @@ -149,7 +149,7 @@ def __init__( It can be null if this is being used to create a Model to pass to a ``PipelineModel`` which has its own Role field. (Default: None). - predictor_cls (Optional[callable[string, sagemaker.session.Session]]): A + predictor_cls (Optional[Callable[[string, sagemaker.session.Session], Any]]): A function to call to create a predictor (Default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. (Default: None). @@ -178,8 +178,8 @@ def __init__( source_dir (Optional[str]): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (Default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory is + preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. If the directory points to S3, no code is uploaded and the S3 location is used instead. (Default: None). diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index f59e2eddf4..349396205e 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -16,7 +16,7 @@ import re from copy import deepcopy from enum import Enum -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union from sagemaker_core.shapes import ModelAccessConfig as CoreModelAccessConfig from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard from sagemaker.utils import ( @@ -619,6 +619,19 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str instance_type=instance_type, property_name="artifact_key" ) + def get_instance_specific_training_artifact_key(self, instance_type: str) -> Optional[str]: + """Returns instance specific training artifact key. + + Returns None if a model, instance type tuple does not have specific + training artifact key. + """ + + return self._get_instance_specific_property( + instance_type=instance_type, property_name="training_artifact_uri" + ) or self._get_instance_specific_property( + instance_type=instance_type, property_name="training_artifact_key" + ) + def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]: """Returns instance specific resource requirements. @@ -1363,9 +1376,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {})) self.predictor_specs: Optional[JumpStartPredictorSpecs] = ( JumpStartPredictorSpecs( - json_obj["predictor_specs"], is_hub_content=self._is_hub_content + json_obj.get("predictor_specs"), + is_hub_content=self._is_hub_content, ) - if "predictor_specs" in json_obj + if json_obj.get("predictor_specs") else None ) self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( @@ -1501,6 +1515,9 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields): "incremental_training_supported", ] + # Map of HubContent fields that map to custom names in MetadataBaseFields + CUSTOM_FIELD_MAP = {"sage_maker_sdk_predictor_specifications": "predictor_specs"} + __slots__ = slots + JumpStartMetadataBaseFields.__slots__ def __init__( @@ -1532,6 +1549,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if field in self.__slots__: setattr(self, field, json_obj[field]) + # Handle custom fields + for custom_field, field in self.CUSTOM_FIELD_MAP.items(): + if custom_field in json_obj: + setattr(self, field, json_obj.get(custom_field)) + class JumpStartMetadataConfig(JumpStartDataHolderType): """Data class of JumpStart metadata config.""" @@ -2150,7 +2172,7 @@ def __init__( image_uri: Optional[Union[str, Any]] = None, model_data: Optional[Union[str, Any, dict]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, Any]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None, @@ -2698,7 +2720,7 @@ def __init__( explainer_config: Optional[Any] = None, image_uri: Optional[Union[str, Any]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, Any]]] = None, model_name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None, diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 46e5f8a847..bd81226727 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -21,7 +21,7 @@ from urllib.parse import urlparse import boto3 from botocore.exceptions import ClientError -from packaging.version import Version +from packaging.version import Version, InvalidVersion import botocore from sagemaker_core.shapes import ModelAccessConfig import sagemaker @@ -1630,3 +1630,52 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str: return get_jumpstart_gated_content_bucket(region=region) return get_jumpstart_content_bucket(region=region) return neo_bucket + + +def remove_env_var_from_estimator_kwargs_if_accept_eula_present( + init_kwargs: dict, accept_eula: Optional[bool] +): + """Remove env vars if access configs are used + + Args: + init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated. + accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit(). + """ + if accept_eula is not None and init_kwargs["environment"]: + del init_kwargs["environment"][constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY] + + +def get_hub_access_config(hub_content_arn: Optional[str]): + """Get hub access config + + Args: + hub_content_arn (Optional[bool]): Arn of the model reference hub content + """ + if hub_content_arn is not None: + hub_access_config = {"HubContentArn": hub_content_arn} + else: + hub_access_config = None + + return hub_access_config + + +def get_model_access_config(accept_eula: Optional[bool]): + """Get access configs + + Args: + accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit(). + """ + if accept_eula is not None: + model_access_config = {"AcceptEula": accept_eula} + else: + model_access_config = None + + return model_access_config + + +def get_latest_version(versions: List[str]) -> Optional[str]: + """Returns the latest version using sem-ver when possible.""" + try: + return None if not versions else max(versions, key=Version) + except InvalidVersion: + return max(versions) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 863bbf376c..e5ea1ea314 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -20,7 +20,7 @@ import os import re import copy -from typing import List, Dict, Optional, Union, Any +from typing import Callable, List, Dict, Optional, Union, Any import sagemaker from sagemaker import ( @@ -154,7 +154,7 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, @@ -186,7 +186,7 @@ def __init__( It can be null if this is being used to create a Model to pass to a ``PipelineModel`` which has its own Role field. (default: None) - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. @@ -215,8 +215,8 @@ def __init__( source_dir (str): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + is preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. If the directory points to S3, no code is uploaded and the S3 location is used instead. @@ -745,6 +745,8 @@ def is_repack(self) -> bool: Returns: bool: if the source need to be repacked or not """ + if self.source_dir is None or self.entry_point is None: + return False return self.source_dir and self.entry_point and not self.git_config def _upload_code(self, key_prefix: str, repack: bool = False) -> None: @@ -1492,6 +1494,9 @@ def deploy( } model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type content (default: None). + inference_ami_version (Optional [str]): Specifies an option from a collection of preconfigured + Amazon Machine Image (AMI) images. For a full list of options, see: + https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -1501,7 +1506,7 @@ def deploy( inference config or - If inference recommendation id is specified along with incompatible parameters Returns: - callable[string, sagemaker.session.Session] or None: Invocation of + Callable[[string, sagemaker.session.Session], Any] or None: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. """ @@ -1743,6 +1748,7 @@ def deploy( model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, routing_config=routing_config, + inference_ami_version=inference_ami_version, ) if endpoint_name: self.endpoint_name = endpoint_name @@ -1959,7 +1965,7 @@ def __init__( role: Optional[str] = None, entry_point: Optional[str] = None, source_dir: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, container_log_level: Union[int, PipelineVariable] = logging.INFO, @@ -1996,11 +2002,11 @@ def __init__( source_dir (str): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. If 'git_config' is provided, - 'source_dir' should be a relative location to a directory in the Git repo. - If the directory points to S3, no code will be uploaded and the S3 location - will be used instead. + point to a file with name ``sourcedir.tar.gz``. Structure within this + directory are preserved when training on Amazon SageMaker. If 'git_config' + is provided, 'source_dir' should be a relative location to a directory in the + Git repo. If the directory points to S3, no code will be uploaded and the S3 + location will be used instead. .. admonition:: Example @@ -2012,7 +2018,7 @@ def __init__( >>> |----- test.py You can assign entry_point='inference.py', source_dir='src'. - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. @@ -2139,6 +2145,8 @@ def is_repack(self) -> bool: Returns: bool: if the source need to be repacked or not """ + if self.source_dir is None or self.entry_point is None: + return False return self.source_dir and self.entry_point and not (self.key_prefix or self.git_config) diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 436377fea5..3bc29a1cf4 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -2413,7 +2413,12 @@ def _update_data_quality_monitoring_schedule( ) self.sagemaker_session.sagemaker_client.create_data_quality_job_definition(**request_dict) try: - self._update_monitoring_schedule(new_job_definition_name, schedule_cron_expression) + self._update_monitoring_schedule( + job_definition_name=new_job_definition_name, + schedule_cron_expression=schedule_cron_expression, + data_analysis_start_time=data_analysis_start_time, + data_analysis_end_time=data_analysis_end_time, + ) self.job_definition_name = new_job_definition_name if role is not None: self.role = role diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index ec0df519f5..458c596a36 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -22,7 +22,7 @@ from __future__ import absolute_import from typing import Optional, Union -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, model_validator, ConfigDict import sagemaker_core.shapes as shapes @@ -74,7 +74,13 @@ ] -class SourceCode(BaseModel): +class BaseConfig(BaseModel): + """BaseConfig""" + + model_config = ConfigDict(validate_assignment=True, extra="forbid") + + +class SourceCode(BaseConfig): """SourceCode. The SourceCode class allows the user to specify the source code location, dependencies, @@ -194,7 +200,7 @@ def _to_vpc_config(self) -> shapes.VpcConfig: return shapes.VpcConfig(**filtered_dict) -class InputData(BaseModel): +class InputData(BaseConfig): """InputData. This config allows the user to specify an input data source for the training job. diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py index 6cdc136dcf..f248b9b77c 100644 --- a/src/sagemaker/modules/distributed.py +++ b/src/sagemaker/modules/distributed.py @@ -13,12 +13,16 @@ """Distributed module.""" from __future__ import absolute_import +import os + +from abc import ABC, abstractmethod from typing import Optional, Dict, Any, List -from pydantic import BaseModel, PrivateAttr from sagemaker.modules.utils import safe_serialize +from sagemaker.modules.constants import SM_DRIVERS_LOCAL_PATH +from sagemaker.modules.configs import BaseConfig -class SMP(BaseModel): +class SMP(BaseConfig): """SMP. This class is used for configuring the SageMaker Model Parallelism v2 parameters. @@ -72,16 +76,37 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]: return hyperparameters -class DistributedConfig(BaseModel): - """Base class for distributed training configurations.""" +class DistributedConfig(BaseConfig, ABC): + """Abstract base class for distributed training configurations. + + This class defines the interface that all distributed training configurations + must implement. It provides a standardized way to specify driver scripts and + their locations for distributed training jobs. + """ + + @property + @abstractmethod + def driver_dir(self) -> str: + """Directory containing the driver script. + + This property should return the path to the directory containing + the driver script, relative to the container's working directory. - _type: str = PrivateAttr() + Returns: + str: Path to directory containing the driver script + """ - def model_dump(self, *args, **kwargs): - """Dump the model to a dictionary.""" - result = super().model_dump(*args, **kwargs) - result["_type"] = self._type - return result + @property + @abstractmethod + def driver_script(self) -> str: + """Name of the driver script. + + This property should return the name of the Python script that implements + the distributed training driver logic. + + Returns: + str: Name of the driver script file + """ class Torchrun(DistributedConfig): @@ -98,11 +123,27 @@ class Torchrun(DistributedConfig): The SageMaker Model Parallelism v2 parameters. """ - _type: str = PrivateAttr(default="torchrun") - process_count_per_node: Optional[int] = None smp: Optional["SMP"] = None + @property + def driver_dir(self) -> str: + """Directory containing the driver script. + + Returns: + str: Path to directory containing the driver script + """ + return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers") + + @property + def driver_script(self) -> str: + """Name of the driver script. + + Returns: + str: Name of the driver script file + """ + return "torchrun_driver.py" + class MPI(DistributedConfig): """MPI. @@ -118,7 +159,23 @@ class MPI(DistributedConfig): The custom MPI options to use for the training job. """ - _type: str = PrivateAttr(default="mpi") - process_count_per_node: Optional[int] = None mpi_additional_options: Optional[List[str]] = None + + @property + def driver_dir(self) -> str: + """Directory containing the driver script. + + Returns: + str: Path to directory containing the driver script + """ + return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers") + + @property + def driver_script(self) -> str: + """Name of the driver script. + + Returns: + str: Name of the driver script + """ + return "mpi_driver.py" diff --git a/src/sagemaker/modules/templates.py b/src/sagemaker/modules/templates.py index fba60dda47..d888b7bcb9 100644 --- a/src/sagemaker/modules/templates.py +++ b/src/sagemaker/modules/templates.py @@ -21,17 +21,12 @@ EXECUTE_BASIC_SCRIPT_DRIVER = """ echo "Running Basic Script driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/basic_script_driver.py +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/basic_script_driver.py """ -EXEUCTE_TORCHRUN_DRIVER = """ -echo "Running Torchrun driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/torchrun_driver.py -""" - -EXECUTE_MPI_DRIVER = """ -echo "Running MPI driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/mpi_driver.py +EXEUCTE_DISTRIBUTED_DRIVER = """ +echo "Running {driver_name} Driver" +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/{driver_script} """ TRAIN_SCRIPT_TEMPLATE = """ diff --git a/src/sagemaker/modules/train/container_drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/__init__.py index 18557a2eb5..864f3663b8 100644 --- a/src/sagemaker/modules/train/container_drivers/__init__.py +++ b/src/sagemaker/modules/train/container_drivers/__init__.py @@ -10,5 +10,5 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Sagemaker modules container_drivers directory.""" +"""Sagemaker modules container drivers directory.""" from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/common/__init__.py b/src/sagemaker/modules/train/container_drivers/common/__init__.py new file mode 100644 index 0000000000..aab88c6b97 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/common/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container drivers - common directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/utils.py b/src/sagemaker/modules/train/container_drivers/common/utils.py similarity index 98% rename from src/sagemaker/modules/train/container_drivers/utils.py rename to src/sagemaker/modules/train/container_drivers/common/utils.py index e939a6e0b8..c07aa1359a 100644 --- a/src/sagemaker/modules/train/container_drivers/utils.py +++ b/src/sagemaker/modules/train/container_drivers/common/utils.py @@ -99,10 +99,10 @@ def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAME return hyperparameters_dict -def get_process_count(distributed_dict: Dict[str, Any]) -> int: +def get_process_count(process_count: Optional[int] = None) -> int: """Get the number of processes to run on each node in the training job.""" return ( - int(distributed_dict.get("process_count_per_node", 0)) + process_count or int(os.environ.get("SM_NUM_GPUS", 0)) or int(os.environ.get("SM_NUM_NEURONS", 0)) or 1 diff --git a/src/sagemaker/modules/train/container_drivers/distributed_drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/__init__.py new file mode 100644 index 0000000000..a44e7e81a9 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container drivers - drivers directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/basic_script_driver.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/basic_script_driver.py similarity index 88% rename from src/sagemaker/modules/train/container_drivers/basic_script_driver.py rename to src/sagemaker/modules/train/container_drivers/distributed_drivers/basic_script_driver.py index cb0278bc9f..0b086a8e4f 100644 --- a/src/sagemaker/modules/train/container_drivers/basic_script_driver.py +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/basic_script_driver.py @@ -13,16 +13,19 @@ """This module is the entry point for the Basic Script Driver.""" from __future__ import absolute_import +import os import sys +import json import shlex +from pathlib import Path from typing import List -from utils import ( +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 logger, get_python_executable, - read_source_code_json, - read_hyperparameters_json, execute_commands, write_failure_file, hyperparameters_to_cli_args, @@ -31,11 +34,10 @@ def create_commands() -> List[str]: """Create the commands to execute.""" - source_code = read_source_code_json() - hyperparameters = read_hyperparameters_json() + entry_script = os.environ["SM_ENTRY_SCRIPT"] + hyperparameters = json.loads(os.environ["SM_HPS"]) python_executable = get_python_executable() - entry_script = source_code["entry_script"] args = hyperparameters_to_cli_args(hyperparameters) if entry_script.endswith(".py"): commands = [python_executable, entry_script] diff --git a/src/sagemaker/modules/train/container_drivers/mpi_driver.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_driver.py similarity index 83% rename from src/sagemaker/modules/train/container_drivers/mpi_driver.py rename to src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_driver.py index dceb748cc0..9946272617 100644 --- a/src/sagemaker/modules/train/container_drivers/mpi_driver.py +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_driver.py @@ -16,18 +16,8 @@ import os import sys import json +from pathlib import Path -from utils import ( - logger, - read_source_code_json, - read_distributed_json, - read_hyperparameters_json, - hyperparameters_to_cli_args, - get_process_count, - execute_commands, - write_failure_file, - USER_CODE_PATH, -) from mpi_utils import ( start_sshd_daemon, bootstrap_master_node, @@ -38,6 +28,16 @@ ) +sys.path.insert(0, str(Path(__file__).parent.parent)) +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + logger, + hyperparameters_to_cli_args, + get_process_count, + execute_commands, + write_failure_file, +) + + def main(): """Main function for the MPI driver script. @@ -58,9 +58,9 @@ def main(): 5. Exit """ - source_code = read_source_code_json() - distribution = read_distributed_json() - hyperparameters = read_hyperparameters_json() + entry_script = os.environ["SM_ENTRY_SCRIPT"] + distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + hyperparameters = json.loads(os.environ["SM_HPS"]) sm_current_host = os.environ["SM_CURRENT_HOST"] sm_hosts = json.loads(os.environ["SM_HOSTS"]) @@ -77,7 +77,8 @@ def main(): host_list = json.loads(os.environ["SM_HOSTS"]) host_count = int(os.environ["SM_HOST_COUNT"]) - process_count = get_process_count(distribution) + process_count = int(distributed_config["process_count_per_node"] or 0) + process_count = get_process_count(process_count) if process_count > 1: host_list = ["{}:{}".format(host, process_count) for host in host_list] @@ -86,8 +87,8 @@ def main(): host_count=host_count, host_list=host_list, num_processes=process_count, - additional_options=distribution.get("mpi_additional_options", []), - entry_script_path=os.path.join(USER_CODE_PATH, source_code["entry_script"]), + additional_options=distributed_config["mpi_additional_options"] or [], + entry_script_path=entry_script, ) args = hyperparameters_to_cli_args(hyperparameters) diff --git a/src/sagemaker/modules/train/container_drivers/mpi_utils.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_utils.py similarity index 81% rename from src/sagemaker/modules/train/container_drivers/mpi_utils.py rename to src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_utils.py index c3c2b7effe..ec9e1fcef9 100644 --- a/src/sagemaker/modules/train/container_drivers/mpi_utils.py +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_utils.py @@ -14,12 +14,23 @@ from __future__ import absolute_import import os -import time +import sys import subprocess +import time +from pathlib import Path from typing import List -from utils import logger, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable +import paramiko + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + SM_EFA_NCCL_INSTANCES, + SM_EFA_RDMA_INSTANCES, + get_python_executable, + logger, +) FINISHED_STATUS_FILE = "/tmp/done.algo-1" READY_FILE = "/tmp/ready.%s" @@ -75,19 +86,45 @@ def start_sshd_daemon(): logger.info("Started SSH daemon.") +class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy): + """Class to handle host key policy for SageMaker distributed training SSH connections. + + Example: + >>> client = paramiko.SSHClient() + >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) + >>> # Will succeed for SageMaker algorithm containers + >>> client.connect('algo-1234.internal') + >>> # Will raise SSHException for other unknown hosts + >>> client.connect('unknown-host') # raises SSHException + """ + + def missing_host_key(self, client, hostname, key): + """Accept host keys for algo-* hostnames, reject others. + + Args: + client: The SSHClient instance + hostname: The hostname attempting to connect + key: The host key + + Raises: + paramiko.SSHException: If hostname doesn't match algo-* pattern + """ + if hostname.startswith("algo-"): + client.get_host_keys().add(hostname, key.get_name(), key) + return + raise paramiko.SSHException(f"Unknown host key for {hostname}") + + def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: """Check if the connection to the provided host and port is possible.""" try: - import paramiko - logger.debug("Testing connection to host %s", host) - client = paramiko.SSHClient() - client.load_system_host_keys() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - client.connect(host, port=port) - client.close() - logger.info("Can connect to host %s", host) - return True + with paramiko.SSHClient() as client: + client.load_system_host_keys() + client.set_missing_host_key_policy(CustomHostKeyPolicy()) + client.connect(host, port=port) + logger.info("Can connect to host %s", host) + return True except Exception as e: # pylint: disable=W0703 logger.info("Cannot connect to host %s", host) logger.debug(f"Connection failed with exception: {e}") @@ -183,9 +220,9 @@ def validate_smddpmprun() -> bool: def write_env_vars_to_file(): """Write environment variables to /etc/environment file.""" - with open("/etc/environment", "a") as f: + with open("/etc/environment", "a", encoding="utf-8") as f: for name in os.environ: - f.write("{}={}\n".format(name, os.environ.get(name))) + f.write(f"{name}={os.environ.get(name)}\n") def get_mpirun_command( diff --git a/src/sagemaker/modules/train/container_drivers/torchrun_driver.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/torchrun_driver.py similarity index 87% rename from src/sagemaker/modules/train/container_drivers/torchrun_driver.py rename to src/sagemaker/modules/train/container_drivers/distributed_drivers/torchrun_driver.py index 666479ec84..7fcfabe05d 100644 --- a/src/sagemaker/modules/train/container_drivers/torchrun_driver.py +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/torchrun_driver.py @@ -15,20 +15,20 @@ import os import sys +import json +from pathlib import Path from typing import List, Tuple -from utils import ( +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 logger, - read_source_code_json, - read_distributed_json, - read_hyperparameters_json, hyperparameters_to_cli_args, get_process_count, get_python_executable, execute_commands, write_failure_file, - USER_CODE_PATH, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, ) @@ -65,11 +65,12 @@ def setup_env(): def create_commands(): """Create the Torch Distributed command to execute""" - source_code = read_source_code_json() - distribution = read_distributed_json() - hyperparameters = read_hyperparameters_json() + entry_script = os.environ["SM_ENTRY_SCRIPT"] + distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + hyperparameters = json.loads(os.environ["SM_HPS"]) - process_count = get_process_count(distribution) + process_count = int(distributed_config["process_count_per_node"] or 0) + process_count = get_process_count(process_count) host_count = int(os.environ["SM_HOST_COUNT"]) torch_cmd = [] @@ -94,7 +95,7 @@ def create_commands(): ] ) - torch_cmd.extend([os.path.join(USER_CODE_PATH, source_code["entry_script"])]) + torch_cmd.extend([entry_script]) args = hyperparameters_to_cli_args(hyperparameters) torch_cmd += args diff --git a/src/sagemaker/modules/train/container_drivers/scripts/__init__.py b/src/sagemaker/modules/train/container_drivers/scripts/__init__.py index 1abbce4067..f04c5b17a0 100644 --- a/src/sagemaker/modules/train/container_drivers/scripts/__init__.py +++ b/src/sagemaker/modules/train/container_drivers/scripts/__init__.py @@ -10,5 +10,5 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Sagemaker modules scripts directory.""" +"""Sagemaker modules container drivers - scripts directory.""" from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/scripts/environment.py b/src/sagemaker/modules/train/container_drivers/scripts/environment.py index ea6abac425..897b1f8af4 100644 --- a/src/sagemaker/modules/train/container_drivers/scripts/environment.py +++ b/src/sagemaker/modules/train/container_drivers/scripts/environment.py @@ -19,12 +19,17 @@ import json import os import sys +from pathlib import Path import logging -parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -sys.path.insert(0, parent_dir) +sys.path.insert(0, str(Path(__file__).parent.parent)) -from utils import safe_serialize, safe_deserialize # noqa: E402 # pylint: disable=C0413 +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + safe_serialize, + safe_deserialize, + read_distributed_json, + read_source_code_json, +) # Initialize logger SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) @@ -42,6 +47,8 @@ SM_OUTPUT_DIR = "/opt/ml/output" SM_OUTPUT_FAILURE = "/opt/ml/output/failure" SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" +SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code" +SM_DISTRIBUTED_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/distributed_drivers" SM_MASTER_ADDR = "algo-1" SM_MASTER_PORT = 7777 @@ -158,6 +165,17 @@ def set_env( "SM_MASTER_PORT": SM_MASTER_PORT, } + # SourceCode and DistributedConfig Environment Variables + source_code = read_source_code_json() + if source_code: + env_vars["SM_SOURCE_DIR"] = SM_SOURCE_DIR_PATH + env_vars["SM_ENTRY_SCRIPT"] = source_code.get("entry_script", "") + + distributed = read_distributed_json() + if distributed: + env_vars["SM_DISTRIBUTED_DRIVER_DIR"] = SM_DISTRIBUTED_DRIVER_DIR_PATH + env_vars["SM_DISTRIBUTED_CONFIG"] = distributed + # Data Channels channels = list(input_data_config.keys()) for channel in channels: diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 31decfaca9..aef6e3312b 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -18,8 +18,8 @@ import json import shutil from tempfile import TemporaryDirectory - from typing import Optional, List, Union, Dict, Any, ClassVar +import yaml from graphene.utils.str_converters import to_camel_case, to_snake_case @@ -70,7 +70,7 @@ ) from sagemaker.modules.local_core.local_container import _LocalContainer -from sagemaker.modules.distributed import Torchrun, MPI, DistributedConfig +from sagemaker.modules.distributed import Torchrun, DistributedConfig from sagemaker.modules.utils import ( _get_repo_name_from_image, _get_unique_name, @@ -94,8 +94,7 @@ from sagemaker.modules.templates import ( TRAIN_SCRIPT_TEMPLATE, EXECUTE_BASE_COMMANDS, - EXECUTE_MPI_DRIVER, - EXEUCTE_TORCHRUN_DRIVER, + EXEUCTE_DISTRIBUTED_DRIVER, EXECUTE_BASIC_SCRIPT_DRIVER, ) from sagemaker.telemetry.telemetry_logging import _telemetry_emitter @@ -153,7 +152,7 @@ class ModelTrainer(BaseModel): source_code (Optional[SourceCode]): The source code configuration. This is used to configure the source code for running the training job. - distributed (Optional[Union[MPI, Torchrun]]): + distributed (Optional[DistributedConfig]): The distributed runner for the training job. This is used to configure a distributed training job. If specifed, ``source_code`` must also be provided. @@ -195,8 +194,9 @@ class ModelTrainer(BaseModel): Defaults to "File". environment (Optional[Dict[str, str]]): The environment variables for the training job. - hyperparameters (Optional[Dict[str, Any]]): - The hyperparameters for the training job. + hyperparameters (Optional[Union[Dict[str, Any], str]): + The hyperparameters for the training job. Can be a dictionary of hyperparameters + or a path to hyperparameters json/yaml file. tags (Optional[List[Tag]]): An array of key-value pairs. You can use tags to categorize your AWS resources in different ways, for example, by purpose, owner, or environment. @@ -205,14 +205,16 @@ class ModelTrainer(BaseModel): "LOCAL_CONTAINER" mode. """ - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + model_config = ConfigDict( + arbitrary_types_allowed=True, validate_assignment=True, extra="forbid" + ) training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB sagemaker_session: Optional[Session] = None role: Optional[str] = None base_job_name: Optional[str] = None source_code: Optional[SourceCode] = None - distributed: Optional[Union[MPI, Torchrun]] = None + distributed: Optional[DistributedConfig] = None compute: Optional[Compute] = None networking: Optional[Networking] = None stopping_condition: Optional[StoppingCondition] = None @@ -224,7 +226,7 @@ class ModelTrainer(BaseModel): checkpoint_config: Optional[CheckpointConfig] = None training_input_mode: Optional[str] = "File" environment: Optional[Dict[str, str]] = {} - hyperparameters: Optional[Dict[str, Any]] = {} + hyperparameters: Optional[Union[Dict[str, Any], str]] = {} tags: Optional[List[Tag]] = None local_container_root: Optional[str] = os.getcwd() @@ -363,9 +365,10 @@ def _populate_intelligent_defaults_from_model_trainer_space(self): def __del__(self): """Destructor method to clean up the temporary directory.""" - # Clean up the temporary directory if it exists - if self._temp_recipe_train_dir is not None: - self._temp_recipe_train_dir.cleanup() + # Clean up the temporary directory if it exists and class was initialized + if hasattr(self, "__pydantic_fields_set__"): + if self._temp_recipe_train_dir is not None: + self._temp_recipe_train_dir.cleanup() def _validate_training_image_and_algorithm_name( self, training_image: Optional[str], algorithm_name: Optional[str] @@ -467,6 +470,29 @@ def model_post_init(self, __context: Any): f"StoppingCondition not provided. Using default:\n{self.stopping_condition}" ) + if self.hyperparameters and isinstance(self.hyperparameters, str): + if not os.path.exists(self.hyperparameters): + raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}") + logger.info(f"Loading hyperparameters from file: {self.hyperparameters}") + with open(self.hyperparameters, "r") as f: + contents = f.read() + try: + self.hyperparameters = json.loads(contents) + logger.debug("Hyperparameters loaded as JSON") + except json.JSONDecodeError: + try: + logger.info(f"contents: {contents}") + self.hyperparameters = yaml.safe_load(contents) + if not isinstance(self.hyperparameters, dict): + raise ValueError("YAML contents must be a valid mapping") + logger.info(f"hyperparameters: {self.hyperparameters}") + logger.debug("Hyperparameters loaded as YAML") + except (yaml.YAMLError, ValueError): + raise ValueError( + f"Invalid hyperparameters file: {self.hyperparameters}. " + "Must be a valid JSON or YAML file." + ) + if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None: session = self.sagemaker_session base_job_name = self.base_job_name @@ -534,12 +560,17 @@ def train( container_arguments = None if self.source_code: if self.training_mode == Mode.LOCAL_CONTAINER: - drivers_dir = TemporaryDirectory( - prefix=os.path.join(self.local_container_root + "/") - ) + tmp_dir = TemporaryDirectory(prefix=os.path.join(self.local_container_root + "/")) else: - drivers_dir = TemporaryDirectory() - shutil.copytree(SM_DRIVERS_LOCAL_PATH, drivers_dir.name, dirs_exist_ok=True) + tmp_dir = TemporaryDirectory() + # Copy everything under container_drivers/ to a temporary directory + shutil.copytree(SM_DRIVERS_LOCAL_PATH, tmp_dir.name, dirs_exist_ok=True) + + # If distributed is provided, overwrite code under /drivers + if self.distributed: + distributed_driver_dir = self.distributed.driver_dir + driver_dir = os.path.join(tmp_dir.name, "distributed_drivers") + shutil.copytree(distributed_driver_dir, driver_dir, dirs_exist_ok=True) # If source code is provided, create a channel for the source code # The source code will be mounted at /opt/ml/input/data/code in the container @@ -552,7 +583,7 @@ def train( input_data_config.append(source_code_channel) self._prepare_train_script( - tmp_dir=drivers_dir, + tmp_dir=tmp_dir, source_code=self.source_code, distributed=self.distributed, ) @@ -561,13 +592,13 @@ def train( mp_parameters = self.distributed.smp._to_mp_hyperparameters() string_hyper_parameters.update(mp_parameters) - self._write_source_code_json(tmp_dir=drivers_dir, source_code=self.source_code) - self._write_distributed_json(tmp_dir=drivers_dir, distributed=self.distributed) + self._write_source_code_json(tmp_dir=tmp_dir, source_code=self.source_code) + self._write_distributed_json(tmp_dir=tmp_dir, distributed=self.distributed) # Create an input channel for drivers packaged by the sdk sm_drivers_channel = self.create_input_data_channel( channel_name=SM_DRIVERS, - data_source=drivers_dir.name, + data_source=tmp_dir.name, key_prefix=input_data_key_prefix, ) input_data_config.append(sm_drivers_channel) @@ -769,7 +800,7 @@ def _write_source_code_json(self, tmp_dir: TemporaryDirectory, source_code: Sour """Write the source code configuration to a JSON file.""" file_path = os.path.join(tmp_dir.name, SOURCE_CODE_JSON) with open(file_path, "w") as f: - dump = source_code.model_dump(exclude_none=True) if source_code else {} + dump = source_code.model_dump() if source_code else {} f.write(json.dumps(dump)) def _write_distributed_json( @@ -780,7 +811,7 @@ def _write_distributed_json( """Write the distributed runner configuration to a JSON file.""" file_path = os.path.join(tmp_dir.name, DISTRIBUTED_JSON) with open(file_path, "w") as f: - dump = distributed.model_dump(exclude_none=True) if distributed else {} + dump = distributed.model_dump() if distributed else {} f.write(json.dumps(dump)) def _prepare_train_script( @@ -792,14 +823,14 @@ def _prepare_train_script( """Prepare the training script to be executed in the training job container. Args: - source_code (SourceCodeConfig): The source code configuration. + source_code (SourceCode): The source code configuration. """ base_command = "" if source_code.command: if source_code.entry_script: logger.warning( - "Both 'command' and 'entry_script' are provided in the SourceCodeConfig. " + "Both 'command' and 'entry_script' are provided in the SourceCode. " + "Defaulting to 'command'." ) base_command = source_code.command.split() @@ -817,13 +848,10 @@ def _prepare_train_script( if base_command: execute_driver = EXECUTE_BASE_COMMANDS.format(base_command=base_command) elif distributed: - distribution_type = distributed._type - if distribution_type == "mpi": - execute_driver = EXECUTE_MPI_DRIVER - elif distribution_type == "torchrun": - execute_driver = EXEUCTE_TORCHRUN_DRIVER - else: - raise ValueError(f"Unsupported distribution type: {distribution_type}.") + execute_driver = EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name=distributed.__class__.__name__, + driver_script=distributed.driver_script, + ) elif source_code.entry_script and not source_code.command and not distributed: if not source_code.entry_script.endswith((".py", ".sh")): raise ValueError( @@ -831,6 +859,13 @@ def _prepare_train_script( + "Only .py and .sh scripts are supported." ) execute_driver = EXECUTE_BASIC_SCRIPT_DRIVER + else: + # This should never be reached, as the source_code should have been validated. + raise ValueError( + f"Unsupported SourceCode or DistributedConfig: {source_code}, {distributed}." + + "Please provide a valid configuration with atleast one of 'command'" + + " or entry_script'." + ) train_script = TRAIN_SCRIPT_TEMPLATE.format( working_dir=working_dir, diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index ff38bcbde8..549645cbe2 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -125,6 +125,27 @@ def _register_custom_resolvers(): OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers)) +def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): + """Get the model base name and script for the training recipe.""" + + model_type_to_script = { + "llama_v3": ("llama", "llama_pretrain.py"), + "mistral": ("mistral", "mistral_pretrain.py"), + "mixtral": ("mixtral", "mixtral_pretrain.py"), + "deepseek": ("deepseek", "deepseek_pretrain.py"), + } + + for key in model_type_to_script: + if model_type.startswith(key): + model_type = key + break + + if model_type not in model_type_to_script: + raise ValueError(f"Model type {model_type} not supported") + + return model_type_to_script[model_type][0], model_type_to_script[model_type][1] + + def _configure_gpu_args( training_recipes_cfg: Dict[str, Any], region_name: str, @@ -140,24 +161,16 @@ def _configure_gpu_args( ) _run_clone_command_silent(adapter_repo, recipe_train_dir.name) - model_type_to_entry = { - "llama_v3": ("llama", "llama_pretrain.py"), - "mistral": ("mistral", "mistral_pretrain.py"), - "mixtral": ("mixtral", "mixtral_pretrain.py"), - } - if "model" not in recipe: raise ValueError("Supplied recipe does not contain required field model.") if "model_type" not in recipe["model"]: raise ValueError("Supplied recipe does not contain required field model_type.") model_type = recipe["model"]["model_type"] - if model_type not in model_type_to_entry: - raise ValueError(f"Model type {model_type} not supported") - source_code.source_dir = os.path.join( - recipe_train_dir.name, "examples", model_type_to_entry[model_type][0] - ) - source_code.entry_script = model_type_to_entry[model_type][1] + model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type) + + source_code.source_dir = os.path.join(recipe_train_dir.name, "examples", model_base_name) + source_code.entry_script = script gpu_image_cfg = training_recipes_cfg.get("gpu_image") if isinstance(gpu_image_cfg, str): diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index 9ed348c927..43a3588e6f 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -223,7 +223,7 @@ def deploy( Amazon SageMaker Model Monitoring. Default: None. Returns: - callable[string, sagemaker.session.Session] or None: Invocation of + Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index 104b93e00a..5126a37a85 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -84,8 +84,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 0dcd71741d..fa0c691d2d 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, List, Dict +from typing import Callable, Union, Optional, List, Dict import packaging.version @@ -68,9 +68,9 @@ def __init__( manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. - serializer (callable): Optional. Default serializes input data to + serializer (Callable): Optional. Default serializes input data to json. Handles dicts, lists, and numpy arrays. - deserializer (callable): Optional. Default parses the response using + deserializer (Callable): Optional. Default parses the response using ``json.load(...)``. component_name (str): Optional. Name of the Amazon SageMaker inference component corresponding to the predictor. @@ -98,7 +98,7 @@ def __init__( framework_version: str = _LOWEST_MMS_VERSION, py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = MXNetPredictor, + predictor_cls: Optional[Callable] = MXNetPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): @@ -127,7 +127,7 @@ def __init__( If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If ``image_uri`` is also ``None``, then a ``ValueError`` will be raised. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 04fbc1cc93..b36cd4e917 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -13,10 +13,12 @@ """Placeholder docstring""" from __future__ import absolute_import -from typing import Optional, Dict, List, Union +from typing import Callable, Optional, Dict, List, Union import sagemaker from sagemaker import ModelMetrics, Model +from sagemaker import local +from sagemaker import session from sagemaker.config import ( ENDPOINT_CONFIG_KMS_KEY_ID_PATH, MODEL_VPC_CONFIG_PATH, @@ -54,7 +56,7 @@ def __init__( self, models: List[Model], role: str = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, sagemaker_session: Optional[Session] = None, @@ -75,7 +77,7 @@ def __init__( endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. @@ -230,7 +232,7 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Returns: - callable[string, sagemaker.session.Session] or None: Invocation of + Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. """ @@ -560,3 +562,16 @@ def delete_model(self): raise ValueError("The SageMaker model must be created before attempting to delete.") self.sagemaker_session.delete_model(self.name) + + def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): + """Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already. + + The type of session object is determined by the instance type. + """ + if self.sagemaker_session: + return + + if instance_type in ("local", "local_gpu"): + self.sagemaker_session = local.LocalSession(sagemaker_config=self._sagemaker_config) + else: + self.sagemaker_session = session.Session(sagemaker_config=self._sagemaker_config) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 36cb920dde..d8674f269d 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -18,51 +18,51 @@ """ from __future__ import absolute_import +import logging import os import pathlib -import logging +import re +from copy import copy from textwrap import dedent from typing import Dict, List, Optional, Union -from copy import copy -import re import attr - from six.moves.urllib.parse import urlparse from six.moves.urllib.request import url2pathname + from sagemaker import s3 +from sagemaker.apiutils._base_types import ApiObject from sagemaker.config import ( + PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + PROCESSING_JOB_ENVIRONMENT_PATH, + PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, PROCESSING_JOB_KMS_KEY_ID_PATH, + PROCESSING_JOB_ROLE_ARN_PATH, PROCESSING_JOB_SECURITY_GROUP_IDS_PATH, PROCESSING_JOB_SUBNETS_PATH, - PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH, - PROCESSING_JOB_ROLE_ARN_PATH, - PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, - PROCESSING_JOB_ENVIRONMENT_PATH, ) +from sagemaker.dataset_definition.inputs import DatasetDefinition, S3Input from sagemaker.job import _Job from sagemaker.local import LocalSession from sagemaker.network import NetworkConfig +from sagemaker.s3 import S3Uploader +from sagemaker.session import Session from sagemaker.utils import ( + Tags, base_name_from_image, + check_and_get_run_experiment_config, + format_tags, get_config_value, name_from_base, - check_and_get_run_experiment_config, - resolve_value_from_config, resolve_class_attribute_from_config, - Tags, - format_tags, + resolve_value_from_config, ) -from sagemaker.session import Session from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.functions import Join from sagemaker.workflow.pipeline_context import runnable_by_pipeline -from sagemaker.workflow.execution_variables import ExecutionVariables -from sagemaker.workflow.entities import PipelineVariable -from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition -from sagemaker.apiutils._base_types import ApiObject -from sagemaker.s3 import S3Uploader logger = logging.getLogger(__name__) @@ -1416,7 +1416,7 @@ class RunArgs(object): class FeatureStoreOutput(ApiObject): """Configuration for processing job outputs in Amazon SageMaker Feature Store.""" - feature_group_name = None + feature_group_name: Optional[str] = None class FrameworkProcessor(ScriptProcessor): @@ -1465,7 +1465,7 @@ def __init__( instance_type (str or PipelineVariable): The type of EC2 instance to use for processing, for example, 'ml.c4.xlarge'. py_version (str): Python version you want to use for executing your - model training code. One of 'py2' or 'py3'. Defaults to 'py3'. Value + model training code. Ex `py38, py39, py310, py311`. Value is ignored when ``image_uri`` is provided. image_uri (str or PipelineVariable): The URI of the Docker image to use for the processing jobs (default: None). diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 46c57581d1..d56c100546 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -95,6 +95,7 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir): "llama_v3": ("llama", "llama_pretrain.py"), "mistral": ("mistral", "mistral_pretrain.py"), "mixtral": ("mixtral", "mixtral_pretrain.py"), + "deepseek": ("deepseek", "deepseek_pretrain.py"), } if "model" not in recipe: @@ -102,6 +103,12 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir): if "model_type" not in recipe["model"]: raise ValueError("Supplied recipe does not contain required field model_type.") model_type = recipe["model"]["model_type"] + + for key in model_type_to_script: + if model_type.startswith(key): + model_type = key + break + if model_type not in model_type_to_script: raise ValueError(f"Model type {model_type} not supported") @@ -175,8 +182,8 @@ def __init__( unless ``image_uri`` is provided. source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry - point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved + point file (default: None). If ``source_dir`` is an S3 URI, it must point to a + file with name ``sourcedir.tar.gz``. Structure within this directory are preserved when training on Amazon SageMaker. Must be a local path when using training_recipe. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 329f9b83b5..958327ba08 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import packaging.version @@ -99,7 +99,7 @@ def __init__( framework_version: str = "1.3", py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = PyTorchPredictor, + predictor_cls: Optional[Callable] = PyTorchPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): @@ -128,7 +128,7 @@ def __init__( If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If ``image_uri`` is also ``None``, then a ``ValueError`` will be raised. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 15051dc04a..76a8443fba 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -90,7 +90,8 @@ def remote( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, - use_torchrun=False, + use_torchrun: bool = False, + use_mpirun: bool = False, nproc_per_node: Optional[int] = None, ): """Decorator for running the annotated function as a SageMaker training job. @@ -207,7 +208,8 @@ def remote( files are accepted and uploaded to S3. instance_count (int): The number of instances to use. Defaults to 1. - NOTE: Remote function does not support instance_count > 1 for non Spark jobs. + NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and + mpirun utilities instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. @@ -284,6 +286,9 @@ def remote( use_torchrun (bool): Specifies whether to use torchrun for distributed training. Defaults to ``False``. + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + nproc_per_node (Optional int): Specifies the number of processes per node for distributed training. Defaults to ``None``. This is defined automatically configured on the instance type. @@ -320,6 +325,7 @@ def _remote(func): use_spot_instances=use_spot_instances, max_wait_time_in_seconds=max_wait_time_in_seconds, use_torchrun=use_torchrun, + use_mpirun=use_mpirun, nproc_per_node=nproc_per_node, ) @@ -327,12 +333,13 @@ def _remote(func): def wrapper(*args, **kwargs): if instance_count > 1 and not ( - (spark_config is not None and not use_torchrun) - or (spark_config is None and use_torchrun) + (spark_config is not None and not use_torchrun and not use_mpirun) + or (spark_config is None and use_torchrun and not use_mpirun) + or (spark_config is None and not use_torchrun and use_mpirun) ): raise ValueError( "Remote function do not support training on multi instances " - + "without spark_config or use_torchrun. " + + "without spark_config or use_torchrun or use_mpirun. " + "Please provide instance_count = 1" ) @@ -536,7 +543,8 @@ def __init__( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, - use_torchrun=False, + use_torchrun: bool = False, + use_mpirun: bool = False, nproc_per_node: Optional[int] = None, ): """Constructor for RemoteExecutor @@ -650,7 +658,8 @@ def __init__( files are accepted and uploaded to S3. instance_count (int): The number of instances to use. Defaults to 1. - NOTE: Remote function does not support instance_count > 1 for non Spark jobs. + NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and + mpirun utilities instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. @@ -730,6 +739,9 @@ def __init__( use_torchrun (bool): Specifies whether to use torchrun for distributed training. Defaults to ``False``. + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + nproc_per_node (Optional int): Specifies the number of processes per node for distributed training. Defaults to ``None``. This is defined automatically configured on the instance type. @@ -740,12 +752,13 @@ def __init__( raise ValueError("max_parallel_jobs must be greater than 0.") if instance_count > 1 and not ( - (spark_config is not None and not use_torchrun) - or (spark_config is None and use_torchrun) + (spark_config is not None and not use_torchrun and not use_mpirun) + or (spark_config is None and use_torchrun and not use_mpirun) + or (spark_config is None and not use_torchrun and use_mpirun) ): raise ValueError( "Remote function do not support training on multi instances " - + "without spark_config or use_torchrun. " + + "without spark_config or use_torchrun or use_mpirun. " + "Please provide instance_count = 1" ) @@ -778,6 +791,7 @@ def __init__( use_spot_instances=use_spot_instances, max_wait_time_in_seconds=max_wait_time_in_seconds, use_torchrun=use_torchrun, + use_mpirun=use_mpirun, nproc_per_node=nproc_per_node, ) diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 4e2e749bcb..52cb0ff04f 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -81,6 +81,7 @@ # runtime script names BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py" +MPI_UTILS_SCRIPT_NAME = "mpi_utils_remote.py" ENTRYPOINT_SCRIPT_NAME = "job_driver.sh" PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py" @@ -167,6 +168,99 @@ fi """ +ENTRYPOINT_MPIRUN_SCRIPT = f""" +#!/bin/bash + +# Entry point for bootstrapping runtime environment and invoking remote function with mpirun + +set -eu + +PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} +export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs +printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" +export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip +printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" + +printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" +cat /opt/ml/input/config/resourceconfig.json + +printf "INFO: Bootstraping runtime environment.\\n" +python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" +source /opt/ml/input/sm_training.env + +if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] +then + if [ -f "remote_function_conda_env.txt" ] + then + cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt + fi + printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" + cd {JOB_REMOTE_FUNCTION_WORKSPACE} +fi + +if [ -f "remote_function_conda_env.txt" ] +then + conda_env=$(cat remote_function_conda_env.txt) + + if which mamba >/dev/null; then + conda_exe="mamba" + else + conda_exe="conda" + fi + + if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + + printf "INFO: Invoking remote function with mpirun inside conda environment: $conda_env.\\n" + printf "INFO: $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + + python -m mpi4py -m sagemaker.remote_function.invoke_function \\n" + $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.remote_function.invoke_function "$@" + + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 + else + printf "INFO: This is the instance $SM_CURRENT_HOST. mpirun command terminated\\n" + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + fi +else + if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + + printf "INFO: No conda env provided. Invoking remote function with mpirun\\n" + printf "INFO: mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.remote_function.invoke_function \\n" + + mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.remote_function.invoke_function "$@" + + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 + else + printf "INFO: This is the instance $SM_CURRENT_HOST.\\n" + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + fi +fi +""" + ENTRYPOINT_TORCHRUN_SCRIPT = f""" #!/bin/bash @@ -211,6 +305,7 @@ printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ -m sagemaker.remote_function.invoke_function \\n" + $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ -m sagemaker.remote_function.invoke_function "$@" @@ -218,6 +313,7 @@ printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\n" + torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@" fi @@ -278,6 +374,7 @@ def __init__( use_spot_instances=False, max_wait_time_in_seconds=None, use_torchrun: bool = False, + use_mpirun: bool = False, nproc_per_node: Optional[int] = None, ): """Initialize a _JobSettings instance which configures the remote job. @@ -464,6 +561,9 @@ def __init__( use_torchrun (bool): Specifies whether to use torchrun for distributed training. Defaults to ``False``. + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + nproc_per_node (Optional int): Specifies the number of processes per node for distributed training. Defaults to ``None``. This is defined automatically configured on the instance type. @@ -626,6 +726,7 @@ def __init__( self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS) self.use_torchrun = use_torchrun + self.use_mpirun = use_mpirun self.nproc_per_node = nproc_per_node @staticmethod @@ -769,7 +870,7 @@ def compile( job_settings: _JobSettings, job_name: str, s3_base_uri: str, - func: callable, + func: Callable, func_args: tuple, func_kwargs: dict, run_info=None, @@ -874,6 +975,12 @@ def compile( ).to_string(), ] ) + if job_settings.use_torchrun: + container_args.extend(["--distribution", "torchrun"]) + elif job_settings.use_mpirun: + container_args.extend(["--distribution", "mpirun"]) + if job_settings.nproc_per_node is not None and int(job_settings.nproc_per_node) > 0: + container_args.extend(["--user_nproc_per_node", str(job_settings.nproc_per_node)]) if job_settings.s3_kms_key: container_args.extend(["--s3_kms_key", job_settings.s3_kms_key]) @@ -950,6 +1057,7 @@ def compile( request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key}) extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri) + extended_request = _extend_mpirun_to_request(extended_request, job_settings) extended_request = _extend_torchrun_to_request(extended_request, job_settings) return extended_request @@ -1031,7 +1139,7 @@ def _prepare_and_upload_runtime_scripts( s3_kms_key: str, sagemaker_session: Session, use_torchrun: bool = False, - nproc_per_node: Optional[int] = None, + use_mpirun: bool = False, ): """Copy runtime scripts to a folder and upload to S3. @@ -1050,6 +1158,8 @@ def _prepare_and_upload_runtime_scripts( use_torchrun (bool): Whether to use torchrun or not. + use_mpirun (bool): Whether to use mpirun or not. + nproc_per_node (Optional[int]): Number of processes per node """ @@ -1075,10 +1185,8 @@ def _prepare_and_upload_runtime_scripts( if use_torchrun: entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT - if nproc_per_node is not None and nproc_per_node > 0: - entry_point_script = entry_point_script.replace( - "$SM_NPROC_PER_NODE", str(nproc_per_node) - ) + if use_mpirun: + entry_point_script = ENTRYPOINT_MPIRUN_SCRIPT with open(entrypoint_script_path, "w", newline="\n") as file: file.writelines(entry_point_script) @@ -1086,12 +1194,16 @@ def _prepare_and_upload_runtime_scripts( bootstrap_script_path = os.path.join( os.path.dirname(__file__), "runtime_environment", BOOTSTRAP_SCRIPT_NAME ) + mpi_utils_path = os.path.join( + os.path.dirname(__file__), "runtime_environment", MPI_UTILS_SCRIPT_NAME + ) runtime_manager_script_path = os.path.join( os.path.dirname(__file__), "runtime_environment", RUNTIME_MANAGER_SCRIPT_NAME ) # copy runtime scripts to tmpdir shutil.copy2(bootstrap_script_path, bootstrap_scripts) + shutil.copy2(mpi_utils_path, bootstrap_scripts) shutil.copy2(runtime_manager_script_path, bootstrap_scripts) upload_path = S3Uploader.upload( @@ -1118,7 +1230,7 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): s3_kms_key=job_settings.s3_kms_key, sagemaker_session=job_settings.sagemaker_session, use_torchrun=job_settings.use_torchrun, - nproc_per_node=job_settings.nproc_per_node, + use_mpirun=job_settings.use_mpirun, ) input_data_config = [ @@ -1459,6 +1571,35 @@ def _upload_serialized_spark_configuration( return config_file_s3_uri +def _extend_mpirun_to_request( + request_dict: Dict, + job_settings: _JobSettings, +) -> Dict: + """Extend the create training job request with mpirun configuration. + + Args: + request_dict (Dict): create training job request dict. + job_settings (_JobSettings): the job settings. + """ + use_mpirun = job_settings.use_mpirun + instance_count = job_settings.instance_count + + if not use_mpirun: + return request_dict + + if instance_count == 1: + return request_dict + + extended_request = request_dict.copy() + + for input_channel in extended_request["InputDataConfig"]: + s3_data_source = input_channel["DataSource"].get("S3DataSource", None) + if s3_data_source: + s3_data_source["S3DataDistributionType"] = "FullyReplicated" + + return extended_request + + def _extend_torchrun_to_request( request_dict: Dict, job_settings: _JobSettings, diff --git a/src/sagemaker/remote_function/runtime_environment/__init__.py b/src/sagemaker/remote_function/runtime_environment/__init__.py index e69de29bb2..18557a2eb5 100644 --- a/src/sagemaker/remote_function/runtime_environment/__init__.py +++ b/src/sagemaker/remote_function/runtime_environment/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container_drivers directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index 0b0823da77..da7c493ae5 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -22,7 +22,7 @@ import shutil import subprocess import sys -from typing import Dict, Any +from typing import Any, Dict if __package__ is None or __package__ == "": from runtime_environment_manager import ( @@ -271,6 +271,8 @@ def _parse_args(sys_args): parser.add_argument("--pipeline_execution_id", type=str) parser.add_argument("--dependency_settings", type=str) parser.add_argument("--func_step_s3_dir", type=str) + parser.add_argument("--distribution", type=str, default=None) + parser.add_argument("--user_nproc_per_node", type=str, default=None) args, _ = parser.parse_known_args(sys_args) return args @@ -401,6 +403,8 @@ def safe_serialize(data): def set_env( resource_config: Dict[str, Any], + distribution: str = None, + user_nproc_per_node: bool = None, output_file: str = ENV_OUTPUT_FILE, ): """Set environment variables for the training job container. @@ -442,12 +446,15 @@ def set_env( # Misc. env_vars["SM_RESOURCE_CONFIG"] = resource_config - if int(env_vars["SM_NUM_GPUS"]) > 0: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_GPUS"]) - elif int(env_vars["SM_NUM_NEURONS"]) > 0: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_NEURONS"]) + if user_nproc_per_node is not None and int(user_nproc_per_node) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(user_nproc_per_node) else: - env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_CPUS"]) + if int(env_vars["SM_NUM_GPUS"]) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_GPUS"]) + elif int(env_vars["SM_NUM_NEURONS"]) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_NEURONS"]) + else: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_CPUS"]) # All Training Environment Variables env_vars["SM_TRAINING_ENV"] = { @@ -471,18 +478,45 @@ def set_env( "resource_config": env_vars["SM_RESOURCE_CONFIG"], } - instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] - network_interface_name = env_vars.get("SM_NETWORK_INTERFACE_NAME", "eth0") + if distribution and distribution == "torchrun": + logger.info("Distribution: torchrun") + + instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] + network_interface_name = env_vars.get("SM_NETWORK_INTERFACE_NAME", "eth0") + + if instance_type in SM_EFA_NCCL_INSTANCES: + # Enable EFA use + env_vars["FI_PROVIDER"] = "efa" + if instance_type in SM_EFA_RDMA_INSTANCES: + # Use EFA's RDMA functionality for one-sided and two-sided transfer + env_vars["FI_EFA_USE_DEVICE_RDMA"] = "1" + env_vars["RDMAV_FORK_SAFE"] = "1" + env_vars["NCCL_SOCKET_IFNAME"] = str(network_interface_name) + env_vars["NCCL_PROTO"] = "simple" + elif distribution and distribution == "mpirun": + logger.info("Distribution: mpirun") + + env_vars["MASTER_ADDR"] = env_vars["SM_MASTER_ADDR"] + env_vars["MASTER_PORT"] = str(env_vars["SM_MASTER_PORT"]) + + host_list = [ + "{}:{}".format(host, int(env_vars["SM_NPROC_PER_NODE"])) for host in sorted_hosts + ] + env_vars["SM_HOSTS_LIST"] = ",".join(host_list) + + instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] + + if instance_type in SM_EFA_NCCL_INSTANCES: + env_vars["SM_FI_PROVIDER"] = "-x FI_PROVIDER=efa" + env_vars["SM_NCCL_PROTO"] = "-x NCCL_PROTO=simple" + else: + env_vars["SM_FI_PROVIDER"] = "" + env_vars["SM_NCCL_PROTO"] = "" - if instance_type in SM_EFA_NCCL_INSTANCES: - # Enable EFA use - env_vars["FI_PROVIDER"] = "efa" - if instance_type in SM_EFA_RDMA_INSTANCES: - # Use EFA's RDMA functionality for one-sided and two-sided transfer - env_vars["FI_EFA_USE_DEVICE_RDMA"] = "1" - env_vars["RDMAV_FORK_SAFE"] = "1" - env_vars["NCCL_SOCKET_IFNAME"] = str(network_interface_name) - env_vars["NCCL_PROTO"] = "simple" + if instance_type in SM_EFA_RDMA_INSTANCES: + env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "-x FI_EFA_USE_DEVICE_RDMA=1" + else: + env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "" with open(output_file, "w") as f: for key, value in env_vars.items(): @@ -499,12 +533,19 @@ def main(sys_args=None): try: args = _parse_args(sys_args) + + logger.info("Arguments:") + for arg in vars(args): + logger.info("%s=%s", arg, getattr(args, arg)) + client_python_version = args.client_python_version client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version job_conda_env = args.job_conda_env pipeline_execution_id = args.pipeline_execution_id dependency_settings = _DependencySettings.from_string(args.dependency_settings) func_step_workspace = args.func_step_s3_dir + distribution = args.distribution + user_nproc_per_node = args.user_nproc_per_node conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") @@ -539,7 +580,11 @@ def main(sys_args=None): logger.info("Found %s", RESOURCE_CONFIG) with open(RESOURCE_CONFIG, "r") as f: resource_config = json.load(f) - set_env(resource_config=resource_config) + set_env( + resource_config=resource_config, + distribution=distribution, + user_nproc_per_node=user_nproc_per_node, + ) except (json.JSONDecodeError, FileNotFoundError) as e: # Optionally, you might want to log this error logger.info("ERROR: Error processing %s: %s", RESOURCE_CONFIG, str(e)) diff --git a/src/sagemaker/remote_function/runtime_environment/mpi_utils_remote.py b/src/sagemaker/remote_function/runtime_environment/mpi_utils_remote.py new file mode 100644 index 0000000000..6f3897fb0b --- /dev/null +++ b/src/sagemaker/remote_function/runtime_environment/mpi_utils_remote.py @@ -0,0 +1,252 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""An utils function for runtime environment. This must be kept independent of SageMaker PySDK""" +from __future__ import absolute_import + +import argparse +import json +import os +import subprocess +import sys +import time +from typing import List + +import paramiko + +if __package__ is None or __package__ == "": + from runtime_environment_manager import ( + get_logger, + ) +else: + from sagemaker.remote_function.runtime_environment.runtime_environment_manager import ( + get_logger, + ) + +SUCCESS_EXIT_CODE = 0 +DEFAULT_FAILURE_CODE = 1 + +FINISHED_STATUS_FILE = "/tmp/done.algo-1" +READY_FILE = "/tmp/ready.%s" +DEFAULT_SSH_PORT = 22 + +FAILURE_REASON_PATH = "/opt/ml/output/failure" +FINISHED_STATUS_FILE = "/tmp/done.algo-1" + +logger = get_logger() + + +class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy): + """Class to handle host key policy for SageMaker distributed training SSH connections. + + Example: + >>> client = paramiko.SSHClient() + >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) + >>> # Will succeed for SageMaker algorithm containers + >>> client.connect('algo-1234.internal') + >>> # Will raise SSHException for other unknown hosts + >>> client.connect('unknown-host') # raises SSHException + """ + + def missing_host_key(self, client, hostname, key): + """Accept host keys for algo-* hostnames, reject others. + + Args: + client: The SSHClient instance + hostname: The hostname attempting to connect + key: The host key + Raises: + paramiko.SSHException: If hostname doesn't match algo-* pattern + """ + if hostname.startswith("algo-"): + client.get_host_keys().add(hostname, key.get_name(), key) + return + raise paramiko.SSHException(f"Unknown host key for {hostname}") + + +def _parse_args(sys_args): + """Parses CLI arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--job_ended", type=str, default="0") + args, _ = parser.parse_known_args(sys_args) + return args + + +def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: + """Check if the connection to the provided host and port is possible.""" + try: + with paramiko.SSHClient() as client: + client.load_system_host_keys() + client.set_missing_host_key_policy(CustomHostKeyPolicy()) + client.connect(host, port=port) + logger.info("Can connect to host %s", host) + return True + except Exception as e: # pylint: disable=W0703 + logger.info("Cannot connect to host %s", host) + logger.debug("Connection failed with exception: %s", e) + return False + + +def _write_file_to_host(host: str, status_file: str) -> bool: + """Write the a file to the provided host.""" + try: + logger.info("Writing %s to %s", status_file, host) + subprocess.run( + ["ssh", host, "touch", f"{status_file}"], + capture_output=True, + text=True, + check=True, + ) + logger.info("Finished writing status file") + return True + except subprocess.CalledProcessError: + logger.info("Cannot connect to %s", host) + return False + + +def _write_failure_reason_file(failure_msg): + """Create a file 'failure' with failure reason written if bootstrap runtime env failed. + + See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html + Args: + failure_msg: The content of file to be written. + """ + if not os.path.exists(FAILURE_REASON_PATH): + with open(FAILURE_REASON_PATH, "w") as f: + f.write("RuntimeEnvironmentError: " + failure_msg) + + +def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300): + """Worker nodes wait until they can connect to the master node.""" + start_time = time.time() + while True: + logger.info("Worker is attempting to connect to the master node %s...", master_host) + if _can_connect(master_host, port): + logger.info("Worker can connect to master node %s.", master_host) + break + if time.time() - start_time > timeout: + raise TimeoutError("Timed out waiting for master %s to be reachable." % master_host) + + time.sleep(5) # Wait for 5 seconds before trying again + + +def _wait_for_status_file(status_file: str): + """Wait for the status file to be created.""" + logger.info("Waiting for status file %s", status_file) + while not os.path.exists(status_file): + time.sleep(30) + logger.info("Found status file %s", status_file) + + +def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300): + """Master node waits until it can connect to all worker nodes.""" + start_time = time.time() + if not worker_hosts: + logger.info("No worker nodes to connect to.") + return + + while True: + logger.info("Master is attempting to connect to all workers...") + all_workers_connected = all( + _can_connect(worker, port) and os.path.exists(READY_FILE % worker) + for worker in worker_hosts + ) + + if all_workers_connected: + logger.info("Master can connect to all worker nodes.") + break + if time.time() - start_time > timeout: + raise TimeoutError("Timed out waiting for workers to be reachable.") + + time.sleep(5) # Wait for 5 seconds before trying again + + +def bootstrap_master_node(worker_hosts: List[str]): + """Bootstrap the master node.""" + logger.info("Bootstrapping master node...") + _wait_for_workers(worker_hosts) + + +def bootstrap_worker_node( + master_host: str, current_host: str, status_file: str = FINISHED_STATUS_FILE +): + """Bootstrap the worker nodes.""" + logger.info("Bootstrapping worker node...") + _wait_for_master(master_host) + _write_file_to_host(master_host, READY_FILE % current_host) + _wait_for_status_file(status_file) + + +def start_sshd_daemon(): + """Start the SSH daemon on the current node.""" + sshd_executable = "/usr/sbin/sshd" + + if not os.path.exists(sshd_executable): + raise RuntimeError("SSH daemon not found.") + + # Start the sshd in daemon mode (-D) + subprocess.Popen([sshd_executable, "-D"]) + logger.info("Started SSH daemon.") + + +def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE): + """Write the status file to all worker nodes.""" + for worker in worker_hosts: + retry = 0 + while not _write_file_to_host(worker, status_file): + time.sleep(5) + retry += 1 + if retry > 5: + raise TimeoutError("Timed out waiting for %s to be reachable." % worker) + logger.info("Retrying to write status file to %s", worker) + + +def main(sys_args=None): + """Entry point for bootstrap script""" + try: + args = _parse_args(sys_args) + + job_ended = args.job_ended + + main_host = os.environ["SM_MASTER_ADDR"] + current_host = os.environ["SM_CURRENT_HOST"] + + if job_ended == "0": + logger.info("Job is running, bootstrapping nodes") + + start_sshd_daemon() + + if current_host != main_host: + bootstrap_worker_node(main_host, current_host) + else: + sorted_hosts = json.loads(os.environ["SM_HOSTS"]) + worker_hosts = [host for host in sorted_hosts if host != main_host] + + bootstrap_master_node(worker_hosts) + else: + logger.info("Job ended, writing status file to workers") + + if current_host == main_host: + sorted_hosts = json.loads(os.environ["SM_HOSTS"]) + worker_hosts = [host for host in sorted_hosts if host != main_host] + + write_status_file_to_workers(worker_hosts) + except Exception as e: # pylint: disable=broad-except + logger.exception("Error encountered while bootstrapping runtime environment: %s", e) + + _write_failure_reason_file(str(e)) + + sys.exit(DEFAULT_FAILURE_CODE) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/src/sagemaker/rl/estimator.py b/src/sagemaker/rl/estimator.py index e262604ac3..f1e1407633 100644 --- a/src/sagemaker/rl/estimator.py +++ b/src/sagemaker/rl/estimator.py @@ -120,8 +120,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on diff --git a/src/sagemaker/s3_utils.py b/src/sagemaker/s3_utils.py index e53cdbe02a..f59c8a299f 100644 --- a/src/sagemaker/s3_utils.py +++ b/src/sagemaker/s3_utils.py @@ -45,6 +45,19 @@ def parse_s3_url(url): return parsed_url.netloc, parsed_url.path.lstrip("/") +def is_s3_url(url): + """Returns True if url is an s3 url, False if not + + Args: + url (str): + + Returns: + bool: + """ + parsed_url = urlparse(url) + return parsed_url.scheme == "s3" + + def s3_path_join(*args, with_end_slash: bool = False): """Returns the arguments joined by a slash ("/"), similar to ``os.path.join()`` (on Unix). diff --git a/src/sagemaker/serializer_utils.py b/src/sagemaker/serializer_utils.py new file mode 100644 index 0000000000..96a931084c --- /dev/null +++ b/src/sagemaker/serializer_utils.py @@ -0,0 +1,222 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Placeholder docstring""" +from __future__ import absolute_import + +import logging +import struct +import sys + +import numpy as np + +from sagemaker.amazon.record_pb2 import Record +from sagemaker.utils import DeferredError + + +def _write_feature_tensor(resolved_type, record, vector): + """Placeholder Docstring""" + if resolved_type == "Int32": + record.features["values"].int32_tensor.values.extend(vector) + elif resolved_type == "Float64": + record.features["values"].float64_tensor.values.extend(vector) + elif resolved_type == "Float32": + record.features["values"].float32_tensor.values.extend(vector) + + +def _write_label_tensor(resolved_type, record, scalar): + """Placeholder Docstring""" + if resolved_type == "Int32": + record.label["values"].int32_tensor.values.extend([scalar]) + elif resolved_type == "Float64": + record.label["values"].float64_tensor.values.extend([scalar]) + elif resolved_type == "Float32": + record.label["values"].float32_tensor.values.extend([scalar]) + + +def _write_keys_tensor(resolved_type, record, vector): + """Placeholder Docstring""" + if resolved_type == "Int32": + record.features["values"].int32_tensor.keys.extend(vector) + elif resolved_type == "Float64": + record.features["values"].float64_tensor.keys.extend(vector) + elif resolved_type == "Float32": + record.features["values"].float32_tensor.keys.extend(vector) + + +def _write_shape(resolved_type, record, scalar): + """Placeholder Docstring""" + if resolved_type == "Int32": + record.features["values"].int32_tensor.shape.extend([scalar]) + elif resolved_type == "Float64": + record.features["values"].float64_tensor.shape.extend([scalar]) + elif resolved_type == "Float32": + record.features["values"].float32_tensor.shape.extend([scalar]) + + +def write_numpy_to_dense_tensor(file, array, labels=None): + """Writes a numpy array to a dense tensor + + Args: + file: + array: + labels: + """ + + # Validate shape of array and labels, resolve array and label types + if not len(array.shape) == 2: + raise ValueError("Array must be a Matrix") + if labels is not None: + if not len(labels.shape) == 1: + raise ValueError("Labels must be a Vector") + if labels.shape[0] not in array.shape: + raise ValueError( + "Label shape {} not compatible with array shape {}".format( + labels.shape, array.shape + ) + ) + resolved_label_type = _resolve_type(labels.dtype) + resolved_type = _resolve_type(array.dtype) + + # Write each vector in array into a Record in the file object + record = Record() + for index, vector in enumerate(array): + record.Clear() + _write_feature_tensor(resolved_type, record, vector) + if labels is not None: + _write_label_tensor(resolved_label_type, record, labels[index]) + _write_recordio(file, record.SerializeToString()) + + +def write_spmatrix_to_sparse_tensor(file, array, labels=None): + """Writes a scipy sparse matrix to a sparse tensor + + Args: + file: + array: + labels: + """ + try: + import scipy + except ImportError as e: + logging.warning( + "scipy failed to import. Sparse matrix functions will be impaired or broken." + ) + # Any subsequent attempt to use scipy will raise the ImportError + scipy = DeferredError(e) + + if not scipy.sparse.issparse(array): + raise TypeError("Array must be sparse") + + # Validate shape of array and labels, resolve array and label types + if not len(array.shape) == 2: + raise ValueError("Array must be a Matrix") + if labels is not None: + if not len(labels.shape) == 1: + raise ValueError("Labels must be a Vector") + if labels.shape[0] not in array.shape: + raise ValueError( + "Label shape {} not compatible with array shape {}".format( + labels.shape, array.shape + ) + ) + resolved_label_type = _resolve_type(labels.dtype) + resolved_type = _resolve_type(array.dtype) + + csr_array = array.tocsr() + n_rows, n_cols = csr_array.shape + + record = Record() + for row_idx in range(n_rows): + record.Clear() + row = csr_array.getrow(row_idx) + # Write values + _write_feature_tensor(resolved_type, record, row.data) + # Write keys + _write_keys_tensor(resolved_type, record, row.indices.astype(np.uint64)) + + # Write labels + if labels is not None: + _write_label_tensor(resolved_label_type, record, labels[row_idx]) + + # Write shape + _write_shape(resolved_type, record, n_cols) + + _write_recordio(file, record.SerializeToString()) + + +def read_records(file): + """Eagerly read a collection of amazon Record protobuf objects from file. + + Args: + file: + """ + records = [] + for record_data in read_recordio(file): + record = Record() + record.ParseFromString(record_data) + records.append(record) + return records + + +# MXNet requires recordio records have length in bytes that's a multiple of 4 +# This sets up padding bytes to append to the end of the record, for diferent +# amounts of padding required. +padding = {} +for amount in range(4): + if sys.version_info >= (3,): + padding[amount] = bytes([0x00 for _ in range(amount)]) + else: + padding[amount] = bytearray([0x00 for _ in range(amount)]) + +_kmagic = 0xCED7230A + + +def _write_recordio(f, data): + """Writes a single data point as a RecordIO record to the given file. + + Args: + f: + data: + """ + length = len(data) + f.write(struct.pack("I", _kmagic)) + f.write(struct.pack("I", length)) + pad = (((length + 3) >> 2) << 2) - length + f.write(data) + f.write(padding[pad]) + + +def read_recordio(f): + """Placeholder Docstring""" + while True: + try: + (read_kmagic,) = struct.unpack("I", f.read(4)) + except struct.error: + return + assert read_kmagic == _kmagic + (len_record,) = struct.unpack("I", f.read(4)) + pad = (((len_record + 3) >> 2) << 2) - len_record + yield f.read(len_record) + if pad: + f.read(pad) + + +def _resolve_type(dtype): + """Placeholder Docstring""" + if dtype == np.dtype(int): + return "Int32" + if dtype == np.dtype(float): + return "Float64" + if dtype == np.dtype("float32"): + return "Float32" + raise ValueError("Unsupported dtype {} on array".format(dtype)) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index ef502dc6f3..be46be0856 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -30,8 +30,10 @@ SparseMatrixSerializer, TorchTensorSerializer, StringSerializer, + RecordSerializer, ) +from sagemaker.deprecations import deprecated_class from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartModelType @@ -152,3 +154,6 @@ def retrieve_default( model_type=model_type, config_name=config_name, ) + + +numpy_to_record_serializer = deprecated_class(RecordSerializer, "numpy_to_record_serializer") diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 37a77179cb..86a6875721 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -17,7 +17,7 @@ import re from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Type, Any, List, Dict, Optional +from typing import Type, Any, List, Dict, Optional, Tuple import logging from botocore.exceptions import ClientError @@ -82,6 +82,7 @@ ModelServer.DJL_SERVING, ModelServer.TGI, } +_JS_MINIMUM_VERSION_IMAGE = "{}:0.31.0-lmi13.0.0-cu124" logger = logging.getLogger(__name__) @@ -829,7 +830,13 @@ def _optimize_for_jumpstart( self.pysdk_model._enable_network_isolation = False if quantization_config or sharding_config or is_compilation: - return create_optimization_job_args + # only apply default image for vLLM usecases. + # vLLM does not support compilation for now so skip on compilation + return ( + create_optimization_job_args + if is_compilation + else self._set_optimization_image_default(create_optimization_job_args) + ) return None def _is_gated_model(self, model=None) -> bool: @@ -986,3 +993,105 @@ def _get_neuron_model_env_vars( ) return job_model.env return None + + def _set_optimization_image_default( + self, create_optimization_job_args: Dict[str, Any] + ) -> Dict[str, Any]: + """Defaults the optimization image to the JumpStart deployment config default + + Args: + create_optimization_job_args (Dict[str, Any]): create optimization job request + + Returns: + Dict[str, Any]: create optimization job request with image uri default + """ + default_image = self._get_default_vllm_image(self.pysdk_model.init_kwargs["image_uri"]) + + # find the latest vLLM image version + for optimization_config in create_optimization_job_args.get("OptimizationConfigs"): + if optimization_config.get("ModelQuantizationConfig"): + model_quantization_config = optimization_config.get("ModelQuantizationConfig") + provided_image = model_quantization_config.get("Image") + if provided_image and self._get_latest_lmi_version_from_list( + default_image, provided_image + ): + default_image = provided_image + if optimization_config.get("ModelShardingConfig"): + model_sharding_config = optimization_config.get("ModelShardingConfig") + provided_image = model_sharding_config.get("Image") + if provided_image and self._get_latest_lmi_version_from_list( + default_image, provided_image + ): + default_image = provided_image + + # default to latest vLLM version + for optimization_config in create_optimization_job_args.get("OptimizationConfigs"): + if optimization_config.get("ModelQuantizationConfig") is not None: + optimization_config.get("ModelQuantizationConfig")["Image"] = default_image + if optimization_config.get("ModelShardingConfig") is not None: + optimization_config.get("ModelShardingConfig")["Image"] = default_image + + logger.info("Defaulting to %s image for optimization job", default_image) + + return create_optimization_job_args + + def _get_default_vllm_image(self, image: str) -> bool: + """Ensures the minimum working image version for vLLM enabled optimization techniques + + Args: + image (str): JumpStart provided default image + + Returns: + str: minimum working image version + """ + dlc_name, _ = image.split(":") + major_version_number, _, _ = self._parse_lmi_version(image) + + if major_version_number < self._parse_lmi_version(_JS_MINIMUM_VERSION_IMAGE)[0]: + minimum_version_default = _JS_MINIMUM_VERSION_IMAGE.format(dlc_name) + return minimum_version_default + return image + + def _get_latest_lmi_version_from_list(self, version: str, version_to_compare: str) -> bool: + """LMI version comparator + + Args: + version (str): current version + version_to_compare (str): version to compare to + + Returns: + bool: if version_to_compare larger or equal to version + """ + parse_lmi_version = self._parse_lmi_version(version) + parse_lmi_version_to_compare = self._parse_lmi_version(version_to_compare) + + # Check major version + if parse_lmi_version_to_compare[0] > parse_lmi_version[0]: + return True + # Check minor version + if parse_lmi_version_to_compare[0] == parse_lmi_version[0]: + if parse_lmi_version_to_compare[1] > parse_lmi_version[1]: + return True + if parse_lmi_version_to_compare[1] == parse_lmi_version[1]: + # Check patch version + if parse_lmi_version_to_compare[2] >= parse_lmi_version[2]: + return True + return False + return False + return False + + def _parse_lmi_version(self, image: str) -> Tuple[int, int, int]: + """Parse out LMI version + + Args: + image (str): image to parse version out of + + Returns: + Tuple[int, int, int]: LMI version split into major, minor, patch + """ + _, dlc_tag = image.split(":") + _, lmi_version, _ = dlc_tag.split("-") + major_version, minor_version, patch_version = lmi_version.split(".") + major_version_number = major_version[3:] + + return (int(major_version_number), int(minor_version), int(patch_version)) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index a7a518105c..5e2175a954 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -11,7 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Holds the ModelBuilder class and the ModelServer enum.""" -from __future__ import absolute_import +from __future__ import absolute_import, annotations import importlib.util import json @@ -24,6 +24,7 @@ from pathlib import Path +from botocore.exceptions import ClientError from sagemaker_core.main.resources import TrainingJob from sagemaker.transformer import Transformer @@ -37,6 +38,7 @@ from sagemaker.s3 import S3Downloader from sagemaker import Session from sagemaker.model import Model +from sagemaker.jumpstart.model import JumpStartModel from sagemaker.base_predictor import PredictorBase from sagemaker.serializers import NumpySerializer, TorchTensorSerializer from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer @@ -75,6 +77,7 @@ ) from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.spec.inference_base import CustomOrchestrator, AsyncCustomOrchestrator from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model @@ -102,6 +105,7 @@ _get_model_base, ) from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve +from sagemaker.serve.model_server.smd.prepare import prepare_for_smd from sagemaker.serve.model_server.triton.triton_builder import Triton from sagemaker.serve.utils.telemetry_logger import _capture_telemetry from sagemaker.serve.utils.types import ModelServer, ModelHub @@ -131,6 +135,7 @@ ModelServer.MMS, ModelServer.TGI, ModelServer.TEI, + ModelServer.SMD, } @@ -220,6 +225,18 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, available for providing s3 path to fine-tuned model artifacts. ``FINE_TUNING_JOB_NAME`` is available for providing fine-tuned job name. Both ``FINE_TUNING_MODEL_PATH`` and ``FINE_TUNING_JOB_NAME`` are mutually exclusive. + inference_component_name (Optional[str]): The name for an inference component + created from this ModelBuilder instance. This or ``resource_requirements`` must be set + to denote that this instance refers to an inference component. + modelbuilder_list: Optional[List[ModelBuilder]] = List of ModelBuilder objects which + can be built in bulk and subsequently deployed in bulk. Currently only supports + deployments for inference components. + resource_requirements: Optional[ResourceRequirements] = Defines the compute resources + allocated to run the model assigned to the inference component. This or + ``inference_component_name`` must be set to denote that this instance refers + to an inference component. If ``inference_component_name`` is set but this is not and a + JumpStart model ID is specified, pre-benchmarked deployment configs will attempt to be + retrieved for the model. """ model_path: Optional[str] = field( @@ -233,7 +250,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, default=None, metadata={"help": "Define sagemaker session for execution"} ) name: Optional[str] = field( - default="model-name-" + uuid.uuid1().hex, + default_factory=lambda: "model-name-" + uuid.uuid1().hex, metadata={"help": "Define the model name"}, ) mode: Optional[Mode] = field( @@ -320,6 +337,23 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, "in the Hub, Adding unsupported task types will throw an exception." }, ) + inference_component_name: Optional[str] = field( + default=None, + metadata={ + "help": "Defines the name for an Inference Component created from this ModelBuilder." + }, + ) + modelbuilder_list: Optional[List[ModelBuilder]] = field( + default=None, + metadata={"help": "Defines a list of ModelBuilder objects."}, + ) + resource_requirements: Optional[ResourceRequirements] = field( + default=None, + metadata={ + "help": "Defines the compute resources allocated to run the model assigned" + " to the inference component." + }, + ) def _save_model_inference_spec(self): """Placeholder docstring""" @@ -465,7 +499,7 @@ def _get_client_translators(self): elif self.schema_builder: serializer = self.schema_builder.input_serializer else: - raise Exception("Cannot serialize") + raise Exception("Cannot serialize. Try providing a SchemaBuilder if not present.") deserializer = None if self.accept_type == "application/json": @@ -477,7 +511,7 @@ def _get_client_translators(self): elif self.schema_builder: deserializer = self.schema_builder.output_deserializer else: - raise Exception("Cannot deserialize") + raise Exception("Cannot deserialize. Try providing a SchemaBuilder if not present.") return serializer, deserializer @@ -562,6 +596,83 @@ def _model_builder_deploy_model_package_wrapper(self, *args, **kwargs): self.pysdk_model.model_package_arn = None return predictor + def _deploy_for_ic( + self, + *args, + ic_data: Dict[str, Any], + container_timeout_in_seconds: int = 300, + model_data_download_timeout: int = 3600, + instance_type: Optional[str] = None, + initial_instance_count: Optional[int] = None, + endpoint_name: Optional[str] = None, + **kwargs, + ) -> Predictor: + """Creates an Inference Component from a ModelBuilder.""" + ic_name = ic_data.get("Name", None) + model = ic_data.get("Model", None) + resource_requirements = ic_data.get("ResourceRequirements", {}) + + # Ensure resource requirements are set for non-JumpStart models + if not resource_requirements: + raise ValueError( + f"Cannot create/update inference component {ic_name} without resource requirements." + ) + + # Check if the Inference Component exists + if ic_name and self._does_ic_exist(ic_name=ic_name): + logger.info("Updating Inference Component %s as it already exists.", ic_name) + + # Create spec for updating the IC + startup_parameters = {} + if model_data_download_timeout is not None: + startup_parameters["ModelDataDownloadTimeoutInSeconds"] = ( + model_data_download_timeout + ) + if container_timeout_in_seconds is not None: + startup_parameters["ContainerStartupHealthCheckTimeoutInSeconds"] = ( + container_timeout_in_seconds + ) + compute_rr = resource_requirements.get_compute_resource_requirements() + inference_component_spec = { + "ModelName": self.name, + "StartupParameters": startup_parameters, + "ComputeResourceRequirements": compute_rr, + } + runtime_config = {"CopyCount": resource_requirements.copy_count} + response = self.sagemaker_session.update_inference_component( + inference_component_name=ic_name, + specification=inference_component_spec, + runtime_config=runtime_config, + ) + return Predictor(endpoint_name=response.get("EndpointName"), component_name=ic_name) + else: + kwargs.update( + { + "resources": resource_requirements, + "endpoint_type": EndpointType.INFERENCE_COMPONENT_BASED, + "inference_component_name": ic_name, + "endpoint_logging": False, + } + ) + return model.deploy( + *args, + container_startup_health_check_timeout=container_timeout_in_seconds, + initial_instance_count=initial_instance_count, + instance_type=instance_type, + mode=Mode.SAGEMAKER_ENDPOINT, + endpoint_name=endpoint_name, + **kwargs, + ) + + def _does_ic_exist(self, ic_name: str) -> bool: + """Returns true if an Inference Component exists with the given name.""" + try: + self.sagemaker_session.describe_inference_component(inference_component_name=ic_name) + return True + except ClientError as e: + msg = e.response["Error"]["Message"] + return "Could not find inference component" not in msg + @_capture_telemetry("torchserve.deploy") def _model_builder_deploy_wrapper( self, @@ -615,6 +726,13 @@ def _model_builder_deploy_wrapper( if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True + + if "inference_component_name" not in kwargs and self.inference_component_name: + kwargs["inference_component_name"] = self.inference_component_name + + if "resources" not in kwargs and self.resource_requirements: + kwargs["resources"] = self.resource_requirements + kwargs.pop("mode", None) self.pysdk_model.role = kwargs.pop("role", self.pysdk_model.role) predictor = self._original_deploy( @@ -673,6 +791,24 @@ def _build_for_torchserve(self) -> Type[Model]: self.model = self._create_model() return self.model + def _build_for_smd(self) -> Type[Model]: + """Build the model for SageMaker Distribution""" + self._save_model_inference_spec() + + if self.mode != Mode.IN_PROCESS: + self._auto_detect_container() + + self.secret_key = prepare_for_smd( + model_path=self.model_path, + shared_libs=self.shared_libs, + dependencies=self.dependencies, + inference_spec=self.inference_spec, + ) + + self._prepare_for_mode() + self.model = self._create_model() + return self.model + def _user_agent_decorator(self, func): """Placeholder docstring""" @@ -854,13 +990,225 @@ def _collect_estimator_model_telemetry(self): """Dummy method to collect telemetry for estimator handshake""" return + def build( + self, + mode: Type[Mode] = None, + role_arn: str = None, + sagemaker_session: Optional[Session] = None, + ) -> Union[ModelBuilder, Type[Model]]: + """Creates deployable ``Model`` instances with all provided ``ModelBuilder`` objects. + + Args: + mode (Type[Mode], optional): The mode. Defaults to ``None``. + role_arn (str, optional): The IAM role arn. Defaults to ``None``. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Union[ModelBuilder, Type[Model]]: A deployable ``ModelBuilder`` object if multiple + ``ModelBuilders`` were built, or a deployable ``Model`` object. + """ + if role_arn: + self.role_arn = role_arn + self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() + + deployables = {} + + if not self.modelbuilder_list and not isinstance( + self.inference_spec, (CustomOrchestrator, AsyncCustomOrchestrator) + ): + self.serve_settings = self._get_serve_setting() + return self._build_single_modelbuilder( + mode=mode, + role_arn=self.role_arn, + sagemaker_session=sagemaker_session, + ) + + # Multi-ModelBuilder case: deploy + built_ic_models = [] + if self.modelbuilder_list: + logger.info("Detected ModelBuilders in modelbuilder_list.") + for mb in self.modelbuilder_list: + if mb.mode == Mode.IN_PROCESS or mb.mode == Mode.LOCAL_CONTAINER: + raise ValueError( + "Bulk ModelBuilder building is only supported for SageMaker Endpoint Mode." + ) + + if (not mb.resource_requirements and not mb.inference_component_name) and ( + not mb.inference_spec + or not isinstance( + mb.inference_spec, (CustomOrchestrator, AsyncCustomOrchestrator) + ) + ): + raise ValueError( + "Bulk ModelBuilder building is only supported for Inference Components " + + "and custom orchestrators." + ) + + for mb in self.modelbuilder_list: + # Custom orchestrator definition found in inference_spec + mb.serve_settings = mb._get_serve_setting() + # Build for Inference Component + logger.info("Building ModelBuilder %s.", mb.name) + # Get JS deployment configs if ResourceRequirements not set + + mb = mb._get_ic_resource_requirements(mb=mb) + + built_model = mb._build_single_modelbuilder( + role_arn=self.role_arn, sagemaker_session=self.sagemaker_session + ) + built_ic_models.append( + { + "Name": mb.inference_component_name, + "ResourceRequirements": mb.resource_requirements, + "Model": built_model, + } + ) + logger.info( + "=====================Build for %s complete.===================", + mb.model, + ) + deployables["InferenceComponents"] = built_ic_models + + if isinstance(self.inference_spec, (CustomOrchestrator, AsyncCustomOrchestrator)): + logger.info("Building custom orchestrator.") + if self.mode == Mode.IN_PROCESS or self.mode == Mode.LOCAL_CONTAINER: + raise ValueError( + "Custom orchestrator deployment is only supported for" + "SageMaker Endpoint Mode." + ) + self.serve_settings = self._get_serve_setting() + cpu_or_gpu_instance = self._get_processing_unit() + self.image_uri = self._get_smd_image_uri(processing_unit=cpu_or_gpu_instance) + self.model_server = ModelServer.SMD + built_orchestrator = self._build_single_modelbuilder( + mode=Mode.SAGEMAKER_ENDPOINT, + role_arn=role_arn, + sagemaker_session=sagemaker_session, + ) + if not self.resource_requirements: + logger.info( + "Custom orchestrator resource_requirements not found. " + "Building as a SageMaker Endpoint instead of Inference Component." + ) + deployables["CustomOrchestrator"] = { + "Mode": "Endpoint", + "Model": built_orchestrator, + } + else: + # Network isolation of ICs on an endpoint must be consistent + if built_ic_models: + if ( + self.dependencies["auto"] + or "requirements" in self.dependencies + or "custom" in self.dependencies + ): + logger.warning( + "Custom orchestrator network isolation must be False when dependencies " + "are specified or using autocapture. To enable network isolation, " + "package all dependencies in the container or model artifacts " + "ahead of time." + ) + built_orchestrator._enable_network_isolation = False + for model in built_ic_models: + model["Model"]._enable_network_isolation = False + deployables["CustomOrchestrator"] = { + "Name": self.inference_component_name, + "Mode": "InferenceComponent", + "ResourceRequirements": self.resource_requirements, + "Model": built_orchestrator, + } + + logger.info( + "=====================Custom orchestrator build complete.===================", + ) + + self._deployables = deployables + return self + + def _get_processing_unit(self): + """Detects if the resource requirements are intended for a CPU or GPU instance.""" + # Assume custom orchestrator will be deployed as an endpoint to a CPU instance + if not self.resource_requirements or not self.resource_requirements.num_accelerators: + return "cpu" + for ic in self.modelbuilder_list or []: + if ic.resource_requirements.num_accelerators > 0: + return "gpu" + if self.resource_requirements.num_accelerators > 0: + return "gpu" + + return "cpu" + + def _get_ic_resource_requirements(self, mb: ModelBuilder = None) -> ModelBuilder: + """Attempts fetching pre-benchmarked resource requirements for the MB from JumpStart.""" + if mb._is_jumpstart_model_id() and not mb.resource_requirements: + js_model = JumpStartModel(model_id=mb.model) + deployment_configs = js_model.list_deployment_configs() + if not deployment_configs: + raise ValueError( + "No resource requirements were provided for Inference Component " + f"{mb.inference_component_name} and no default deployment" + " configs were found in JumpStart." + ) + compute_requirements = ( + deployment_configs[0].get("DeploymentArgs").get("ComputeResourceRequirements") + ) + logger.info("Retrieved pre-benchmarked deployment configurations from JumpStart.") + mb.resource_requirements = ResourceRequirements( + requests={ + "memory": compute_requirements["MinMemoryRequiredInMb"], + "num_accelerators": compute_requirements.get( + "NumberOfAcceleratorDevicesRequired", None + ), + "copies": 1, + "num_cpus": compute_requirements.get("NumberOfCpuCoresRequired", None), + }, + limits={"memory": compute_requirements.get("MaxMemoryRequiredInMb", None)}, + ) + + return mb + + @_capture_telemetry("build_custom_orchestrator") + def _get_smd_image_uri(self, processing_unit: str = None) -> str: + """Gets the SMD Inference Image URI. + + Returns: + str: SMD Inference Image URI. + """ + from sagemaker import image_uris + import sys + + self.sagemaker_session = self.sagemaker_session or Session() + from packaging.version import Version + + formatted_py_version = f"py{sys.version_info.major}{sys.version_info.minor}" + if Version(f"{sys.version_info.major}{sys.version_info.minor}") < Version("3.12"): + raise ValueError( + f"Found Python version {formatted_py_version} but" + f"Custom orchestrator deployment requires Python version >= 3.12." + ) + + INSTANCE_TYPES = {"cpu": "ml.c5.xlarge", "gpu": "ml.g5.4xlarge"} + + logger.info("Finding SMD inference image URI for a %s instance.", processing_unit) + + smd_uri = image_uris.retrieve( + framework="sagemaker-distribution", + image_scope="inference", + instance_type=INSTANCE_TYPES[processing_unit], + region=self.sagemaker_session.boto_region_name, + ) + logger.info("Found compatible image %s", smd_uri) + return smd_uri + # Model Builder is a class to build the model for deployment. # It supports three modes of deployment # 1/ SageMaker Endpoint # 2/ Local launch with container # 3/ In process mode with Transformers server in beta release @_capture_telemetry("ModelBuilder.build") - def build( # pylint: disable=R0911 + def _build_single_modelbuilder( # pylint: disable=R0911 self, mode: Type[Mode] = None, role_arn: str = None, @@ -1039,6 +1387,9 @@ def _build_for_model_server(self): # pylint: disable=R0911, R1710 if self.model_server == ModelServer.MMS: return self._build_for_transformers() + if self.model_server == ModelServer.SMD: + return self._build_for_smd() + @_capture_telemetry("ModelBuilder.save") def save( self, @@ -1593,6 +1944,8 @@ def _optimize_prepare_for_hf(self): def deploy( self, endpoint_name: str = None, + container_timeout_in_second: int = 300, + instance_type: str = None, initial_instance_count: Optional[int] = 1, inference_config: Optional[ Union[ @@ -1602,7 +1955,10 @@ def deploy( ResourceRequirements, ] ] = None, - ) -> Union[Predictor, Transformer]: + custom_orchestrator_instance_type: str = None, + custom_orchestrator_initial_instance_count: int = None, + **kwargs, + ) -> Union[Predictor, Transformer, List[Predictor]]: """Deploys the built Model. Depending on the type of config provided, this function will call deployment accordingly. @@ -1619,50 +1975,106 @@ def deploy( Transformer for Batch Deployments Predictors for all others """ - if not hasattr(self, "built_model"): - raise ValueError("Model Needs to be built before deploying") - endpoint_name = unique_name_from_base(endpoint_name) - if not inference_config: # Real-time Deployment - return self.built_model.deploy( - instance_type=self.instance_type, - initial_instance_count=initial_instance_count, - endpoint_name=endpoint_name, - ) + if not hasattr(self, "built_model") and not hasattr(self, "_deployables"): + raise ValueError("Model needs to be built before deploying") + endpoint_name = endpoint_name or unique_name_from_base("endpoint-name") + + if not hasattr(self, "_deployables"): + if not inference_config: # Real-time Deployment + return self.built_model.deploy( + instance_type=self.instance_type, + initial_instance_count=initial_instance_count, + endpoint_name=endpoint_name, + ) - if isinstance(inference_config, ServerlessInferenceConfig): - return self.built_model.deploy( - serverless_inference_config=inference_config, - endpoint_name=endpoint_name, - ) + if isinstance(inference_config, ServerlessInferenceConfig): + return self.built_model.deploy( + serverless_inference_config=inference_config, + endpoint_name=endpoint_name, + ) - if isinstance(inference_config, AsyncInferenceConfig): - return self.built_model.deploy( - instance_type=self.instance_type, - initial_instance_count=initial_instance_count, - async_inference_config=inference_config, - endpoint_name=endpoint_name, - ) + if isinstance(inference_config, AsyncInferenceConfig): + return self.built_model.deploy( + instance_type=self.instance_type, + initial_instance_count=initial_instance_count, + async_inference_config=inference_config, + endpoint_name=endpoint_name, + ) - if isinstance(inference_config, BatchTransformInferenceConfig): - transformer = self.built_model.transformer( - instance_type=inference_config.instance_type, - output_path=inference_config.output_path, - instance_count=inference_config.instance_count, - ) - return transformer + if isinstance(inference_config, BatchTransformInferenceConfig): + transformer = self.built_model.transformer( + instance_type=inference_config.instance_type, + output_path=inference_config.output_path, + instance_count=inference_config.instance_count, + ) + return transformer + + if isinstance(inference_config, ResourceRequirements): + # Multi Model and MultiContainer endpoints with Inference Component + return self.built_model.deploy( + instance_type=self.instance_type, + mode=Mode.SAGEMAKER_ENDPOINT, + endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, + resources=inference_config, + initial_instance_count=initial_instance_count, + role=self.role_arn, + ) - if isinstance(inference_config, ResourceRequirements): - # Multi Model and MultiContainer endpoints with Inference Component - return self.built_model.deploy( - instance_type=self.instance_type, - mode=Mode.SAGEMAKER_ENDPOINT, - endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, - resources=inference_config, - initial_instance_count=initial_instance_count, - role=self.role_arn, + raise ValueError("Deployment Options not supported") + + # Iterate through deployables for a custom orchestrator deployment. + # Create all Inference Components first before deploying custom orchestrator if present. + predictors = [] + for inference_component in self._deployables.get("InferenceComponents", []): + predictors.append( + self._deploy_for_ic( + ic_data=inference_component, + container_timeout_in_seconds=container_timeout_in_second, + instance_type=instance_type, + initial_instance_count=initial_instance_count, + endpoint_name=endpoint_name, + **kwargs, + ) ) + if self._deployables.get("CustomOrchestrator", None): + custom_orchestrator = self._deployables.get("CustomOrchestrator") + if not custom_orchestrator_instance_type and not instance_type: + logger.warning( + "Deploying custom orchestrator as an endpoint but no instance type was " + "set. Defaulting to `ml.c5.xlarge`." + ) + custom_orchestrator_instance_type = "ml.c5.xlarge" + custom_orchestrator_initial_instance_count = 1 + if custom_orchestrator["Mode"] == "Endpoint": + logger.info( + "Deploying custom orchestrator on instance type %s.", + custom_orchestrator_instance_type, + ) + predictors.append( + custom_orchestrator["Model"].deploy( + instance_type=custom_orchestrator_instance_type, + initial_instance_count=custom_orchestrator_initial_instance_count, + **kwargs, + ) + ) + elif custom_orchestrator["Mode"] == "InferenceComponent": + logger.info( + "Deploying custom orchestrator as an inference component " + f"to endpoint {endpoint_name}" + ) + predictors.append( + self._deploy_for_ic( + ic_data=custom_orchestrator, + container_timeout_in_seconds=container_timeout_in_second, + instance_type=custom_orchestrator_instance_type or instance_type, + initial_instance_count=custom_orchestrator_initial_instance_count + or initial_instance_count, + endpoint_name=endpoint_name, + **kwargs, + ) + ) - raise ValueError("Deployment Options not supported") + return predictors def display_benchmark_metrics(self, **kwargs): """Display Markdown Benchmark Metrics for deployment configs.""" diff --git a/src/sagemaker/serve/builder/schema_builder.py b/src/sagemaker/serve/builder/schema_builder.py index 3fd1816d0e..7f70e98747 100644 --- a/src/sagemaker/serve/builder/schema_builder.py +++ b/src/sagemaker/serve/builder/schema_builder.py @@ -4,6 +4,7 @@ import io import logging from pathlib import Path +from typing import Callable import numpy as np from pandas import DataFrame @@ -286,7 +287,7 @@ def _is_path_to_file(data: object) -> bool: def _validate_translations( - payload: object, serialize_callable: callable, deserialize_callable: callable + payload: object, serialize_callable: Callable, deserialize_callable: Callable ) -> None: """Placeholder docstring""" try: diff --git a/src/sagemaker/serve/detector/dependency_manager.py b/src/sagemaker/serve/detector/dependency_manager.py index e72a84da30..8ff37c9185 100644 --- a/src/sagemaker/serve/detector/dependency_manager.py +++ b/src/sagemaker/serve/detector/dependency_manager.py @@ -34,22 +34,34 @@ def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool = """Placeholder docstring""" path = work_dir.joinpath("requirements.txt") if "auto" in dependencies and dependencies["auto"]: + import site + + pkl_path = work_dir.joinpath(PKL_FILE_NAME) + dest_path = path + site_packages_dir = site.getsitepackages()[0] + pickle_command_dir = "/sagemaker/serve/detector" + command = [ sys.executable, - Path(__file__).parent.joinpath("pickle_dependencies.py"), - "--pkl_path", - work_dir.joinpath(PKL_FILE_NAME), - "--dest", - path, + "-c", ] if capture_all: - command.append("--capture_all") + command.append( + f"from pickle_dependencies import get_all_requirements;" + f'get_all_requirements("{dest_path}")' + ) + else: + command.append( + f"from pickle_dependencies import get_requirements_for_pkl_file;" + f'get_requirements_for_pkl_file("{pkl_path}", "{dest_path}")' + ) subprocess.run( command, env={"SETUPTOOLS_USE_DISTUTILS": "stdlib"}, check=True, + cwd=site_packages_dir + pickle_command_dir, ) with open(path, "r") as f: diff --git a/src/sagemaker/serve/detector/pickle_dependencies.py b/src/sagemaker/serve/detector/pickle_dependencies.py index 5a1cd43869..8f9da917fd 100644 --- a/src/sagemaker/serve/detector/pickle_dependencies.py +++ b/src/sagemaker/serve/detector/pickle_dependencies.py @@ -3,7 +3,6 @@ from __future__ import absolute_import from pathlib import Path from typing import List -import argparse import email.parser import email.policy import json @@ -129,32 +128,3 @@ def get_all_requirements(dest: Path): version = package_info.get("version") out.write(f"{name}=={version}\n") - - -def parse_args(): - """Placeholder docstring""" - parser = argparse.ArgumentParser( - prog="pkl_requirements", description="Generates a requirements.txt for a cloudpickle file" - ) - parser.add_argument("--pkl_path", required=True, help="path of the pkl file") - parser.add_argument("--dest", required=True, help="path of the destination requirements.txt") - parser.add_argument( - "--capture_all", - action="store_true", - help="capture all dependencies in current environment", - ) - args = parser.parse_args() - return (Path(args.pkl_path), Path(args.dest), args.capture_all) - - -def main(): - """Placeholder docstring""" - pkl_path, dest, capture_all = parse_args() - if capture_all: - get_all_requirements(dest) - else: - get_requirements_for_pkl_file(pkl_path, dest) - - -if __name__ == "__main__": - main() diff --git a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py index 2f09d3d572..2b4473a706 100644 --- a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py +++ b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py @@ -16,10 +16,13 @@ from sagemaker.serve.model_server.djl_serving.server import SageMakerDjlServing from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing from sagemaker.serve.model_server.multi_model_server.server import SageMakerMultiModelServer +from sagemaker.serve.model_server.smd.server import SageMakerSmdServer + logger = logging.getLogger(__name__) +# pylint: disable=R0901 class SageMakerEndpointMode( SageMakerTorchServe, SageMakerTritonServer, @@ -27,6 +30,7 @@ class SageMakerEndpointMode( SageMakerTgiServing, SageMakerMultiModelServer, SageMakerTensorflowServing, + SageMakerSmdServer, ): """Holds the required method to deploy a model to a SageMaker Endpoint""" @@ -144,6 +148,16 @@ def prepare( should_upload_artifacts=should_upload_artifacts, ) + if self.model_server == ModelServer.SMD: + upload_artifacts = self._upload_smd_artifacts( + model_path=model_path, + sagemaker_session=sagemaker_session, + secret_key=secret_key, + s3_model_data_url=s3_model_data_url, + image=image, + should_upload_artifacts=True, + ) + if upload_artifacts or isinstance(self.model_server, ModelServer): return upload_artifacts diff --git a/src/sagemaker/serve/model_format/mlflow/constants.py b/src/sagemaker/serve/model_format/mlflow/constants.py index d7ddcd9ef0..ff7553ea5f 100644 --- a/src/sagemaker/serve/model_format/mlflow/constants.py +++ b/src/sagemaker/serve/model_format/mlflow/constants.py @@ -18,6 +18,7 @@ "py38": "1.12.1", "py39": "1.13.1", "py310": "2.2.0", + "py311": "2.3.0", } MODEL_PACKAGE_ARN_REGEX = ( r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/(.*?)(?:/(\d+))?$" diff --git a/src/sagemaker/serve/model_server/smd/custom_execution_inference.py b/src/sagemaker/serve/model_server/smd/custom_execution_inference.py new file mode 100644 index 0000000000..f53677fc69 --- /dev/null +++ b/src/sagemaker/serve/model_server/smd/custom_execution_inference.py @@ -0,0 +1,72 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module is for SageMaker inference.py.""" + +from __future__ import absolute_import +import asyncio +import os +import platform +import cloudpickle +import logging +from pathlib import Path +from sagemaker.serve.validations.check_integrity import perform_integrity_check + +logger = LOGGER = logging.getLogger("sagemaker") + + +def initialize_custom_orchestrator(): + """Initializes the custom orchestrator.""" + code_dir = os.getenv("SAGEMAKER_INFERENCE_CODE_DIRECTORY", None) + serve_path = Path(code_dir).joinpath("serve.pkl") + with open(str(serve_path), mode="rb") as pkl_file: + return cloudpickle.load(pkl_file) + + +def _run_preflight_diagnostics(): + _py_vs_parity_check() + _pickle_file_integrity_check() + + +def _py_vs_parity_check(): + container_py_vs = platform.python_version() + local_py_vs = os.getenv("LOCAL_PYTHON") + + if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]: + logger.warning( + f"The local python version {local_py_vs} differs from the python version " + f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior" + ) + + +def _pickle_file_integrity_check(): + with open("/opt/ml/model/code/serve.pkl", "rb") as f: + buffer = f.read() + + metadata_path = Path("/opt/ml/model/code/metadata.json") + perform_integrity_check(buffer=buffer, metadata_path=metadata_path) + + +_run_preflight_diagnostics() +custom_orchestrator, _ = initialize_custom_orchestrator() + + +async def handler(request): + """Custom service entry point function. + + :param request: raw input from request + :return: outputs to be send back to client + """ + if asyncio.iscoroutinefunction(custom_orchestrator.handle): + return await custom_orchestrator.handle(request.body) + else: + return custom_orchestrator.handle(request.body) diff --git a/src/sagemaker/serve/model_server/smd/prepare.py b/src/sagemaker/serve/model_server/smd/prepare.py new file mode 100644 index 0000000000..6461e4023f --- /dev/null +++ b/src/sagemaker/serve/model_server/smd/prepare.py @@ -0,0 +1,74 @@ +"""Summary of MyModule. + +Extended discussion of my module. +""" + +from __future__ import absolute_import +import os +from pathlib import Path +import shutil +from typing import List + +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.detector.dependency_manager import capture_dependencies +from sagemaker.serve.validations.check_integrity import ( + generate_secret_key, + compute_hash, +) +from sagemaker.remote_function.core.serialization import _MetaData +from sagemaker.serve.spec.inference_base import CustomOrchestrator, AsyncCustomOrchestrator + + +def prepare_for_smd( + model_path: str, + shared_libs: List[str], + dependencies: dict, + inference_spec: InferenceSpec = None, +) -> str: + """Prepares artifacts for SageMaker model deployment. + + Args:to + model_path (str) : Argument + shared_libs (List[]) : Argument + dependencies (dict) : Argument + inference_spec (InferenceSpec, optional) : Argument + (default is None) + + Returns: + ( str ) : + + """ + model_path = Path(model_path) + if not model_path.exists(): + model_path.mkdir() + elif not model_path.is_dir(): + raise Exception("model_dir is not a valid directory") + + if inference_spec and isinstance(inference_spec, InferenceSpec): + inference_spec.prepare(str(model_path)) + + code_dir = model_path.joinpath("code") + code_dir.mkdir(exist_ok=True) + + if inference_spec and isinstance(inference_spec, (CustomOrchestrator, AsyncCustomOrchestrator)): + shutil.copy2(Path(__file__).parent.joinpath("custom_execution_inference.py"), code_dir) + os.rename( + str(code_dir.joinpath("custom_execution_inference.py")), + str(code_dir.joinpath("inference.py")), + ) + + shared_libs_dir = model_path.joinpath("shared_libs") + shared_libs_dir.mkdir(exist_ok=True) + for shared_lib in shared_libs: + shutil.copy2(Path(shared_lib), shared_libs_dir) + + capture_dependencies(dependencies=dependencies, work_dir=code_dir) + + secret_key = generate_secret_key() + with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: + buffer = f.read() + hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: + metadata.write(_MetaData(hash_value).to_json()) + + return secret_key diff --git a/src/sagemaker/serve/model_server/smd/server.py b/src/sagemaker/serve/model_server/smd/server.py new file mode 100644 index 0000000000..c700c39727 --- /dev/null +++ b/src/sagemaker/serve/model_server/smd/server.py @@ -0,0 +1,59 @@ +"""Module for SMD Server""" + +from __future__ import absolute_import + +import logging +import platform +from sagemaker.serve.utils.optimize_utils import _is_s3_uri +from sagemaker.session import Session +from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url +from sagemaker import fw_utils +from sagemaker.serve.utils.uploader import upload + +logger = logging.getLogger(__name__) + + +class SageMakerSmdServer: + """Placeholder docstring""" + + def _upload_smd_artifacts( + self, + model_path: str, + sagemaker_session: Session, + secret_key: str, + s3_model_data_url: str = None, + image: str = None, + should_upload_artifacts: bool = False, + ): + """Tar the model artifact and upload to S3 bucket, then prepare for the environment variables""" + s3_upload_path = None + if _is_s3_uri(model_path): + s3_upload_path = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) + + logger.debug( + "Uploading the model resources to bucket=%s, key_prefix=%s.", + bucket, + code_key_prefix, + ) + s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) + logger.debug("Model resources uploaded to: %s", s3_upload_path) + + env_vars = { + "SAGEMAKER_INFERENCE_CODE_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_INFERENCE_CODE": "inference.handler", + "SAGEMAKER_REGION": sagemaker_session.boto_region_name, + "SAGEMAKER_SERVE_SECRET_KEY": secret_key, + "LOCAL_PYTHON": platform.python_version(), + } + return s3_upload_path, env_vars diff --git a/src/sagemaker/serve/spec/inference_base.py b/src/sagemaker/serve/spec/inference_base.py new file mode 100644 index 0000000000..23ea6cb01d --- /dev/null +++ b/src/sagemaker/serve/spec/inference_base.py @@ -0,0 +1,45 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds templated classes to enable users to provide custom inference scripting capabilities""" +from __future__ import absolute_import +from abc import ABC, abstractmethod + + +class CustomOrchestrator(ABC): + """Templated class to standardize sync entrypoint-based inference scripts""" + + def __init__(self): + self._client = None + + @property + def client(self): + """Boto3 SageMaker runtime client to use with custom orchestrator""" + if not hasattr(self, "_client") or not self._client: + from boto3 import Session + + self._client = Session().client("sagemaker-runtime") + return self._client + + @abstractmethod + def handle(self, data, context=None): + """Abstract class for defining an entrypoint for the model server""" + return NotImplemented + + +class AsyncCustomOrchestrator(ABC): + """Templated class to standardize async entrypoint-based inference scripts""" + + @abstractmethod + async def handle(self, data, context=None): + """Abstract class for defining an aynchronous entrypoint for the model server""" + return NotImplemented diff --git a/src/sagemaker/serve/utils/telemetry_logger.py b/src/sagemaker/serve/utils/telemetry_logger.py index a1a0408718..1534d960fe 100644 --- a/src/sagemaker/serve/utils/telemetry_logger.py +++ b/src/sagemaker/serve/utils/telemetry_logger.py @@ -64,6 +64,7 @@ str(ModelServer.TRITON): 5, str(ModelServer.TGI): 6, str(ModelServer.TEI): 7, + str(ModelServer.SMD): 8, } MLFLOW_MODEL_PATH_CODE = { diff --git a/src/sagemaker/serve/utils/tuning.py b/src/sagemaker/serve/utils/tuning.py index b93c01b522..5a63cfe508 100644 --- a/src/sagemaker/serve/utils/tuning.py +++ b/src/sagemaker/serve/utils/tuning.py @@ -7,6 +7,7 @@ import collections from multiprocessing.pool import ThreadPool from math import ceil +from typing import Callable import pandas as pd from numpy import percentile, std from sagemaker.serve.model_server.djl_serving.utils import _tokens_from_chars, _tokens_from_words @@ -152,7 +153,7 @@ def _tokens_per_second(generated_text: str, max_token_length: int, latency: floa return min(est_tokens, max_token_length) / latency -def _timed_invoke(predict: callable, sample_input: object) -> tuple: +def _timed_invoke(predict: Callable, sample_input: object) -> tuple: """Placeholder docstring""" start_timer = perf_counter() response = predict(sample_input) diff --git a/src/sagemaker/serve/utils/types.py b/src/sagemaker/serve/utils/types.py index e50be62440..b405d85b21 100644 --- a/src/sagemaker/serve/utils/types.py +++ b/src/sagemaker/serve/utils/types.py @@ -19,6 +19,7 @@ def __str__(self): TRITON = 5 TGI = 6 TEI = 7 + SMD = 8 class HardwareType(Enum): diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 04a7326557..b2398e03d1 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4347,11 +4347,59 @@ def submit(request): if model_package_group_name is not None and not model_package_group_name.startswith( "arn:" ): - _create_resource( - lambda: self.sagemaker_client.create_model_package_group( - ModelPackageGroupName=request["ModelPackageGroupName"] + is_model_package_group_present = False + try: + model_package_groups_response = self.search( + resource="ModelPackageGroup", + search_expression={ + "Filters": [ + { + "Name": "ModelPackageGroupName", + "Value": request["ModelPackageGroupName"], + "Operator": "Equals", + } + ], + }, + ) + if len(model_package_groups_response.get("Results")) > 0: + is_model_package_group_present = True + except Exception: # pylint: disable=W0703 + model_package_groups = [] + model_package_groups_response = self.sagemaker_client.list_model_package_groups( + NameContains=request["ModelPackageGroupName"], + ) + model_package_groups = ( + model_package_groups + + model_package_groups_response["ModelPackageGroupSummaryList"] + ) + next_token = model_package_groups_response.get("NextToken") + + while next_token is not None and next_token != "": + model_package_groups_response = ( + self.sagemaker_client.list_model_package_groups( + NameContains=request["ModelPackageGroupName"], NextToken=next_token + ) + ) + model_package_groups = ( + model_package_groups + + model_package_groups_response["ModelPackageGroupSummaryList"] + ) + next_token = model_package_groups_response.get("NextToken") + + filtered_model_package_group = list( + filter( + lambda mpg: mpg.get("ModelPackageGroupName") + == request["ModelPackageGroupName"], + model_package_groups, + ) + ) + is_model_package_group_present = len(filtered_model_package_group) > 0 + if not is_model_package_group_present: + _create_resource( + lambda: self.sagemaker_client.create_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"] + ) ) - ) if "SourceUri" in request and request["SourceUri"] is not None: # Remove inference spec from request if the # given source uri can lead to auto-population of it @@ -5286,7 +5334,7 @@ def get_tagging_resources(self, tag_filters, resource_type_filters): resource_tag_response = self.resource_group_tagging_client.get_resources( TagFilters=tag_filters, ResourceTypeFilters=resource_type_filters, - NextToken=next_token, + PaginationToken=next_token, ) resource_list = resource_list + resource_tag_response["ResourceTagMappingList"] next_token = resource_tag_response.get("PaginationToken") diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index ae66bc8338..586e50da88 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -83,8 +83,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index c3727b2fb5..a9b0e2e8f0 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, List, Dict +from typing import Callable, Union, Optional, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -92,7 +92,7 @@ def __init__( framework_version: Optional[str] = None, py_version: str = "py3", image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = SKLearnPredictor, + predictor_cls: Optional[Callable] = SKLearnPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): @@ -122,7 +122,7 @@ def __init__( If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If ``image_uri`` is also ``None``, then a ``ValueError`` will be raised. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. diff --git a/src/sagemaker/telemetry/constants.py b/src/sagemaker/telemetry/constants.py index 2108ff9fd6..cb83a78279 100644 --- a/src/sagemaker/telemetry/constants.py +++ b/src/sagemaker/telemetry/constants.py @@ -42,3 +42,40 @@ class Status(Enum): def __str__(self): # pylint: disable=E0307 """Return the status name.""" return self.name + + +class Region(str, Enum): + """Telemetry: List of all supported AWS regions.""" + + # Classic + US_EAST_1 = "us-east-1" # IAD + US_EAST_2 = "us-east-2" # CMH + US_WEST_1 = "us-west-1" # SFO + US_WEST_2 = "us-west-2" # PDX + AP_NORTHEAST_1 = "ap-northeast-1" # NRT + AP_NORTHEAST_2 = "ap-northeast-2" # ICN + AP_NORTHEAST_3 = "ap-northeast-3" # KIX + AP_SOUTH_1 = "ap-south-1" # BOM + AP_SOUTHEAST_1 = "ap-southeast-1" # SIN + AP_SOUTHEAST_2 = "ap-southeast-2" # SYD + CA_CENTRAL_1 = "ca-central-1" # YUL + EU_CENTRAL_1 = "eu-central-1" # FRA + EU_NORTH_1 = "eu-north-1" # ARN + EU_WEST_1 = "eu-west-1" # DUB + EU_WEST_2 = "eu-west-2" # LHR + EU_WEST_3 = "eu-west-3" # CDG + SA_EAST_1 = "sa-east-1" # GRU + # Opt-in + AP_EAST_1 = "ap-east-1" # HKG + AP_SOUTHEAST_3 = "ap-southeast-3" # CGK + AF_SOUTH_1 = "af-south-1" # CPT + EU_SOUTH_1 = "eu-south-1" # MXP + ME_SOUTH_1 = "me-south-1" # BAH + MX_CENTRAL_1 = "mx-central-1" # QRO + AP_SOUTHEAST_7 = "ap-southeast-7" # BKK + AP_SOUTH_2 = "ap-south-2" # HYD + AP_SOUTHEAST_4 = "ap-southeast-4" # MEL + EU_CENTRAL_2 = "eu-central-2" # ZRH + EU_SOUTH_2 = "eu-south-2" # ZAZ + IL_CENTRAL_1 = "il-central-1" # TLV + ME_CENTRAL_1 = "me-central-1" # DXB diff --git a/src/sagemaker/telemetry/telemetry_logging.py b/src/sagemaker/telemetry/telemetry_logging.py index b45550b2c2..b0ecedee4c 100644 --- a/src/sagemaker/telemetry/telemetry_logging.py +++ b/src/sagemaker/telemetry/telemetry_logging.py @@ -27,6 +27,7 @@ from sagemaker.telemetry.constants import ( Feature, Status, + Region, DEFAULT_AWS_REGION, ) from sagemaker.user_agent import SDK_VERSION, process_studio_metadata_file @@ -189,8 +190,16 @@ def _send_telemetry_request( """Make GET request to an empty object in S3 bucket""" try: accountId = _get_accountId(session) if session else "NotAvailable" - # telemetry will be sent to us-west-2 if no session availale - region = _get_region_or_default(session) if session else DEFAULT_AWS_REGION + region = _get_region_or_default(session) + + try: + Region(region) # Validate the region + except ValueError: + logger.warning( + "Region not found in supported regions. Telemetry request will not be emitted." + ) + return + url = _construct_url( accountId, region, @@ -268,6 +277,7 @@ def _get_region_or_default(session): def _get_default_sagemaker_session(): """Return the default sagemaker session""" + boto_session = boto3.Session(region_name=DEFAULT_AWS_REGION) sagemaker_session = Session(boto_session=boto_session) diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index fe20994e20..c7f624114f 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, List, Dict +from typing import Callable, Union, Optional, List, Dict import sagemaker from sagemaker import image_uris, s3, ModelMetrics @@ -62,9 +62,9 @@ def __init__( manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. - serializer (callable): Optional. Default serializes input data to + serializer (Callable): Optional. Default serializes input data to json. Handles dicts, lists, and numpy arrays. - deserializer (callable): Optional. Default parses the response using + deserializer (Callable): Optional. Default parses the response using ``json.load(...)``. model_name (str): Optional. The name of the SavedModel model that should handle the request. If not specified, the endpoint's @@ -146,7 +146,7 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[str] = None, container_log_level: Optional[int] = None, - predictor_cls: callable = TensorFlowPredictor, + predictor_cls: Optional[Callable] = TensorFlowPredictor, **kwargs, ): """Initialize a Model. @@ -174,7 +174,7 @@ def __init__( container_log_level (int): Log level to use within the container (default: logging.ERROR). Valid values are defined in the Python logging module. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 4b0f38f36f..fa8f9b8555 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -18,21 +18,20 @@ import inspect import json import logging - from enum import Enum -from typing import Union, Dict, Optional, List, Set +from typing import Dict, List, Optional, Set, Union import sagemaker from sagemaker.amazon.amazon_estimator import ( - RecordSet, AmazonAlgorithmEstimatorBase, FileSystemRecordSet, + RecordSet, ) from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.analytics import HyperparameterTuningJobAnalytics from sagemaker.deprecations import removed_function -from sagemaker.estimator import Framework, EstimatorBase -from sagemaker.inputs import TrainingInput, FileSystemInput +from sagemaker.estimator import EstimatorBase, Framework +from sagemaker.inputs import FileSystemInput, TrainingInput from sagemaker.job import _Job from sagemaker.jumpstart.utils import ( add_jumpstart_uri_tags, @@ -44,18 +43,17 @@ IntegerParameter, ParameterRange, ) -from sagemaker.workflow.entities import PipelineVariable -from sagemaker.workflow.pipeline_context import runnable_by_pipeline - from sagemaker.session import Session from sagemaker.utils import ( + Tags, base_from_name, base_name_from_image, + format_tags, name_from_base, to_string, - format_tags, - Tags, ) +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.pipeline_context import runnable_by_pipeline AMAZON_ESTIMATOR_MODULE = "sagemaker" AMAZON_ESTIMATOR_CLS_NAMES = { @@ -133,15 +131,12 @@ def __init__( if warm_start_type not in list(WarmStartTypes): raise ValueError( - "Invalid type: {}, valid warm start types are: {}".format( - warm_start_type, list(WarmStartTypes) - ) + f"Invalid type: {warm_start_type}, " + f"valid warm start types are: {list(WarmStartTypes)}" ) if not parents: - raise ValueError( - "Invalid parents: {}, parents should not be None/empty".format(parents) - ) + raise ValueError(f"Invalid parents: {parents}, parents should not be None/empty") self.type = warm_start_type self.parents = set(parents) @@ -1455,9 +1450,7 @@ def _get_best_training_job(self): return tuning_job_describe_result["BestTrainingJob"] except KeyError: raise Exception( - "Best training job not available for tuning job: {}".format( - self.latest_tuning_job.name - ) + f"Best training job not available for tuning job: {self.latest_tuning_job.name}" ) def _ensure_last_tuning_job(self): @@ -1920,8 +1913,11 @@ def create( :meth:`~sagemaker.tuner.HyperparameterTuner.fit` method launches. If not specified, a default job name is generated, based on the training image name and current timestamp. - strategy (str): Strategy to be used for hyperparameter estimations - (default: 'Bayesian'). + strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations. + More information about different strategies: + https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html. + Available options are: 'Bayesian', 'Random', 'Hyperband', + 'Grid' (default: 'Bayesian') strategy_config (dict): The configuration for a training job launched by a hyperparameter tuning job. completion_criteria_config (dict): The configuration for tuning job completion criteria. @@ -2080,21 +2076,19 @@ def _validate_dict_argument(cls, name, value, allowed_keys, require_same_keys=Fa return if not isinstance(value, dict): - raise ValueError( - "Argument '{}' must be a dictionary using {} as keys".format(name, allowed_keys) - ) + raise ValueError(f"Argument '{name}' must be a dictionary using {allowed_keys} as keys") value_keys = sorted(value.keys()) if require_same_keys: if value_keys != allowed_keys: raise ValueError( - "The keys of argument '{}' must be the same as {}".format(name, allowed_keys) + f"The keys of argument '{name}' must be the same as {allowed_keys}" ) else: if not set(value_keys).issubset(set(allowed_keys)): raise ValueError( - "The keys of argument '{}' must be a subset of {}".format(name, allowed_keys) + f"The keys of argument '{name}' must be a subset of {allowed_keys}" ) def _add_estimator( diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index e8602de8d7..1a75a3a5cc 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -397,8 +397,7 @@ def download_folder(bucket_name, prefix, target, sagemaker_session): sagemaker_session (sagemaker.session.Session): a sagemaker session to interact with S3. """ - boto_session = sagemaker_session.boto_session - s3 = boto_session.resource("s3", region_name=boto_session.region_name) + s3 = sagemaker_session.s3_resource prefix = prefix.lstrip("/") @@ -726,7 +725,7 @@ def retry_with_backoff(callable_func, num_attempts=8, botocore_client_error_code """Retry with backoff until maximum attempts are reached Args: - callable_func (callable): The callable function to retry. + callable_func (Callable): The callable function to retry. num_attempts (int): The maximum number of attempts to retry.(Default: 8) botocore_client_error_code (str): The specific Botocore ClientError exception error code on which to retry on. diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index a80b5440c7..f49e457bc6 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -645,6 +645,7 @@ def arguments(self) -> RequestType: request_dict = self.step_args else: if isinstance(self.model, PipelineModel): + self.model._init_sagemaker_session_if_does_not_exist() request_dict = self.model.sagemaker_session._create_model_request( name="", role=self.model.role, @@ -653,6 +654,7 @@ def arguments(self) -> RequestType: enable_network_isolation=self.model.enable_network_isolation, ) else: + self.model._init_sagemaker_session_if_does_not_exist() request_dict = self.model.sagemaker_session._create_model_request( name="", role=self.model.role, diff --git a/src/sagemaker/xgboost/estimator.py b/src/sagemaker/xgboost/estimator.py index 2921dbc2db..9385acf745 100644 --- a/src/sagemaker/xgboost/estimator.py +++ b/src/sagemaker/xgboost/estimator.py @@ -78,8 +78,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index ea532b4c39..f4797c79e7 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -91,7 +91,7 @@ def __init__( framework_version: str = None, image_uri: Optional[Union[str, PipelineVariable]] = None, py_version: str = "py3", - predictor_cls: callable = XGBoostPredictor, + predictor_cls: Optional[Callable] = XGBoostPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): @@ -113,8 +113,8 @@ def __init__( (default: 'py3'). framework_version (str): XGBoost version you want to use for executing your model training code. - predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create - a predictor with an endpoint name and SageMaker ``Session``. + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call + to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. model_server_workers (int or PipelineVariable): Optional. The number of worker processes diff --git a/tests/conftest.py b/tests/conftest.py index db890d1a14..2c8dc2689f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -293,6 +293,8 @@ def huggingface_pytorch_training_version(huggingface_training_version): @pytest.fixture(scope="module") def huggingface_pytorch_training_py_version(huggingface_pytorch_training_version): + if Version(huggingface_pytorch_training_version) >= Version("2.3"): + return "py311" if Version(huggingface_pytorch_training_version) >= Version("2.0"): return "py310" elif Version(huggingface_pytorch_training_version) >= Version("1.13"): @@ -355,6 +357,8 @@ def huggingface_training_compiler_pytorch_py_version( def huggingface_pytorch_latest_training_py_version( huggingface_training_pytorch_latest_version, ): + if Version(huggingface_training_pytorch_latest_version) >= Version("2.3"): + return "py311" if Version(huggingface_training_pytorch_latest_version) >= Version("2.0"): return "py310" elif Version(huggingface_training_pytorch_latest_version) >= Version("1.13"): diff --git a/tests/data/modules/custom_drivers/driver.py b/tests/data/modules/custom_drivers/driver.py new file mode 100644 index 0000000000..3395b80da9 --- /dev/null +++ b/tests/data/modules/custom_drivers/driver.py @@ -0,0 +1,34 @@ +import json +import os +import subprocess +import sys + + +def main(): + driver_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + process_count_per_node = driver_config["process_count_per_node"] + assert process_count_per_node != None + + hps = json.loads(os.environ["SM_HPS"]) + assert hps != None + assert isinstance(hps, dict) + + source_dir = os.environ["SM_SOURCE_DIR"] + assert source_dir == "/opt/ml/input/data/code" + sm_drivers_dir = os.environ["SM_DISTRIBUTED_DRIVER_DIR"] + assert sm_drivers_dir == "/opt/ml/input/data/sm_drivers/distributed_drivers" + + entry_script = os.environ["SM_ENTRY_SCRIPT"] + assert entry_script != None + + python = sys.executable + + command = [python, entry_script] + print(f"Running command: {command}") + subprocess.run(command, check=True) + + +if __name__ == "__main__": + print("Running custom driver script") + main() + print("Finished running custom driver script") diff --git a/tests/data/modules/params_script/hyperparameters.json b/tests/data/modules/params_script/hyperparameters.json new file mode 100644 index 0000000000..f637288dbe --- /dev/null +++ b/tests/data/modules/params_script/hyperparameters.json @@ -0,0 +1,15 @@ +{ + "integer": 1, + "boolean": true, + "float": 3.14, + "string": "Hello World", + "list": [1, 2, 3], + "dict": { + "string": "value", + "integer": 3, + "float": 3.14, + "list": [1, 2, 3], + "dict": {"key": "value"}, + "boolean": true + } +} \ No newline at end of file diff --git a/tests/data/modules/params_script/hyperparameters.yaml b/tests/data/modules/params_script/hyperparameters.yaml new file mode 100644 index 0000000000..9e3011daf2 --- /dev/null +++ b/tests/data/modules/params_script/hyperparameters.yaml @@ -0,0 +1,19 @@ +integer: 1 +boolean: true +float: 3.14 +string: "Hello World" +list: + - 1 + - 2 + - 3 +dict: + string: value + integer: 3 + float: 3.14 + list: + - 1 + - 2 + - 3 + dict: + key: value + boolean: true \ No newline at end of file diff --git a/tests/data/modules/params_script/requirements.txt b/tests/data/modules/params_script/requirements.txt new file mode 100644 index 0000000000..3d2e72e354 --- /dev/null +++ b/tests/data/modules/params_script/requirements.txt @@ -0,0 +1 @@ +omegaconf diff --git a/tests/data/modules/params_script/train.py b/tests/data/modules/params_script/train.py index 8d3924a325..9b8cb2c82f 100644 --- a/tests/data/modules/params_script/train.py +++ b/tests/data/modules/params_script/train.py @@ -16,6 +16,9 @@ import argparse import json import os +from typing import List, Dict, Any +from dataclasses import dataclass +from omegaconf import OmegaConf EXPECTED_HYPERPARAMETERS = { "integer": 1, @@ -26,6 +29,7 @@ "dict": { "string": "value", "integer": 3, + "float": 3.14, "list": [1, 2, 3], "dict": {"key": "value"}, "boolean": True, @@ -117,7 +121,7 @@ def main(): assert isinstance(params["dict"], dict) params = json.loads(os.environ["SM_TRAINING_ENV"])["hyperparameters"] - print(params) + print(f"SM_TRAINING_ENV -> hyperparameters: {params}") assert params["string"] == EXPECTED_HYPERPARAMETERS["string"] assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"] assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"] @@ -132,9 +136,96 @@ def main(): assert isinstance(params["float"], float) assert isinstance(params["list"], list) assert isinstance(params["dict"], dict) - print(f"SM_TRAINING_ENV -> hyperparameters: {params}") - print("Test passed.") + # Local JSON - DictConfig OmegaConf + params = OmegaConf.load("hyperparameters.json") + + print(f"Local hyperparameters.json: {params}") + assert params.string == EXPECTED_HYPERPARAMETERS["string"] + assert params.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert params.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert params.float == EXPECTED_HYPERPARAMETERS["float"] + assert params.list == EXPECTED_HYPERPARAMETERS["list"] + assert params.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert params.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert params.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert params.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert params.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert params.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert params.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + + @dataclass + class DictConfig: + string: str + integer: int + boolean: bool + float: float + list: List[int] + dict: Dict[str, Any] + + @dataclass + class HPConfig: + string: str + integer: int + boolean: bool + float: float + list: List[int] + dict: DictConfig + + # Local JSON - Structured OmegaConf + hp_config: HPConfig = OmegaConf.merge( + OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.json") + ) + print(f"Local hyperparameters.json - Structured: {hp_config}") + assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"] + assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"] + assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"] + assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + + # Local YAML - Structured OmegaConf + hp_config: HPConfig = OmegaConf.merge( + OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.yaml") + ) + print(f"Local hyperparameters.yaml - Structured: {hp_config}") + assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"] + assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"] + assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"] + assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + print(f"hyperparameters.yaml -> hyperparameters: {hp_config}") + + # HP Dict - Structured OmegaConf + hp_dict = json.loads(os.environ["SM_HPS"]) + hp_config: HPConfig = OmegaConf.merge(OmegaConf.structured(HPConfig), OmegaConf.create(hp_dict)) + print(f"SM_HPS - Structured: {hp_config}") + assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"] + assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"] + assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"] + assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + print(f"SM_HPS -> hyperparameters: {hp_config}") if __name__ == "__main__": diff --git a/tests/data/modules/scripts/entry_script.py b/tests/data/modules/scripts/entry_script.py new file mode 100644 index 0000000000..3c972bd956 --- /dev/null +++ b/tests/data/modules/scripts/entry_script.py @@ -0,0 +1,19 @@ +import json +import os +import time + + +def main(): + hps = json.loads(os.environ["SM_HPS"]) + assert hps != None + print(f"Hyperparameters: {hps}") + + print("Running pseudo training script") + for epochs in range(hps["epochs"]): + print(f"Epoch: {epochs}") + time.sleep(1) + print("Finished running pseudo training script") + + +if __name__ == "__main__": + main() diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py index 751162d2e6..a64db4a97d 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py @@ -82,6 +82,23 @@ def test_jumpstart_hub_model(setup, add_model_references): assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name) +def test_jumpstart_hub_model_with_default_session(setup, add_model_references): + model_version = "*" + hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] + + model_id = "catboost-classification-model" + + sagemaker_session = get_sm_session() + + model = JumpStartModel(model_id=model_id, model_version=model_version, hub_name=hub_name) + + predictor = model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name) + + def test_jumpstart_hub_gated_model(setup, add_model_references): model_id = "meta-textgeneration-llama-3-2-1b" @@ -105,9 +122,10 @@ def test_jumpstart_hub_gated_model(setup, add_model_references): assert response is not None +@pytest.mark.skip(reason="blocking PR checks and release pipeline.") def test_jumpstart_gated_model_inference_component_enabled(setup, add_model_references): - model_id = "meta-textgeneration-llama-2-7b" + model_id = "meta-textgeneration-llama-3-2-1b" hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] diff --git a/tests/integ/sagemaker/modules/train/test_model_trainer.py b/tests/integ/sagemaker/modules/train/test_model_trainer.py index cd298402b2..a1e3106553 100644 --- a/tests/integ/sagemaker/modules/train/test_model_trainer.py +++ b/tests/integ/sagemaker/modules/train/test_model_trainer.py @@ -17,7 +17,7 @@ from sagemaker.modules.train import ModelTrainer from sagemaker.modules.configs import SourceCode, Compute -from sagemaker.modules.distributed import MPI, Torchrun +from sagemaker.modules.distributed import MPI, Torchrun, DistributedConfig EXPECTED_HYPERPARAMETERS = { "integer": 1, @@ -28,26 +28,29 @@ "dict": { "string": "value", "integer": 3, + "float": 3.14, "list": [1, 2, 3], "dict": {"key": "value"}, "boolean": True, }, } +PARAM_SCRIPT_SOURCE_DIR = f"{DATA_DIR}/modules/params_script" +PARAM_SCRIPT_SOURCE_CODE = SourceCode( + source_dir=PARAM_SCRIPT_SOURCE_DIR, + requirements="requirements.txt", + entry_script="train.py", +) + DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310" def test_hp_contract_basic_py_script(modules_sagemaker_session): - source_code = SourceCode( - source_dir=f"{DATA_DIR}/modules/params_script", - entry_script="train.py", - ) - model_trainer = ModelTrainer( sagemaker_session=modules_sagemaker_session, training_image=DEFAULT_CPU_IMAGE, hyperparameters=EXPECTED_HYPERPARAMETERS, - source_code=source_code, + source_code=PARAM_SCRIPT_SOURCE_CODE, base_job_name="hp-contract-basic-py-script", ) @@ -57,6 +60,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session): def test_hp_contract_basic_sh_script(modules_sagemaker_session): source_code = SourceCode( source_dir=f"{DATA_DIR}/modules/params_script", + requirements="requirements.txt", entry_script="train.sh", ) model_trainer = ModelTrainer( @@ -71,17 +75,13 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session): def test_hp_contract_mpi_script(modules_sagemaker_session): - source_code = SourceCode( - source_dir=f"{DATA_DIR}/modules/params_script", - entry_script="train.py", - ) compute = Compute(instance_type="ml.m5.xlarge", instance_count=2) model_trainer = ModelTrainer( sagemaker_session=modules_sagemaker_session, training_image=DEFAULT_CPU_IMAGE, compute=compute, hyperparameters=EXPECTED_HYPERPARAMETERS, - source_code=source_code, + source_code=PARAM_SCRIPT_SOURCE_CODE, distributed=MPI(), base_job_name="hp-contract-mpi-script", ) @@ -90,19 +90,71 @@ def test_hp_contract_mpi_script(modules_sagemaker_session): def test_hp_contract_torchrun_script(modules_sagemaker_session): - source_code = SourceCode( - source_dir=f"{DATA_DIR}/modules/params_script", - entry_script="train.py", - ) compute = Compute(instance_type="ml.m5.xlarge", instance_count=2) model_trainer = ModelTrainer( sagemaker_session=modules_sagemaker_session, training_image=DEFAULT_CPU_IMAGE, compute=compute, hyperparameters=EXPECTED_HYPERPARAMETERS, - source_code=source_code, + source_code=PARAM_SCRIPT_SOURCE_CODE, distributed=Torchrun(), base_job_name="hp-contract-torchrun-script", ) model_trainer.train() + + +def test_hp_contract_hyperparameter_json(modules_sagemaker_session): + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.json", + source_code=PARAM_SCRIPT_SOURCE_CODE, + base_job_name="hp-contract-hyperparameter-json", + ) + assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS + model_trainer.train() + + +def test_hp_contract_hyperparameter_yaml(modules_sagemaker_session): + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.yaml", + source_code=PARAM_SCRIPT_SOURCE_CODE, + base_job_name="hp-contract-hyperparameter-yaml", + ) + assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS + model_trainer.train() + + +def test_custom_distributed_driver(modules_sagemaker_session): + class CustomDriver(DistributedConfig): + process_count_per_node: int = None + + @property + def driver_dir(self) -> str: + return f"{DATA_DIR}/modules/custom_drivers" + + @property + def driver_script(self) -> str: + return "driver.py" + + source_code = SourceCode( + source_dir=f"{DATA_DIR}/modules/scripts", + entry_script="entry_script.py", + ) + + hyperparameters = {"epochs": 10} + + custom_driver = CustomDriver(process_count_per_node=2) + + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=hyperparameters, + source_code=source_code, + distributed=custom_driver, + base_job_name="custom-distributed-driver", + ) + model_trainer.train() diff --git a/tests/integ/sagemaker/remote_function/test_decorator.py b/tests/integ/sagemaker/remote_function/test_decorator.py index 680bfc01df..fa55d7dfa7 100644 --- a/tests/integ/sagemaker/remote_function/test_decorator.py +++ b/tests/integ/sagemaker/remote_function/test_decorator.py @@ -825,6 +825,7 @@ def test_decorator_torchrun( dummy_container_without_error, gpu_instance_type, use_torchrun=False, + use_mpirun=False, ): @remote( role=ROLE, @@ -833,6 +834,7 @@ def test_decorator_torchrun( sagemaker_session=sagemaker_session, keep_alive_period_in_seconds=60, use_torchrun=use_torchrun, + use_mpirun=use_mpirun, ) def divide(x, y): return x / y diff --git a/tests/integ/sagemaker/serve/constants.py b/tests/integ/sagemaker/serve/constants.py index d5e7a56f83..3f25f6a575 100644 --- a/tests/integ/sagemaker/serve/constants.py +++ b/tests/integ/sagemaker/serve/constants.py @@ -25,6 +25,7 @@ PYTHON_VERSION_IS_NOT_38 = platform.python_version_tuple()[1] != "8" PYTHON_VERSION_IS_NOT_310 = platform.python_version_tuple()[1] != "10" +PYTHON_VERSION_IS_NOT_312 = platform.python_version_tuple()[1] != "12" XGB_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "xgboost") PYTORCH_SQUEEZENET_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "pytorch") diff --git a/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py b/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py index 10f338c4b5..80f9c50e4b 100644 --- a/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py +++ b/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py @@ -12,38 +12,72 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import pytest - -from sagemaker import get_execution_role -from sklearn.datasets import load_iris -from sklearn.model_selection import train_test_split - import os +import uuid +from typing import Generator +import numpy as np +import pandas as pd +import pytest +from sagemaker_core.main.resources import TrainingJob from sagemaker_core.main.shapes import ( AlgorithmSpecification, Channel, DataSource, - S3DataSource, OutputDataConfig, ResourceConfig, + S3DataSource, StoppingCondition, ) -import uuid -from sagemaker.serve.builder.model_builder import ModelBuilder -import pandas as pd -import numpy as np -from sagemaker.serve import InferenceSpec, SchemaBuilder -from sagemaker_core.main.resources import TrainingJob +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split from xgboost import XGBClassifier -from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig - -from sagemaker.s3_utils import s3_path_join +from sagemaker import get_execution_role from sagemaker.async_inference import AsyncInferenceConfig +from sagemaker.s3_utils import s3_path_join +from sagemaker.serve import InferenceSpec, SchemaBuilder +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig from tests.integ.utils import cleanup_model_resources +@pytest.fixture(autouse=True) +def cleanup_endpoints(mb_sagemaker_session) -> Generator[None, None, None]: + """Clean up any existing endpoints before and after tests.""" + sagemaker_client = mb_sagemaker_session.sagemaker_client + + # Pre-test cleanup + try: + endpoints = sagemaker_client.list_endpoints() + for endpoint in endpoints["Endpoints"]: + try: + sagemaker_client.delete_endpoint(EndpointName=endpoint["EndpointName"]) + sagemaker_client.delete_endpoint_config( + EndpointConfigName=endpoint["EndpointConfigName"] + ) + except Exception as e: + print(f"Error cleaning up endpoint {endpoint['EndpointName']}: {e}") + except Exception as e: + print(f"Error listing endpoints: {e}") + + yield + + # Post-test cleanup + try: + endpoints = sagemaker_client.list_endpoints() + for endpoint in endpoints["Endpoints"]: + try: + sagemaker_client.delete_endpoint(EndpointName=endpoint["EndpointName"]) + sagemaker_client.delete_endpoint_config( + EndpointConfigName=endpoint["EndpointConfigName"] + ) + except Exception as e: + print(f"Error cleaning up endpoint {endpoint['EndpointName']}: {e}") + except Exception as e: + print(f"Error listing endpoints: {e}") + + @pytest.fixture(scope="module") def xgboost_model_builder(mb_sagemaker_session): sagemaker_session = mb_sagemaker_session diff --git a/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py b/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py index 348c57745f..e13e672bec 100644 --- a/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py +++ b/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py @@ -32,6 +32,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e iam_client = sagemaker_session.boto_session.client("iam") role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] + sagemaker_session.sagemaker_client.create_optimization_job = MagicMock() + schema_builder = SchemaBuilder("test", "test") model_builder = ModelBuilder( model="meta-textgeneration-llama-3-1-8b-instruct", @@ -50,6 +52,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e accept_eula=True, ) + assert not sagemaker_session.sagemaker_client.create_optimization_job.called + optimized_model.deploy() mock_create_model.assert_called_once_with( @@ -126,6 +130,13 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_ accept_eula=True, ) + assert ( + sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelShardingConfig"]["Image"] + is not None + ) + optimized_model.deploy( resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8}) ) @@ -206,6 +217,13 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are accept_eula=True, ) + assert ( + sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelQuantizationConfig"]["Image"] + is not None + ) + optimized_model.deploy() mock_create_model.assert_called_once_with( diff --git a/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py b/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py new file mode 100644 index 0000000000..84649bb565 --- /dev/null +++ b/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py @@ -0,0 +1,147 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import tests.integ + +from botocore.exceptions import ClientError +from sagemaker.predictor import Predictor +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serve.builder.schema_builder import SchemaBuilder +from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements +from sagemaker.utils import unique_name_from_base + +from tests.integ.sagemaker.serve.constants import ( + SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, +) +from tests.integ.timeout import timeout +import logging + +logger = logging.getLogger(__name__) + +sample_input = {"inputs": "What are falcons?", "parameters": {"max_new_tokens": 32}} + +sample_output = [ + { + "generated_text": "Falcons are small to medium-sized birds of prey related to hawks and eagles." + } +] + +LLAMA_2_7B_JS_ID = "meta-textgeneration-llama-2-7b" +LLAMA_IC_NAME = "llama2-mb-ic" + + +@pytest.fixture +def model_builder_llama_inference_component(): + return ModelBuilder( + model=LLAMA_2_7B_JS_ID, + schema_builder=SchemaBuilder(sample_input, sample_output), + resource_requirements=ResourceRequirements( + requests={"memory": 98304, "num_accelerators": 4, "copies": 1, "num_cpus": 40} + ), + ) + + +@pytest.mark.skipif( + tests.integ.test_region() not in "us-west-2", + reason="G5 capacity available in PDX.", +) +def test_model_builder_ic_sagemaker_endpoint( + sagemaker_session, + model_builder_llama_inference_component, +): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + + model_builder_llama_inference_component.sagemaker_session = sagemaker_session + + model_builder_llama_inference_component.inference_component_name = unique_name_from_base( + LLAMA_IC_NAME + ) + + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + + chain = ModelBuilder( + modelbuilder_list=[ + model_builder_llama_inference_component, + ], + role_arn=role_arn, + sagemaker_session=sagemaker_session, + ) + + chain.build() + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + endpoint_name = "llama-ic-endpoint-name" + predictors = chain.deploy( + instance_type="ml.g5.24xlarge", + initial_instance_count=1, + accept_eula=True, + endpoint_name=endpoint_name, + ) + logger.info("Inference components successfully deployed.") + predictors[0].predict(sample_input) + assert len(predictors) == 1 + except Exception as e: + caught_ex = e + finally: + if caught_ex: + logger.exception(caught_ex) + cleanup_resources(sagemaker_session, [LLAMA_IC_NAME]) + assert False, f"{caught_ex} thrown when running mb-IC deployment test." + + cleanup_resources(sagemaker_session, [LLAMA_IC_NAME]) + + +def cleanup_resources(sagemaker_session, ic_base_names): + sm_client = sagemaker_session.sagemaker_client + + endpoint_names = set() + for ic_base_name in ic_base_names: + response = sm_client.list_inference_components( + NameContains=ic_base_name, StatusEquals="InService" + ) + ics = response["InferenceComponents"] + + logger.info(f"Cleaning up {len(ics)} ICs with base name {ic_base_name}.") + for ic in ics: + ic_name = ic["InferenceComponentName"] + ep_name = ic["EndpointName"] + + try: + logger.info(f"Deleting IC with name {ic_name}") + Predictor( + endpoint_name=ep_name, + component_name=ic_name, + sagemaker_session=sagemaker_session, + ).delete_predictor() + sagemaker_session.wait_for_inference_component_deletion( + inference_component_name=ic_name, + poll=10, + ) + endpoint_names.add(ep_name) + except ClientError as e: + logger.warning(e) + + for endpoint_name in endpoint_names: + logger.info(f"Deleting endpoint with name {endpoint_name}") + try: + Predictor( + endpoint_name=endpoint_name, sagemaker_session=sagemaker_session + ).delete_endpoint() + except ClientError as e: + logger.warning(e) diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index 7f85c0066c..8f98cd076d 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -26,7 +26,6 @@ import pytest from packaging.version import Version -from packaging.specifiers import SpecifierSet from sagemaker.model_card.model_card import ModelCard, ModelOverview, ModelPackageModelCard from sagemaker.model_card.schema_constraints import ModelCardStatusEnum @@ -1422,7 +1421,7 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model( pipeline_name, region_name, ): - if Version(tf_full_version) in SpecifierSet("==2.16.*"): + if Version(tf_full_version) >= Version("2.16"): pytest.skip( "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " "https://github.com/tensorflow/io/issues/2039" diff --git a/tests/integ/sagemaker/workflow/test_model_steps.py b/tests/integ/sagemaker/workflow/test_model_steps.py index 089cdaf08f..02f7613f85 100644 --- a/tests/integ/sagemaker/workflow/test_model_steps.py +++ b/tests/integ/sagemaker/workflow/test_model_steps.py @@ -18,7 +18,6 @@ import pytest from packaging.version import Version -from packaging.specifiers import SpecifierSet from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution from sagemaker.workflow.fail_step import FailStep @@ -592,7 +591,7 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics( def test_model_registration_with_tensorflow_model_with_pipeline_model( pipeline_session, role, tf_full_version, tf_full_py_version, pipeline_name ): - if Version(tf_full_version) in SpecifierSet("==2.16.*"): + if Version(tf_full_version) >= Version("2.16"): pytest.skip( "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " "https://github.com/tensorflow/io/issues/2039" diff --git a/tests/integ/sagemaker/workflow/test_training_steps.py b/tests/integ/sagemaker/workflow/test_training_steps.py index bcff221afe..4b442c6d93 100644 --- a/tests/integ/sagemaker/workflow/test_training_steps.py +++ b/tests/integ/sagemaker/workflow/test_training_steps.py @@ -19,7 +19,6 @@ import pytest from packaging.version import Version -from packaging.specifiers import SpecifierSet from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution from sagemaker import TrainingInput, get_execution_role, utils, image_uris @@ -238,7 +237,7 @@ def test_training_step_with_output_path_as_join( def test_tensorflow_training_step_with_parameterized_code_input( pipeline_session, role, tf_full_version, tf_full_py_version, pipeline_name ): - if Version(tf_full_version) in SpecifierSet("==2.16.*"): + if Version(tf_full_version) >= Version("2.16"): pytest.skip( "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " "https://github.com/tensorflow/io/issues/2039" diff --git a/tests/integ/test_collection.py b/tests/integ/test_collection.py index 2ee1d90e34..9a6db645cf 100644 --- a/tests/integ/test_collection.py +++ b/tests/integ/test_collection.py @@ -19,20 +19,22 @@ def test_create_collection_root_success(sagemaker_session): collection = Collection(sagemaker_session) collection_name = unique_name_from_base("test-collection") - collection.create(collection_name) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - assert collection_details["ResponseMetadata"]["HTTPStatusCode"] == 200 - delete_response = collection.delete([collection_name]) - assert len(delete_response["deleted_collections"]) == 1 - assert len(delete_response["delete_collection_failures"]) == 0 + try: + collection.create(collection_name) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + assert collection_details["ResponseMetadata"]["HTTPStatusCode"] == 200 + finally: + delete_response = collection.delete([collection_name]) + assert len(delete_response["deleted_collections"]) == 1 + assert len(delete_response["delete_collection_failures"]) == 0 def test_create_collection_nested_success(sagemaker_session): @@ -41,25 +43,27 @@ def test_create_collection_nested_success(sagemaker_session): child_collection_name = unique_name_from_base("test-collection-2") collection.create(collection_name) collection.create(collection_name=child_collection_name, parent_collection_name=collection_name) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - # has one child i.e child collection - assert len(collection_details["Resources"]) == 1 - - collection_details = sagemaker_session.list_group_resources( - group=child_collection_name, filters=collection_filter - ) - collection_details["ResponseMetadata"]["HTTPStatusCode"] - delete_response = collection.delete([child_collection_name, collection_name]) - assert len(delete_response["deleted_collections"]) == 2 - assert len(delete_response["delete_collection_failures"]) == 0 + try: + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + # has one child i.e child collection + assert len(collection_details["Resources"]) == 1 + + collection_details = sagemaker_session.list_group_resources( + group=child_collection_name, filters=collection_filter + ) + collection_details["ResponseMetadata"]["HTTPStatusCode"] + finally: + delete_response = collection.delete([child_collection_name, collection_name]) + assert len(delete_response["deleted_collections"]) == 2 + assert len(delete_response["delete_collection_failures"]) == 0 def test_add_remove_model_groups_in_collection_success(sagemaker_session): @@ -70,40 +74,42 @@ def test_add_remove_model_groups_in_collection_success(sagemaker_session): collection = Collection(sagemaker_session) collection_name = unique_name_from_base("test-collection") collection.create(collection_name) - model_groups = [] - model_groups.append(model_group_name) - add_response = collection.add_model_groups( - collection_name=collection_name, model_groups=model_groups - ) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - - assert len(add_response["failure"]) == 0 - assert len(add_response["added_groups"]) == 1 - assert len(collection_details["Resources"]) == 1 - - remove_response = collection.remove_model_groups( - collection_name=collection_name, model_groups=model_groups - ) - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - assert len(remove_response["failure"]) == 0 - assert len(remove_response["removed_groups"]) == 1 - assert len(collection_details["Resources"]) == 0 - - delete_response = collection.delete([collection_name]) - assert len(delete_response["deleted_collections"]) == 1 - sagemaker_session.sagemaker_client.delete_model_package_group( - ModelPackageGroupName=model_group_name - ) + try: + model_groups = [] + model_groups.append(model_group_name) + add_response = collection.add_model_groups( + collection_name=collection_name, model_groups=model_groups + ) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + + assert len(add_response["failure"]) == 0 + assert len(add_response["added_groups"]) == 1 + assert len(collection_details["Resources"]) == 1 + + remove_response = collection.remove_model_groups( + collection_name=collection_name, model_groups=model_groups + ) + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + assert len(remove_response["failure"]) == 0 + assert len(remove_response["removed_groups"]) == 1 + assert len(collection_details["Resources"]) == 0 + + finally: + delete_response = collection.delete([collection_name]) + assert len(delete_response["deleted_collections"]) == 1 + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) def test_move_model_groups_in_collection_success(sagemaker_session): @@ -116,56 +122,58 @@ def test_move_model_groups_in_collection_success(sagemaker_session): destination_collection_name = unique_name_from_base("test-collection-destination") collection.create(source_collection_name) collection.create(destination_collection_name) - model_groups = [] - model_groups.append(model_group_name) - add_response = collection.add_model_groups( - collection_name=source_collection_name, model_groups=model_groups - ) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=source_collection_name, filters=collection_filter - ) - - assert len(add_response["failure"]) == 0 - assert len(add_response["added_groups"]) == 1 - assert len(collection_details["Resources"]) == 1 - - move_response = collection.move_model_group( - source_collection_name=source_collection_name, - model_group=model_group_name, - destination_collection_name=destination_collection_name, - ) - - assert move_response["moved_success"] == model_group_name - - collection_details = sagemaker_session.list_group_resources( - group=destination_collection_name, filters=collection_filter - ) - - assert len(collection_details["Resources"]) == 1 - - collection_details = sagemaker_session.list_group_resources( - group=source_collection_name, filters=collection_filter - ) - assert len(collection_details["Resources"]) == 0 - - remove_response = collection.remove_model_groups( - collection_name=destination_collection_name, model_groups=model_groups - ) - - assert len(remove_response["failure"]) == 0 - assert len(remove_response["removed_groups"]) == 1 - - delete_response = collection.delete([source_collection_name, destination_collection_name]) - assert len(delete_response["deleted_collections"]) == 2 - sagemaker_session.sagemaker_client.delete_model_package_group( - ModelPackageGroupName=model_group_name - ) + try: + model_groups = [] + model_groups.append(model_group_name) + add_response = collection.add_model_groups( + collection_name=source_collection_name, model_groups=model_groups + ) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=source_collection_name, filters=collection_filter + ) + + assert len(add_response["failure"]) == 0 + assert len(add_response["added_groups"]) == 1 + assert len(collection_details["Resources"]) == 1 + + move_response = collection.move_model_group( + source_collection_name=source_collection_name, + model_group=model_group_name, + destination_collection_name=destination_collection_name, + ) + + assert move_response["moved_success"] == model_group_name + + collection_details = sagemaker_session.list_group_resources( + group=destination_collection_name, filters=collection_filter + ) + + assert len(collection_details["Resources"]) == 1 + + collection_details = sagemaker_session.list_group_resources( + group=source_collection_name, filters=collection_filter + ) + assert len(collection_details["Resources"]) == 0 + + remove_response = collection.remove_model_groups( + collection_name=destination_collection_name, model_groups=model_groups + ) + + assert len(remove_response["failure"]) == 0 + assert len(remove_response["removed_groups"]) == 1 + + finally: + delete_response = collection.delete([source_collection_name, destination_collection_name]) + assert len(delete_response["deleted_collections"]) == 2 + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) def test_list_collection_success(sagemaker_session): @@ -176,23 +184,27 @@ def test_list_collection_success(sagemaker_session): collection = Collection(sagemaker_session) collection_name = unique_name_from_base("test-collection") collection.create(collection_name) - model_groups = [] - model_groups.append(model_group_name) - collection.add_model_groups(collection_name=collection_name, model_groups=model_groups) - child_collection_name = unique_name_from_base("test-collection") - collection.create(parent_collection_name=collection_name, collection_name=child_collection_name) - root_collections = collection.list_collection() - is_collection_found = False - for root_collection in root_collections: - if root_collection["Name"] == collection_name: - is_collection_found = True - assert is_collection_found - - collection_content = collection.list_collection(collection_name) - assert len(collection_content) == 2 - - collection.remove_model_groups(collection_name=collection_name, model_groups=model_groups) - collection.delete([child_collection_name, collection_name]) - sagemaker_session.sagemaker_client.delete_model_package_group( - ModelPackageGroupName=model_group_name - ) + try: + model_groups = [] + model_groups.append(model_group_name) + collection.add_model_groups(collection_name=collection_name, model_groups=model_groups) + child_collection_name = unique_name_from_base("test-collection") + collection.create( + parent_collection_name=collection_name, collection_name=child_collection_name + ) + root_collections = collection.list_collection() + is_collection_found = False + for root_collection in root_collections: + if root_collection["Name"] == collection_name: + is_collection_found = True + assert is_collection_found + + collection_content = collection.list_collection(collection_name) + assert len(collection_content) == 2 + + collection.remove_model_groups(collection_name=collection_name, model_groups=model_groups) + finally: + collection.delete([child_collection_name, collection_name]) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) diff --git a/tests/integ/test_horovod.py b/tests/integ/test_horovod.py index 2ddcdc92e0..78314c2ade 100644 --- a/tests/integ/test_horovod.py +++ b/tests/integ/test_horovod.py @@ -62,11 +62,8 @@ def test_hvd_gpu( tmpdir, **kwargs, ): - if ( - Version(tensorflow_training_latest_version) >= Version("2.12") - and kwargs["instance_type"] == "ml.p2.xlarge" - ): - pytest.skip("P2 instances have been deprecated for sagemaker jobs starting TensorFlow 2.12") + if kwargs["instance_type"] == "ml.p2.xlarge": + pytest.skip("Instance type ml.p2.xlarge has been deprecated") if Version(tensorflow_training_latest_version) >= Version("2.13"): pytest.skip("Horovod is deprecated in TensorFlow 2.13 and above") diff --git a/tests/integ/test_horovod_mx.py b/tests/integ/test_horovod_mx.py index 7bd6a641e0..a238966dd3 100644 --- a/tests/integ/test_horovod_mx.py +++ b/tests/integ/test_horovod_mx.py @@ -58,6 +58,9 @@ def test_hvd_gpu( tmpdir, **kwargs, ): + if kwargs["instance_type"] == "ml.p2.xlarge": + pytest.skip("Instance type ml.p2.xlarge has been deprecated") + _create_and_fit_estimator( mxnet_training_latest_version, mxnet_training_latest_py_version, diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 8c99854d14..0d03aee8ea 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -19,7 +19,6 @@ import pytest from packaging.version import Version -from packaging.specifiers import SpecifierSet from sagemaker import KMeans, s3, get_execution_role from sagemaker.mxnet import MXNet @@ -556,7 +555,7 @@ def test_transform_mxnet_logs( def test_transform_tf_kms_network_isolation( sagemaker_session, cpu_instance_type, tmpdir, tf_full_version, tf_full_py_version ): - if Version(tf_full_version) in SpecifierSet("==2.16.*"): + if Version(tf_full_version) >= Version("2.16"): pytest.skip( "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " "https://github.com/tensorflow/io/issues/2039" diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py index 4c93e18939..5d32030580 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py @@ -75,12 +75,12 @@ def test_constructor_node_should_be_modified(src, expected): ("sagemaker.predictor._NumpyDeserializer()", "deserializers.NumpyDeserializer()"), ("sagemaker.predictor._JsonDeserializer()", "deserializers.JSONDeserializer()"), ( - "sagemaker.amazon.common.numpy_to_record_serializer()", - "sagemaker.amazon.common.RecordSerializer()", + "sagemaker.serializers.numpy_to_record_serializer()", + "sagemaker.serializers.RecordSerializer()", ), ( - "sagemaker.amazon.common.record_deserializer()", - "sagemaker.amazon.common.RecordDeserializer()", + "sagemaker.deserializers.record_deserializer()", + "sagemaker.deserializers.RecordDeserializer()", ), ("_CsvSerializer()", "serializers.CSVSerializer()"), ("_JsonSerializer()", "serializers.JSONSerializer()"), @@ -265,20 +265,12 @@ def test_import_from_amazon_common_node_should_be_modified(import_statement, exp "import_statement, expected", [ ( - "from sagemaker.amazon.common import numpy_to_record_serializer", - "from sagemaker.amazon.common import RecordSerializer", + "from sagemaker.serializers import numpy_to_record_serializer", + "from sagemaker.serializers import RecordSerializer", ), ( - "from sagemaker.amazon.common import record_deserializer", - "from sagemaker.amazon.common import RecordDeserializer", - ), - ( - "from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer", - "from sagemaker.amazon.common import RecordSerializer, RecordDeserializer", - ), - ( - "from sagemaker.amazon.common import write_spmatrix_to_sparse_tensor, numpy_to_record_serializer", - "from sagemaker.amazon.common import write_spmatrix_to_sparse_tensor, RecordSerializer", + "from sagemaker.deserializers import record_deserializer", + "from sagemaker.deserializers import RecordDeserializer", ), ], ) diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py index 57f4a54f78..00bd3ca090 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py @@ -908,6 +908,7 @@ def test_remote_decorator_fields_consistency(get_execution_role, session): "max_wait_time_in_seconds", "custom_file_filter", "use_torchrun", + "use_mpirun", "nproc_per_node", } diff --git a/tests/unit/sagemaker/image_uris/expected_uris.py b/tests/unit/sagemaker/image_uris/expected_uris.py index 01e4d4991f..eb198454fc 100644 --- a/tests/unit/sagemaker/image_uris/expected_uris.py +++ b/tests/unit/sagemaker/image_uris/expected_uris.py @@ -107,3 +107,12 @@ def base_python_uri(repo, account, region=REGION): domain = ALTERNATE_DOMAINS.get(region, DOMAIN) tag = "1.0" return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) + + +def sagemaker_distribution_uri(repo, account, tag, processor, region=REGION): + domain = ALTERNATE_DOMAINS.get(region, DOMAIN) + if processor == "cpu": + tag = f"{tag}-cpu" + else: + tag = f"{tag}-gpu" + return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) diff --git a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py index 28525a390c..0d96417e9f 100644 --- a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py +++ b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py @@ -46,6 +46,8 @@ "2.0.2": "2.3.0-tgi2.0.2-gpu-py310-cu121-ubuntu22.04", "2.2.0": "2.3.0-tgi2.2.0-gpu-py310-cu121-ubuntu22.04-v2.0", "2.3.1": "2.4.0-tgi2.3.1-gpu-py311-cu124-ubuntu22.04", + "2.4.0": "2.4.0-tgi2.4.0-gpu-py311-cu124-ubuntu22.04-v2.2", + "3.0.1": "2.4.0-tgi3.0.1-gpu-py311-cu124-ubuntu22.04-v2.1", }, "inf2": { "0.0.16": "1.13.1-optimum0.0.16-neuronx-py310-ubuntu22.04", @@ -58,6 +60,7 @@ "0.0.23": "2.1.2-optimum0.0.23-neuronx-py310-ubuntu22.04", "0.0.24": "2.1.2-optimum0.0.24-neuronx-py310-ubuntu22.04", "0.0.25": "2.1.2-optimum0.0.25-neuronx-py310-ubuntu22.04", + "0.0.27": "2.1.2-optimum0.0.27-neuronx-py310-ubuntu22.04", }, } diff --git a/tests/unit/sagemaker/image_uris/test_sagemaker_distribution.py b/tests/unit/sagemaker/image_uris/test_sagemaker_distribution.py new file mode 100644 index 0000000000..d339a50b2e --- /dev/null +++ b/tests/unit/sagemaker/image_uris/test_sagemaker_distribution.py @@ -0,0 +1,47 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from sagemaker import image_uris +from tests.unit.sagemaker.image_uris import expected_uris + +INSTANCE_TYPES = {"cpu": "ml.c4.xlarge", "gpu": "ml.p2.xlarge"} + + +def _test_ecr_uri(account, region, version, tag, instance_type, processor): + actual_uri = image_uris.retrieve( + "sagemaker-distribution", region=region, instance_type=instance_type, version=version + ) + expected_uri = expected_uris.sagemaker_distribution_uri( + "sagemaker-distribution-prod", account, tag, processor, region + ) + return expected_uri == actual_uri + + +@pytest.mark.parametrize("load_config", ["sagemaker-distribution.json"], indirect=True) +def test_sagemaker_distribution_ecr_uri(load_config): + VERSIONS = load_config["versions"] + processors = load_config["processors"] + for version in VERSIONS: + SAGEMAKER_DISTRIBUTION_ACCOUNTS = load_config["versions"][version]["registries"] + for region in SAGEMAKER_DISTRIBUTION_ACCOUNTS.keys(): + for processor in processors: + assert _test_ecr_uri( + account=SAGEMAKER_DISTRIBUTION_ACCOUNTS[region], + region=region, + version=version, + tag="3.0.0", + instance_type=INSTANCE_TYPES[processor], + processor=processor, + ) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 59f38bd189..4021599120 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -3059,7 +3059,7 @@ "g4": { "regional_properties": {"image_uri": "$gpu_image_uri"}, "properties": { - "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + "training_artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" }, }, "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, @@ -3135,7 +3135,7 @@ }, "p9": { "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": {"artifact_key": "do/re/mi"}, + "properties": {"training_artifact_key": "do/re/mi"}, }, "m2": { "regional_properties": {"image_uri": "$cpu_image_uri"}, @@ -3214,13 +3214,13 @@ "ml.p9.12xlarge": { "properties": { "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "artifact_key": "you/not/entertained", + "training_artifact_key": "you/not/entertained", } }, "g6": { "properties": { "environment_variables": {"BLAH": "4"}, - "artifact_key": "path/to/training/artifact.tar.gz", + "training_artifact_key": "path/to/training/artifact.tar.gz", "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/", } }, @@ -5046,7 +5046,7 @@ "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, "m5": { "regional_properties": {"image_uri": "$cpu_ecr_uri_1"}, - "properties": {"artifact_key": "hello-world-1"}, + "properties": {"training_artifact_key": "hello-world-1"}, }, "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, @@ -17234,13 +17234,13 @@ "g4dn": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g4dn/v1.0.0/", # noqa: E501 }, }, "g5": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g5/v1.0.0/", # noqa: E501 }, }, "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, @@ -17249,13 +17249,13 @@ "p3dn": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p3dn/v1.0.0/", # noqa: E501 }, }, "p4d": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p4d/v1.0.0/", # noqa: E501 }, }, "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 1fd2a47aca..4a64b413f4 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -392,23 +392,6 @@ def test_gated_model_s3_uri( mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session - with pytest.raises(ValueError) as e: - JumpStartEstimator( - model_id=model_id, - environment={ - "accept_eula": "false", - "what am i": "doing", - "SageMakerGatedModelS3Uri": "none of your business", - }, - ) - assert str(e.value) == ( - "Need to define ‘accept_eula'='true' within Environment. " - "Model 'meta-textgeneration-llama-2-7b-f' requires accepting end-user " - "license agreement (EULA). See " - "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com/fmhMetadata/eula/llamaEula.txt" - " for terms of use." - ) - mock_estimator_init.reset_mock() estimator = JumpStartEstimator(model_id=model_id, environment={"accept_eula": "true"}) @@ -510,6 +493,151 @@ def test_gated_model_s3_uri( ], ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_gated_model_s3_uri_with_eula_in_fit( + self, + mock_estimator_deploy: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session_estimator: mock.Mock, + mock_session_model: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_timestamp: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + + mock_timestamp.return_value = "8675309" + + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + + model_id, _ = "js-gated-artifact-trainable-model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session_estimator.return_value = sagemaker_session + mock_session_model.return_value = sagemaker_session + + mock_estimator_init.reset_mock() + + estimator = JumpStartEstimator(model_id=model_id) + + mock_estimator_init.assert_called_once_with( + instance_type="ml.g5.12xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-" + "pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "meta/transfer_learning/textgeneration/v1.0.6/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={ + "int8_quantization": "False", + "enable_fsdp": "True", + "epoch": "1", + "learning_rate": "0.0001", + "lora_r": "8", + "lora_alpha": "32", + "lora_dropout": "0.05", + "instruction_tuned": "False", + "chat_dataset": "True", + "add_input_output_demarcation_key": "True", + "per_device_train_batch_size": "1", + "per_device_eval_batch_size": "1", + "max_train_samples": "-1", + "max_val_samples": "-1", + "seed": "10", + "max_input_length": "-1", + "validation_split_ratio": "0.2", + "train_data_split_seed": "0", + "preprocessing_num_workers": "None", + }, + metric_definitions=[ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", + }, + ], + role=execution_role, + sagemaker_session=sagemaker_session, + max_run=360000, + enable_network_isolation=True, + encrypt_inter_container_traffic=True, + environment={ + "SageMakerGatedModelS3Uri": "s3://sagemaker-repository-pdx/" + "model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + }, + tags=[ + { + "Key": "sagemaker-sdk:jumpstart-model-id", + "Value": "js-gated-artifact-trainable-model", + }, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.4"}, + ], + ) + + channels = { + "training": f"s3://{get_jumpstart_content_bucket(region)}/" + f"some-training-dataset-doesn't-matter", + } + + estimator.fit(channels, accept_eula=True) + + mock_estimator_fit.assert_called_once_with( + inputs=channels, + wait=True, + job_name="meta-textgeneration-llama-2-7b-f-8675309", + ) + + assert hasattr(estimator, "model_access_config") + assert hasattr(estimator, "hub_access_config") + + assert estimator.model_access_config == {"AcceptEula": True} + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + predictor_cls=Predictor, + endpoint_name="meta-textgeneration-llama-2-7b-f-8675309", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118", + wait=True, + model_data_download_timeout=3600, + container_startup_health_check_timeout=3600, + role=execution_role, + enable_network_isolation=True, + model_name="meta-textgeneration-llama-2-7b-f-8675309", + use_compiled_model=False, + tags=[ + { + "Key": "sagemaker-sdk:jumpstart-model-id", + "Value": "js-gated-artifact-trainable-model", + }, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.4"}, + ], + ) + @mock.patch( "sagemaker.jumpstart.artifacts.environment_variables.get_jumpstart_gated_content_bucket" ) @@ -1218,7 +1346,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): and reach out to JumpStart team.""" init_args_to_skip: Set[str] = set(["kwargs"]) - fit_args_to_skip: Set[str] = set() + fit_args_to_skip: Set[str] = set(["accept_eula"]) deploy_args_to_skip: Set[str] = set(["kwargs"]) parent_class_init = Estimator.__init__ @@ -1243,8 +1371,8 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): js_class_fit = JumpStartEstimator.fit js_class_fit_args = set(signature(js_class_fit).parameters.keys()) - assert js_class_fit_args - parent_class_fit_args == set() - assert parent_class_fit_args - js_class_fit_args == fit_args_to_skip + assert js_class_fit_args - parent_class_fit_args == fit_args_to_skip + assert parent_class_fit_args - js_class_fit_args == set() model_class_init = Model.__init__ model_class_init_args = set(signature(model_class_init).parameters.keys()) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py index 8522b33bc3..06f5473322 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_hub.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -192,6 +192,39 @@ def test_describe_model_success(mock_describe_hub_content_response, sagemaker_se ) +@patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json") +def test_describe_model_one_thrown_error(mock_describe_hub_content_response, sagemaker_session): + mock_describe_hub_content_response.return_value = Mock() + mock_list_hub_content_versions = sagemaker_session.list_hub_content_versions + mock_list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0"}, + {"HubContentVersion": "2.0"}, + {"HubContentVersion": "3.0"}, + ] + } + mock_describe_hub_content = sagemaker_session.describe_hub_content + mock_describe_hub_content.side_effect = [ + Exception("Some exception"), + {"HubContentName": "test-model", "HubContentVersion": "3.0"}, + ] + + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + + with patch("sagemaker.jumpstart.hub.utils.get_hub_model_version") as mock_get_hub_model_version: + mock_get_hub_model_version.return_value = "3.0" + + hub.describe_model("test-model") + + mock_describe_hub_content.asssert_called_times(2) + mock_describe_hub_content.assert_called_with( + hub_name=HUB_NAME, + hub_content_name="test-model", + hub_content_version="3.0", + hub_content_type="Model", + ) + + def test_create_hub_content_reference(sagemaker_session): hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) model_name = "mock-model-one-huggingface" diff --git a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py index 11798bc854..ebd90d98d2 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py @@ -923,15 +923,13 @@ def test_hub_content_document_from_json_obj(): "g4dn": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-" - "huggingface-llm-gemma-2b-instruct.tar.gz", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g4dn/v1.0.0/", # noqa: E501 }, }, "g5": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-" - "huggingface-llm-gemma-2b-instruct.tar.gz", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g5/v1.0.0/", # noqa: E501 }, }, "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, @@ -940,15 +938,13 @@ def test_hub_content_document_from_json_obj(): "p3dn": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-" - "huggingface-llm-gemma-2b-instruct.tar.gz", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p3dn/v1.0.0/", # noqa: E501 }, }, "p4d": { "properties": { "image_uri": "$gpu_ecr_uri_1", - "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-" - "huggingface-llm-gemma-2b-instruct.tar.gz", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p4d/v1.0.0/", # noqa: E501 }, }, "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index 6dbb1340f4..a0b824fc9b 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -14,7 +14,10 @@ from unittest.mock import patch, Mock from sagemaker.jumpstart.types import HubArnExtractedInfo -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.constants import ( + JUMPSTART_DEFAULT_REGION_NAME, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) from sagemaker.jumpstart.hub import parser_utils, utils @@ -80,6 +83,17 @@ def test_construct_hub_arn_from_name(): ) +def test_construct_hub_arn_from_name_with_session_none(): + hub_name = "my-cool-hub" + account_id = DEFAULT_JUMPSTART_SAGEMAKER_SESSION.account_id() + boto_region_name = DEFAULT_JUMPSTART_SAGEMAKER_SESSION.boto_region_name + + assert ( + utils.construct_hub_arn_from_name(hub_name=hub_name, session=None) + == f"arn:aws:sagemaker:{boto_region_name}:{account_id}:hub/{hub_name}" + ) + + def test_construct_hub_model_arn_from_inputs(): model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index e687a1c4ac..75aa93a920 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -176,7 +176,7 @@ def test_retrieve_training_artifact_key(self): "image_uri": "$alias_ecr_uri_1", }, "properties": { - "artifact_key": "in/the/way", + "training_artifact_key": "in/the/way", }, } }, diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index da20debc6a..b7edc124d3 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -22,10 +22,14 @@ from mock.mock import MagicMock import pytest from mock import patch +from packaging.version import Version + +from sagemaker.jumpstart import utils from sagemaker.jumpstart.cache import ( JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JumpStartModelsCache, ) from sagemaker.jumpstart.constants import ( @@ -33,6 +37,7 @@ ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, ) from sagemaker.jumpstart.types import ( + JumpStartCachedContentValue, JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, @@ -53,6 +58,25 @@ from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region") +@patch( + "sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket" +) +@patch("boto3.client") +def test_jumpstart_cache_init(mock_boto3_client): + cache = JumpStartModelsCache() + assert cache._region == "dummy-region" + assert cache.s3_bucket_name == "dummy-bucket" + assert cache._manifest_file_s3_key == JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY + assert cache._proprietary_manifest_s3_key == JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY + assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION + mock_boto3_client.assert_called_once_with("s3", region_name="dummy-region") + + # Some callers override the session to None, should still be set to default + cache = JumpStartModelsCache(sagemaker_session=None) + assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION + + @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_header(): @@ -1119,3 +1143,124 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( ), ] ) + + +@patch.object(JumpStartModelsCache, "_retrieval_function") +def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights( + retrieval_function: Mock, +): + sm_version = Version(utils.get_sagemaker_version()) + new_sm_version = Version(str(sm_version.major + 1) + ".0.0") + print(str(new_sm_version)) + versions = ["1.0.0", "2.9.1", "2.16.0"] + manifest = [ + { + "model_id": "test-model", + "version": version, + "min_version": "2.49.0", + "spec_key": "spec_key", + } + for version in versions + ] + + manifest.append( + { + "model_id": "test-model", + "version": "3.0.0", + "min_version": str(new_sm_version), + "spec_key": "spec_key", + } + ) + + manifest_dict = {} + for header in manifest: + header_obj = JumpStartModelHeader(header) + manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = ( + header_obj + ) + retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict) + key = JumpStartVersionedModelId("test-model", "*") + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None) + + assert_key = JumpStartVersionedModelId("test-model", "2.16.0") + + assert result == assert_key + + +@patch.object(JumpStartModelsCache, "_retrieval_function") +def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights( + retrieval_function: Mock, +): + sm_version = Version(utils.get_sagemaker_version()) + new_sm_version = Version(str(sm_version.major + 1) + ".0.0") + print(str(new_sm_version)) + versions = ["1.0.0", "2.9.1", "2.16.0"] + manifest = [ + { + "model_id": "test-model", + "version": version, + "min_version": "2.49.0", + "spec_key": "spec_key", + } + for version in versions + ] + + manifest.append( + { + "model_id": "test-model", + "version": "3.0.0", + "min_version": str(new_sm_version), + "spec_key": "spec_key", + } + ) + + manifest_dict = {} + for header in manifest: + header_obj = JumpStartModelHeader(header) + manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = ( + header_obj + ) + retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict) + key = JumpStartVersionedModelId("test-model", "*") + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + result = cache._get_proprietary_manifest_key_from_model_id(key=key, value=None) + + assert_key = JumpStartVersionedModelId("test-model", "2.16.0") + + assert result == assert_key + + +@patch.object(JumpStartModelsCache, "_retrieval_function") +def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(retrieval_function: Mock): + sm_version = Version(utils.get_sagemaker_version()) + new_sm_version = Version(str(sm_version.major + 1) + ".0.0") + print(str(new_sm_version)) + versions = ["abc", "2.9.1", "2.16.0"] + manifest = [ + { + "model_id": "test-model", + "version": version, + "min_version": "2.49.0", + "spec_key": "spec_key", + } + for version in versions + ] + + manifest_dict = {} + for header in manifest: + header_obj = JumpStartModelHeader(header) + manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = ( + header_obj + ) + retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict) + key = JumpStartVersionedModelId("test-model", "*") + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None) + + assert_key = JumpStartVersionedModelId("test-model", "abc") + + assert result == assert_key diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 3efa8c8c81..acce8ef4f1 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -117,7 +117,7 @@ "g4": { "regional_properties": {"image_uri": "$gpu_image_uri"}, "properties": { - "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + "training_artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" }, }, "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, @@ -193,7 +193,7 @@ }, "p9": { "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": {"artifact_key": "do/re/mi"}, + "properties": {"training_artifact_key": "do/re/mi"}, }, "m2": { "regional_properties": {"image_uri": "$cpu_image_uri"}, @@ -272,13 +272,13 @@ "ml.p9.12xlarge": { "properties": { "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "artifact_key": "you/not/entertained", + "training_artifact_key": "you/not/entertained", } }, "g6": { "properties": { "environment_variables": {"BLAH": "4"}, - "artifact_key": "path/to/training/artifact.tar.gz", + "training_artifact_key": "path/to/training/artifact.tar.gz", "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/", } }, @@ -952,27 +952,35 @@ def test_jumpstart_hosting_prepacked_artifact_key_instance_variants(): def test_jumpstart_training_artifact_key_instance_variants(): assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.g6.xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.g6.xlarge" + ) == "path/to/training/artifact.tar.gz" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.g4.9xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.g4.9xlarge" + ) == "path/to/prepacked/training/artifact/prefix/number2/" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.p9.9xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.p9.9xlarge" + ) == "do/re/mi" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.p9.12xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.p9.12xlarge" + ) == "you/not/entertained" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key( + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( instance_type="ml.g9dsfsdfs.12xlarge" ) is None diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 7cf8fdc9b6..ea4d64f289 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -2144,6 +2144,22 @@ def test_has_instance_rate_stat(stats, expected): assert utils.has_instance_rate_stat(stats) is expected +def test_get_latest_version(): + assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0"]) == "2.16.0" + + +def test_get_latest_version_empty_list_is_none(): + assert utils.get_latest_version([]) is None + + +def test_get_latest_version_none_is_none(): + assert utils.get_latest_version(None) is None + + +def test_get_latest_version_with_invalid_sem_ver(): + assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0", "abc"]) == "abc" + + @pytest.mark.parametrize( "data, expected", [(None, []), ([], []), (get_base_deployment_configs_metadata(), get_base_deployment_configs())], diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 6bfb28f684..7b99281b96 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -130,6 +130,7 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem model_data_download_timeout=None, container_startup_health_check_timeout=None, routing_config=None, + inference_ami_version=None, ) sagemaker_session.create_model.assert_called_with( @@ -192,6 +193,7 @@ def test_deploy_accelerator_type( model_data_download_timeout=None, container_startup_health_check_timeout=None, routing_config=None, + inference_ami_version=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -519,6 +521,7 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model, model_data_download_timeout=None, container_startup_health_check_timeout=None, routing_config=None, + inference_ami_version=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -956,6 +959,7 @@ def test_deploy_customized_volume_size_and_timeout( model_data_download_timeout=model_data_download_timeout_sec, container_startup_health_check_timeout=startup_health_check_timeout_sec, routing_config=None, + inference_ami_version=None, ) sagemaker_session.create_model.assert_called_with( @@ -1006,6 +1010,7 @@ def test_deploy_with_resources(sagemaker_session, name_from_base, production_var model_data_download_timeout=None, container_startup_health_check_timeout=None, routing_config=None, + inference_ami_version=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( name=name_from_base(MODEL_NAME), diff --git a/tests/unit/sagemaker/model/test_framework_model.py b/tests/unit/sagemaker/model/test_framework_model.py index d41dd6f821..432d90bd37 100644 --- a/tests/unit/sagemaker/model/test_framework_model.py +++ b/tests/unit/sagemaker/model/test_framework_model.py @@ -511,6 +511,20 @@ def test_is_repack_with_code_location(repack_model, sagemaker_session): assert not model.is_repack() +@patch("sagemaker.utils.repack_model") +def test_is_repack_with_none_type(repack_model, sagemaker_session): + """Test is_repack() returns a boolean value when source_dir and entry_point are None""" + + model = FrameworkModel( + role=ROLE, + sagemaker_session=sagemaker_session, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + + assert model.is_repack() is False + + @patch("sagemaker.git_utils.git_clone_repo") @patch("sagemaker.model.fw_utils.tar_and_upload_dir") def test_is_repack_with_git_config(tar_and_upload_dir, git_clone_repo, sagemaker_session): diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 9175613662..3d498dfc59 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -1046,6 +1046,20 @@ def test_is_repack_with_code_location(repack_model, sagemaker_session): assert model.is_repack() +@patch("sagemaker.utils.repack_model") +def test_is_repack_with_none_type(repack_model, sagemaker_session): + """Test is_repack() returns a boolean value when source_dir and entry_point are None""" + + model = Model( + role=ROLE, + sagemaker_session=sagemaker_session, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + + assert model.is_repack() is False + + @patch("sagemaker.git_utils.git_clone_repo") @patch("sagemaker.model.fw_utils.tar_and_upload_dir") def test_is_repack_with_git_config(tar_and_upload_dir, git_clone_repo, sagemaker_session): diff --git a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py index 30d6dfdf6c..fe4fa08825 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py @@ -21,12 +21,10 @@ from sagemaker.modules.train.container_drivers.scripts.environment import ( set_env, - log_key_value, log_env_variables, - mask_sensitive_info, HIDDEN_VALUE, ) -from sagemaker.modules.train.container_drivers.utils import safe_serialize, safe_deserialize +from sagemaker.modules.train.container_drivers.common.utils import safe_serialize, safe_deserialize RESOURCE_CONFIG = dict( current_host="algo-1", @@ -75,6 +73,15 @@ }, } +SOURCE_CODE = { + "source_dir": "code", + "entry_script": "train.py", +} + +DISTRIBUTED_CONFIG = { + "process_count_per_node": 2, +} + OUTPUT_FILE = os.path.join(os.path.dirname(__file__), "sm_training.env") # flake8: noqa @@ -89,6 +96,10 @@ export SM_LOG_LEVEL='20' export SM_MASTER_ADDR='algo-1' export SM_MASTER_PORT='7777' +export SM_SOURCE_DIR='/opt/ml/input/data/code' +export SM_ENTRY_SCRIPT='train.py' +export SM_DISTRIBUTED_DRIVER_DIR='/opt/ml/input/data/sm_drivers/distributed_drivers' +export SM_DISTRIBUTED_CONFIG='{"process_count_per_node": 2}' export SM_CHANNEL_TRAIN='/opt/ml/input/data/train' export SM_CHANNEL_VALIDATION='/opt/ml/input/data/validation' export SM_CHANNELS='["train", "validation"]' @@ -112,6 +123,14 @@ """ +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.read_source_code_json", + return_value=SOURCE_CODE, +) +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.read_distributed_json", + return_value=DISTRIBUTED_CONFIG, +) @patch("sagemaker.modules.train.container_drivers.scripts.environment.num_cpus", return_value=8) @patch("sagemaker.modules.train.container_drivers.scripts.environment.num_gpus", return_value=0) @patch("sagemaker.modules.train.container_drivers.scripts.environment.num_neurons", return_value=0) @@ -124,7 +143,13 @@ side_effect=safe_deserialize, ) def test_set_env( - mock_safe_deserialize, mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons + mock_safe_deserialize, + mock_safe_serialize, + mock_num_neurons, + mock_num_gpus, + mock_num_cpus, + mock_read_distributed_json, + mock_read_source_code_json, ): with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): set_env( @@ -137,6 +162,8 @@ def test_set_env( mock_num_cpus.assert_called_once() mock_num_gpus.assert_called_once() mock_num_neurons.assert_called_once() + mock_read_distributed_json.assert_called_once() + mock_read_source_code_json.assert_called_once() with open(OUTPUT_FILE, "r") as f: env_file = f.read().strip() diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py index a1a84da1ab..bf51db8285 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py @@ -15,13 +15,14 @@ import os import sys +import json from unittest.mock import patch, MagicMock sys.modules["utils"] = MagicMock() sys.modules["mpi_utils"] = MagicMock() -from sagemaker.modules.train.container_drivers import mpi_driver # noqa: E402 +from sagemaker.modules.train.container_drivers.distributed_drivers import mpi_driver # noqa: E402 DUMMY_MPI_COMMAND = [ @@ -40,12 +41,7 @@ "script.py", ] -DUMMY_SOURCE_CODE = { - "source_code": "source_code", - "entry_script": "script.py", -} DUMMY_DISTRIBUTED = { - "_type": "mpi", "process_count_per_node": 2, "mpi_additional_options": [ "--verbose", @@ -62,17 +58,28 @@ "SM_HOSTS": '["algo-1", "algo-2"]', "SM_MASTER_ADDR": "algo-1", "SM_HOST_COUNT": "2", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "/opt/ml/input/data/code/script.py", }, ) -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_master_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_worker_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.hyperparameters_to_cli_args") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_mpirun_command") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.execute_commands") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands") def test_mpi_driver_worker( mock_execute_commands, mock_get_mpirun_command, @@ -81,12 +88,8 @@ def test_mpi_driver_worker( mock_bootstrap_master_node, mock_start_sshd_daemon, mock_write_env_vars_to_file, - mock_read_source_code_json, - mock_read_distributed_json, ): mock_hyperparameters_to_cli_args.return_value = [] - mock_read_source_code_json.return_value = DUMMY_SOURCE_CODE - mock_read_distributed_json.return_value = DUMMY_DISTRIBUTED mpi_driver.main() @@ -106,19 +109,32 @@ def test_mpi_driver_worker( "SM_HOSTS": '["algo-1", "algo-2"]', "SM_MASTER_ADDR": "algo-1", "SM_HOST_COUNT": "2", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", }, ) -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_master_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_worker_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_process_count") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.hyperparameters_to_cli_args") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_mpirun_command") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.execute_commands") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_status_file_to_workers") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_process_count") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_status_file_to_workers" +) def test_mpi_driver_master( mock_write_status_file_to_workers, mock_execute_commands, @@ -129,12 +145,8 @@ def test_mpi_driver_master( mock_bootstrap_master_node, mock_start_sshd_daemon, mock_write_env_vars_to_file, - mock_read_source_code_config_json, - mock_read_distributed_json, ): mock_hyperparameters_to_cli_args.return_value = [] - mock_read_source_code_config_json.return_value = DUMMY_SOURCE_CODE - mock_read_distributed_json.return_value = DUMMY_DISTRIBUTED mock_get_mpirun_command.return_value = DUMMY_MPI_COMMAND mock_get_process_count.return_value = 2 mock_execute_commands.return_value = (0, "") diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py new file mode 100644 index 0000000000..35208d708a --- /dev/null +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py @@ -0,0 +1,113 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""MPI Utils Unit Tests.""" +from __future__ import absolute_import + +import subprocess +from unittest.mock import Mock, patch + +import paramiko +import pytest + +# Mock the utils module before importing mpi_utils +mock_utils = Mock() +mock_utils.logger = Mock() +mock_utils.SM_EFA_NCCL_INSTANCES = [] +mock_utils.SM_EFA_RDMA_INSTANCES = [] +mock_utils.get_python_executable = Mock(return_value="/usr/bin/python") + +with patch.dict("sys.modules", {"utils": mock_utils}): + from sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils import ( + CustomHostKeyPolicy, + _can_connect, + write_status_file_to_workers, + ) + +TEST_HOST = "algo-1" +TEST_WORKER = "algo-2" +TEST_STATUS_FILE = "/tmp/test-status" + + +def test_custom_host_key_policy_valid_hostname(): + """Test CustomHostKeyPolicy accepts algo- prefixed hostnames.""" + policy = CustomHostKeyPolicy() + mock_client = Mock() + mock_key = Mock() + mock_key.get_name.return_value = "ssh-rsa" + + policy.missing_host_key(mock_client, "algo-1", mock_key) + + mock_client.get_host_keys.assert_called_once() + mock_client.get_host_keys().add.assert_called_once_with("algo-1", "ssh-rsa", mock_key) + + +def test_custom_host_key_policy_invalid_hostname(): + """Test CustomHostKeyPolicy rejects non-algo prefixed hostnames.""" + policy = CustomHostKeyPolicy() + mock_client = Mock() + mock_key = Mock() + + with pytest.raises(paramiko.SSHException) as exc_info: + policy.missing_host_key(mock_client, "invalid-1", mock_key) + + assert "Unknown host key for invalid-1" in str(exc_info.value) + mock_client.get_host_keys.assert_not_called() + + +@patch("paramiko.SSHClient") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils.logger") +def test_can_connect_success(mock_logger, mock_ssh_client): + """Test successful SSH connection.""" + mock_client = Mock() + mock_ssh_client.return_value.__enter__.return_value = mock_client + mock_client.connect.return_value = None # Successful connection + + result = _can_connect(TEST_HOST) + + assert result is True + mock_client.load_system_host_keys.assert_called_once() + mock_client.set_missing_host_key_policy.assert_called_once() + mock_client.connect.assert_called_once_with(TEST_HOST, port=22) + + +@patch("paramiko.SSHClient") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils.logger") +def test_can_connect_failure(mock_logger, mock_ssh_client): + """Test SSH connection failure.""" + mock_client = Mock() + mock_ssh_client.return_value.__enter__.return_value = mock_client + mock_client.connect.side_effect = paramiko.SSHException("Connection failed") + + result = _can_connect(TEST_HOST) + + assert result is False + mock_client.load_system_host_keys.assert_called_once() + mock_client.set_missing_host_key_policy.assert_called_once() + mock_client.connect.assert_called_once_with(TEST_HOST, port=22) + + +@patch("subprocess.run") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils.logger") +def test_write_status_file_to_workers_failure(mock_logger, mock_run): + """Test failed status file writing to workers with retry timeout.""" + mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") + + with pytest.raises(TimeoutError) as exc_info: + write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE) + + assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value) + assert mock_run.call_count > 1 # Verifies that retries occurred + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py index 4cff07a0c0..2568346158 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py @@ -15,38 +15,38 @@ import os import sys +import json from unittest.mock import patch, MagicMock sys.modules["utils"] = MagicMock() -from sagemaker.modules.train.container_drivers import torchrun_driver # noqa: E402 - -DUMMY_SOURCE_CODE = { - "source_code": "source_code", - "entry_script": "script.py", -} +from sagemaker.modules.train.container_drivers.distributed_drivers import ( # noqa: E402 + torchrun_driver, +) -DUMMY_distributed = {"_type": "torchrun", "process_count_per_node": 2} +DUMMY_DISTRIBUTED = {"process_count_per_node": 2} @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_python_executable", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_python_executable", return_value="python3", ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), ) def test_get_base_pytorch_command_torchrun(mock_pytorch_version, mock_get_python_executable): assert torchrun_driver.get_base_pytorch_command() == ["torchrun"] @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_python_executable", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_python_executable", return_value="python3", ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(1, 8) + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", + return_value=(1, 8), ) def test_get_base_pytorch_command_torch_distributed_launch( mock_pytorch_version, mock_get_python_executable @@ -62,38 +62,29 @@ def test_get_base_pytorch_command_torch_distributed_launch( "SM_CURRENT_INSTANCE_TYPE": "ml.p4d.24xlarge", "SM_NETWORK_INTERFACE_NAME": "eth0", "SM_HOST_COUNT": "1", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", }, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.USER_CODE_PATH", - "/opt/ml/input/data/code", -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_process_count", return_value=2 + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_process_count", + return_value=2, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_base_pytorch_command", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_base_pytorch_command", return_value=["torchrun"], ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_source_code_json", - return_value=DUMMY_SOURCE_CODE, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_json", - return_value=DUMMY_distributed, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.hyperparameters_to_cli_args", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.hyperparameters_to_cli_args", return_value=[], ) def test_create_commands_single_node( mock_hyperparameters_to_cli_args, - mock_read_distributed_json, - mock_read_source_code_json, mock_get_base_pytorch_command, mock_pytorch_version, mock_get_process_count, @@ -102,7 +93,7 @@ def test_create_commands_single_node( "torchrun", "--nnodes=1", "--nproc_per_node=2", - "/opt/ml/input/data/code/script.py", + "script.py", ] command = torchrun_driver.create_commands() @@ -118,38 +109,29 @@ def test_create_commands_single_node( "SM_MASTER_ADDR": "algo-1", "SM_MASTER_PORT": "7777", "SM_CURRENT_HOST_RANK": "0", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", }, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.USER_CODE_PATH", - "/opt/ml/input/data/code", -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_process_count", return_value=2 + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_process_count", + return_value=2, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_base_pytorch_command", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_base_pytorch_command", return_value=["torchrun"], ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_source_code_json", - return_value=DUMMY_SOURCE_CODE, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_json", - return_value=DUMMY_distributed, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.hyperparameters_to_cli_args", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.hyperparameters_to_cli_args", return_value=[], ) def test_create_commands_multi_node( mock_hyperparameters_to_cli_args, - mock_read_distributed_json, - mock_read_source_code_json, mock_get_base_pytorch_command, mock_pytorch_version, mock_get_process_count, @@ -161,7 +143,7 @@ def test_create_commands_multi_node( "--master_addr=algo-1", "--master_port=7777", "--node_rank=0", - "/opt/ml/input/data/code/script.py", + "script.py", ] command = torchrun_driver.create_commands() diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py index aba97996b0..beff06e8d8 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py @@ -12,11 +12,13 @@ # language governing permissions and limitations under the License. """Container Utils Unit Tests.""" from __future__ import absolute_import +import os -from sagemaker.modules.train.container_drivers.utils import ( +from sagemaker.modules.train.container_drivers.common.utils import ( safe_deserialize, safe_serialize, hyperparameters_to_cli_args, + get_process_count, ) SM_HPS = { @@ -119,3 +121,18 @@ def test_safe_serialize_empty_data(): assert safe_serialize("") == "" assert safe_serialize([]) == "[]" assert safe_serialize({}) == "{}" + + +def test_get_process_count(): + assert get_process_count() == 1 + assert get_process_count(2) == 2 + os.environ["SM_NUM_GPUS"] = "4" + assert get_process_count() == 4 + os.environ["SM_NUM_GPUS"] = "0" + os.environ["SM_NUM_NEURONS"] = "8" + assert get_process_count() == 8 + os.environ["SM_NUM_NEURONS"] = "0" + assert get_process_count() == 1 + del os.environ["SM_NUM_GPUS"] + del os.environ["SM_NUM_NEURONS"] + assert get_process_count() == 1 diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py index 66eafab4f0..f5f7ceb083 100644 --- a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -26,6 +26,7 @@ _load_recipes_cfg, _configure_gpu_args, _configure_trainium_args, + _get_trainining_recipe_gpu_model_name_and_script, ) from sagemaker.modules.utils import _run_clone_command_silent from sagemaker.modules.configs import Compute @@ -178,3 +179,37 @@ def test_get_args_from_recipe_compute( assert mock_gpu_args.call_count == 0 assert mock_trainium_args.call_count == 0 assert args is None + + @pytest.mark.parametrize( + "test_case", + [ + { + "model_type": "llama_v3", + "script": "llama_pretrain.py", + "model_base_name": "llama_v3", + }, + { + "model_type": "mistral", + "script": "mistral_pretrain.py", + "model_base_name": "mistral", + }, + { + "model_type": "deepseek_llamav3", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + { + "model_type": "deepseek_qwenv2", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + ], + ) + def test_get_trainining_recipe_gpu_model_name_and_script(test_case): + model_type = test_case["model_type"] + script = test_case["script"] + model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script( + model_type, script + ) + assert model_base_name == test_case["model_base_name"] + assert script == test_case["script"] diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 093da20ab8..770420c354 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -17,8 +17,10 @@ import tempfile import json import os +import yaml import pytest -from unittest.mock import patch, MagicMock, ANY +from pydantic import ValidationError +from unittest.mock import patch, MagicMock, ANY, mock_open from sagemaker import image_uris from sagemaker_core.main.resources import TrainingJob @@ -65,7 +67,7 @@ ) from sagemaker.modules.distributed import Torchrun, SMP, MPI from sagemaker.modules.train.sm_recipes.utils import _load_recipes_cfg -from sagemaker.modules.templates import EXEUCTE_TORCHRUN_DRIVER, EXECUTE_MPI_DRIVER +from sagemaker.modules.templates import EXEUCTE_DISTRIBUTED_DRIVER from tests.unit import DATA_DIR DEFAULT_BASE_NAME = "dummy-image-job" @@ -410,7 +412,9 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ { "source_code": DEFAULT_SOURCE_CODE, "distributed": Torchrun(), - "expected_template": EXEUCTE_TORCHRUN_DRIVER, + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="Torchrun", driver_script="torchrun_driver.py" + ), "expected_hyperparameters": {}, }, { @@ -423,7 +427,9 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ tensor_parallel_degree=5, ) ), - "expected_template": EXEUCTE_TORCHRUN_DRIVER, + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="Torchrun", driver_script="torchrun_driver.py" + ), "expected_hyperparameters": { "mp_parameters": json.dumps( { @@ -438,9 +444,11 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ { "source_code": DEFAULT_SOURCE_CODE, "distributed": MPI( - custom_mpi_options=["-x", "VAR1", "-x", "VAR2"], + mpi_additional_options=["-x", "VAR1", "-x", "VAR2"], + ), + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="MPI", driver_script="mpi_driver.py" ), - "expected_template": EXECUTE_MPI_DRIVER, "expected_hyperparameters": {}, }, ], @@ -497,21 +505,15 @@ def test_train_with_distributed_config( assert os.path.exists(expected_runner_json_path) with open(expected_runner_json_path, "r") as f: runner_json_content = f.read() - assert test_case["distributed"].model_dump(exclude_none=True) == ( - json.loads(runner_json_content) - ) + assert test_case["distributed"].model_dump() == (json.loads(runner_json_content)) assert os.path.exists(expected_source_code_json_path) with open(expected_source_code_json_path, "r") as f: source_code_json_content = f.read() - assert test_case["source_code"].model_dump(exclude_none=True) == ( - json.loads(source_code_json_content) - ) + assert test_case["source_code"].model_dump() == (json.loads(source_code_json_content)) assert os.path.exists(expected_source_code_json_path) with open(expected_source_code_json_path, "r") as f: source_code_json_content = f.read() - assert test_case["source_code"].model_dump(exclude_none=True) == ( - json.loads(source_code_json_content) - ) + assert test_case["source_code"].model_dump() == (json.loads(source_code_json_content)) finally: shutil.rmtree(tmp_dir.name) assert not os.path.exists(tmp_dir.name) @@ -1059,3 +1061,126 @@ def mock_upload_data(path, bucket, key_prefix): hyper_parameters=hyperparameters, environment=environment, ) + + +def test_safe_configs(): + # Test extra fails + with pytest.raises(ValueError): + SourceCode(entry_point="train.py") + # Test invalid type fails + with pytest.raises(ValueError): + SourceCode(entry_script=1) + + +@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory") +def test_destructor_cleanup(mock_tmp_dir, modules_session): + + with pytest.raises(ValidationError): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute="test", + ) + mock_tmp_dir.cleanup.assert_not_called() + + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + ) + model_trainer._temp_recipe_train_dir = mock_tmp_dir + mock_tmp_dir.assert_not_called() + del model_trainer + mock_tmp_dir.cleanup.assert_called_once() + + +@patch("os.path.exists") +def test_hyperparameters_valid_json(mock_exists, modules_session): + mock_exists.return_value = True + expected_hyperparameters = {"param1": "value1", "param2": 2} + mock_file_open = mock_open(read_data=json.dumps(expected_hyperparameters)) + + with patch("builtins.open", mock_file_open): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.json", + ) + assert model_trainer.hyperparameters == expected_hyperparameters + mock_file_open.assert_called_once_with("hyperparameters.json", "r") + mock_exists.assert_called_once_with("hyperparameters.json") + + +@patch("os.path.exists") +def test_hyperparameters_valid_yaml(mock_exists, modules_session): + mock_exists.return_value = True + expected_hyperparameters = {"param1": "value1", "param2": 2} + mock_file_open = mock_open(read_data=yaml.dump(expected_hyperparameters)) + + with patch("builtins.open", mock_file_open): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) + assert model_trainer.hyperparameters == expected_hyperparameters + mock_file_open.assert_called_once_with("hyperparameters.yaml", "r") + mock_exists.assert_called_once_with("hyperparameters.yaml") + + +def test_hyperparameters_not_exist(modules_session): + with pytest.raises(ValueError): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="nonexistent.json", + ) + + +@patch("os.path.exists") +def test_hyperparameters_invalid(mock_exists, modules_session): + mock_exists.return_value = True + + # YAML contents must be a valid mapping + mock_file_open = mock_open(read_data="- item1\n- item2") + with patch("builtins.open", mock_file_open): + with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) + + # YAML contents must be a valid mapping + mock_file_open = mock_open(read_data="invalid") + with patch("builtins.open", mock_file_open): + with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) + + # Must be valid YAML + mock_file_open = mock_open(read_data="* invalid") + with patch("builtins.open", mock_file_open): + with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index d31b9f8527..b338885491 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -73,6 +73,7 @@ LINFINITY_METHOD = "LInfinity" CRON_DAILY = CronExpressionGenerator.daily() +CRON_NOW = CronExpressionGenerator.now() BASELINING_JOB_NAME = "baselining-job" BASELINE_DATASET_PATH = "/my/local/path/baseline.csv" PREPROCESSOR_PATH = "/my/local/path/preprocessor.py" @@ -1136,6 +1137,36 @@ def _test_data_quality_monitor_update_schedule(data_quality_monitor, sagemaker_s sagemaker_session.sagemaker_client.delete_data_quality_job_definition.assert_not_called() sagemaker_session.sagemaker_client.create_data_quality_job_definition.assert_not_called() + # update schedule + sagemaker_session.describe_monitoring_schedule = MagicMock() + sagemaker_session.sagemaker_client.describe_data_quality_job_definition = MagicMock() + sagemaker_session.sagemaker_client.create_data_quality_job_definition = MagicMock() + + # Test updating monitoring schedule with schedule_cron_expression set to NOW + sagemaker_session.sagemaker_client.update_monitoring_schedule = Mock() + data_quality_monitor.update_monitoring_schedule( + data_analysis_start_time="-PT24H", + data_analysis_end_time="-PT0H", + schedule_cron_expression=CRON_NOW, + ) + + sagemaker_session.sagemaker_client.update_monitoring_schedule.assert_called_once_with( + MonitoringScheduleName=data_quality_monitor.monitoring_schedule_name, + MonitoringScheduleConfig={ + "MonitoringJobDefinitionName": data_quality_monitor.job_definition_name, + "MonitoringType": DefaultModelMonitor.monitoring_type(), + "ScheduleConfig": { + "ScheduleExpression": CRON_NOW, + "DataAnalysisStartTime": "-PT24H", + "DataAnalysisEndTime": "-PT0H", + }, + }, + ) + + # A new data quality job definition should be created + sagemaker_session.sagemaker_client.describe_data_quality_job_definition.assert_called_once() + sagemaker_session.sagemaker_client.create_data_quality_job_definition.assert_called_once() + # update one property of job definition time.sleep( 0.001 diff --git a/tests/unit/sagemaker/remote_function/runtime_environment/test_mpi_utils.py b/tests/unit/sagemaker/remote_function/runtime_environment/test_mpi_utils.py new file mode 100644 index 0000000000..aa983141ae --- /dev/null +++ b/tests/unit/sagemaker/remote_function/runtime_environment/test_mpi_utils.py @@ -0,0 +1,125 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""MPI Utils Unit Tests.""" +from __future__ import absolute_import + +import os +from mock import patch + +import sagemaker.remote_function.runtime_environment.mpi_utils_remote as mpi_utils_remote # noqa: E402 + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-1", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +def test_mpi_utils_main_job_start( + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main() + + mock_start_sshd_daemon.assert_called_once() + mock_bootstrap_worker_node.assert_not_called() + mock_bootstrap_master_node.assert_called_once() + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-2", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +def test_mpi_utils_worker_job_start( + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main() + + mock_start_sshd_daemon.assert_called_once() + mock_bootstrap_worker_node.assert_called_once() + mock_bootstrap_master_node.assert_not_called() + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-1", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +@patch( + "sagemaker.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers" +) +def test_mpi_utils_main_job_end( + mock_write_status_file_to_workers, + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main(["--job_ended", "1"]) + + mock_start_sshd_daemon.assert_not_called() + mock_bootstrap_worker_node.assert_not_called() + mock_bootstrap_master_node.assert_not_called() + mock_write_status_file_to_workers.assert_called_once() + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-2", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +@patch( + "sagemaker.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers" +) +def test_mpi_utils_worker_job_end( + mock_write_status_file_to_workers, + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main(["--job_ended", "1"]) + + mock_start_sshd_daemon.assert_not_called() + mock_bootstrap_worker_node.assert_not_called() + mock_bootstrap_master_node.assert_not_called() + mock_write_status_file_to_workers.assert_not_called() diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py index 536bfdfca7..6c2a373dbc 100644 --- a/tests/unit/sagemaker/remote_function/test_client.py +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -1505,6 +1505,7 @@ def test_consistency_between_remote_and_step_decorator(): "s3_root_uri", "sagemaker_session", "use_torchrun", + "use_mpirun", "nproc_per_node", ] diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index c7d35b6481..671f091d02 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -96,8 +96,6 @@ export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.t3.xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.t3.xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' export SM_NPROC_PER_NODE='4' export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.t3.xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 4, "num_gpus": 0, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.t3.xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.t3.xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' -export NCCL_SOCKET_IFNAME='eth0' -export NCCL_PROTO='simple' """ # flake8: noqa @@ -154,6 +152,99 @@ export NCCL_PROTO='simple' """ +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.12xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='48' +export SM_NUM_GPUS='4' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='4' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 48, "num_gpus": 4, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +export MASTER_ADDR='algo-1' +export MASTER_PORT='7777' +export SM_HOSTS_LIST='algo-1:4' +export SM_FI_PROVIDER='' +export SM_NCCL_PROTO='' +export SM_FI_EFA_USE_DEVICE_RDMA='' +""" + +# flake8: noqa +EXPECTED_ENV_MULTI_NODE_MULTI_GPUS_MPIRUN = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.2xlarge' +export SM_HOSTS='["algo-1", "algo-2", "algo-3", "algo-4"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='4' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='8' +export SM_NUM_GPUS='1' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='1' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.2xlarge", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "host_count": 4, "nproc_per_node": 1, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 8, "num_gpus": 1, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}}' +export MASTER_ADDR='algo-1' +export MASTER_PORT='7777' +export SM_HOSTS_LIST='algo-1:1,algo-2:1,algo-3:1,algo-4:1' +export SM_FI_PROVIDER='' +export SM_NCCL_PROTO='' +export SM_FI_EFA_USE_DEVICE_RDMA='' +""" + +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN_WITH_NPROC_PER_NODE = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.12xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='48' +export SM_NUM_GPUS='4' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='2' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 2, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 48, "num_gpus": 4, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +export MASTER_ADDR='algo-1' +export MASTER_PORT='7777' +export SM_HOSTS_LIST='algo-1:2' +export SM_FI_PROVIDER='' +export SM_NCCL_PROTO='' +export SM_FI_EFA_USE_DEVICE_RDMA='' +""" + DESCRIBE_TRAINING_JOB_RESPONSE = { "TrainingJobArn": TRAINING_JOB_ARN, "TrainingJobStatus": "{}", @@ -478,7 +569,7 @@ def test_start( s3_kms_key=None, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=None, + use_mpirun=False, ) mock_dependency_upload.assert_called_once_with( @@ -761,7 +852,7 @@ def test_start_with_complete_job_settings( s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=None, + use_mpirun=False, ) mock_user_workspace_upload.assert_called_once_with( @@ -933,7 +1024,7 @@ def test_get_train_args_under_pipeline_context( s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=None, + use_mpirun=False, ) mock_user_workspace_upload.assert_called_once_with( @@ -1109,7 +1200,7 @@ def test_start_with_spark( s3_kms_key=None, sagemaker_session=session(), use_torchrun=False, - nproc_per_node=None, + use_mpirun=False, ) session().sagemaker_client.create_training_job.assert_called_once_with( @@ -1268,7 +1359,7 @@ def test_prepare_and_upload_runtime_scripts(session, mock_copy, mock_s3_upload): assert s3_path == mock_s3_upload.return_value - assert mock_copy.call_count == 2 + assert mock_copy.call_count == 3 mock_s3_upload.assert_called_once() @@ -1288,7 +1379,7 @@ def test_prepare_and_upload_runtime_scripts_under_pipeline_context( ) # Bootstrap scripts are uploaded on the first call assert s3_path == mock_s3_upload.return_value - assert mock_copy.call_count == 2 + assert mock_copy.call_count == 3 mock_s3_upload.assert_called_once() mock_copy.reset_mock() @@ -1725,7 +1816,7 @@ def test_start_with_torchrun_single_node( instance_type="ml.g5.12xlarge", encrypt_inter_container_traffic=True, use_torchrun=True, - nproc_per_node=None, + use_mpirun=False, ) job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) @@ -1751,7 +1842,7 @@ def test_start_with_torchrun_single_node( s3_kms_key=None, sagemaker_session=session(), use_torchrun=True, - nproc_per_node=None, + use_mpirun=False, ) mock_dependency_upload.assert_called_once_with( @@ -1809,6 +1900,8 @@ def test_start_with_torchrun_single_node( mock_sagemaker_pysdk_version, "--dependency_settings", '{"dependency_file": null}', + "--distribution", + "torchrun", "--run_in_context", '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', ], @@ -1853,7 +1946,7 @@ def test_start_with_torchrun_multi_node( instance_type="ml.g5.2xlarge", encrypt_inter_container_traffic=True, use_torchrun=True, - nproc_per_node=None, + use_mpirun=False, ) job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) @@ -1879,7 +1972,7 @@ def test_start_with_torchrun_multi_node( s3_kms_key=None, sagemaker_session=session(), use_torchrun=True, - nproc_per_node=None, + use_mpirun=False, ) mock_dependency_upload.assert_called_once_with( @@ -1939,6 +2032,8 @@ def test_start_with_torchrun_multi_node( mock_sagemaker_pysdk_version, "--dependency_settings", '{"dependency_file": null}', + "--distribution", + "torchrun", "--run_in_context", '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', ], @@ -1969,7 +2064,7 @@ def test_start_with_torchrun_multi_node( return_value=0, ) @patch( - "sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize", + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", side_effect=safe_serialize, ) def test_set_env_single_node_cpu( @@ -1991,6 +2086,7 @@ def test_set_env_single_node_cpu( ], network_interface_name="eth0", ), + distribution=None, output_file=OUTPUT_FILE, ) @@ -2021,7 +2117,7 @@ def test_set_env_single_node_cpu( return_value=0, ) @patch( - "sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize", + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", side_effect=safe_serialize, ) def test_set_env_single_node_multi_gpu( @@ -2043,6 +2139,7 @@ def test_set_env_single_node_multi_gpu( ], network_interface_name="eth0", ), + distribution="torchrun", output_file=OUTPUT_FILE, ) @@ -2073,7 +2170,7 @@ def test_set_env_single_node_multi_gpu( return_value=0, ) @patch( - "sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize", + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", side_effect=safe_serialize, ) def test_set_env_multi_node_multi_gpu( @@ -2095,6 +2192,7 @@ def test_set_env_multi_node_multi_gpu( ], network_interface_name="eth0", ), + distribution="torchrun", output_file=OUTPUT_FILE, ) @@ -2112,6 +2210,432 @@ def test_set_env_multi_node_multi_gpu( assert not os.path.exists(OUTPUT_FILE) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=48, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_multi_gpu_mpirun( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.12xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.12xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + distribution="mpirun", + output_file=OUTPUT_FILE, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(OUTPUT_FILE, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + os.remove(OUTPUT_FILE) + assert not os.path.exists(OUTPUT_FILE) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=8, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=1, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_multi_node_multi_gpu_mpirun( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1", "algo-2", "algo-3", "algo-4"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.2xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.2xlarge", + hosts=["algo-4", "algo-2", "algo-1", "algo-3"], + ) + ], + network_interface_name="eth0", + ), + distribution="mpirun", + output_file=OUTPUT_FILE, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(OUTPUT_FILE, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_MULTI_NODE_MULTI_GPUS_MPIRUN) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + os.remove(OUTPUT_FILE) + assert not os.path.exists(OUTPUT_FILE) + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_torchrun_single_node_with_nproc_per_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_type="ml.g5.12xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=True, + use_mpirun=False, + nproc_per_node=2, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=True, + use_mpirun=False, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--distribution", + "torchrun", + "--user_nproc_per_node", + "2", + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=1, + InstanceType="ml.g5.12xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_mpirun_single_node_with_nproc_per_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_type="ml.g5.12xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=False, + use_mpirun=True, + nproc_per_node=2, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=False, + use_mpirun=True, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--distribution", + "mpirun", + "--user_nproc_per_node", + "2", + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=1, + InstanceType="ml.g5.12xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=48, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_multi_gpu_mpirun_with_nproc_per_node( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.12xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.12xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + distribution="mpirun", + user_nproc_per_node=2, + output_file=OUTPUT_FILE, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(OUTPUT_FILE, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines( + EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN_WITH_NPROC_PER_NODE + ) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + os.remove(OUTPUT_FILE) + assert not os.path.exists(OUTPUT_FILE) + + def _remove_extra_lines(string): """Removes extra blank lines from a string.""" return "\n".join([line for line in string.splitlines() if line.strip()]) diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index b6bd69e304..415d7eab5b 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -75,7 +75,7 @@ "-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" ) mock_djl_image_uri = ( - "123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1" + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124" ) mock_model_data = { @@ -1166,6 +1166,9 @@ def test_optimize_quantize_for_jumpstart( mock_pysdk_model.image_uri = mock_tgi_image_uri mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pysdk_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } sample_input = { "inputs": "The diamondback terrapin or simply terrapin is a species " @@ -1201,6 +1204,10 @@ def test_optimize_quantize_for_jumpstart( ) self.assertIsNotNone(out_put) + self.assertEqual( + out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + ) @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) @@ -1287,6 +1294,9 @@ def test_optimize_quantize_and_compile_for_jumpstart( mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] mock_pysdk_model.config_name = "config_name" mock_pysdk_model._metadata_configs = {"config_name": mock_metadata_config} + mock_pysdk_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } sample_input = { "inputs": "The diamondback terrapin or simply terrapin is a species " @@ -1319,6 +1329,8 @@ def test_optimize_quantize_and_compile_for_jumpstart( ) self.assertIsNotNone(out_put) + self.assertIsNone(out_put["OptimizationConfigs"][1]["ModelCompilationConfig"].get("Image")) + self.assertIsNone(out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"].get("Image")) @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) @@ -1633,13 +1645,17 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations( mock_serve_settings, mock_telemetry, ): - mock_sagemaker_session = Mock() + mock_sagemaker_session = MagicMock() + mock_sagemaker_session.sagemaker_client.create_optimization_job = MagicMock() mock_sagemaker_session.wait_for_optimization_job.side_effect = ( lambda *args: mock_optimization_job_response ) mock_lmi_js_model = MagicMock() mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } mock_lmi_js_model.env = { "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -1671,6 +1687,13 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations( output_path="s3://bucket/code/", ) + assert ( + mock_sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelQuantizationConfig"]["Image"] + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124" + ) + assert mock_lmi_js_model.set_deployment_config.call_args_list[0].kwargs == { "instance_type": "ml.g5.24xlarge", "config_name": "lmi", @@ -1711,13 +1734,17 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_over mock_serve_settings, mock_telemetry, ): - mock_sagemaker_session = Mock() + mock_sagemaker_session = MagicMock() + mock_sagemaker_session.sagemaker_client.create_optimization_job = MagicMock() mock_sagemaker_session.wait_for_optimization_job.side_effect = ( lambda *args: mock_optimization_job_response ) mock_lmi_js_model = MagicMock() mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi27.0.0-cu124" + } mock_lmi_js_model.env = { "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -1748,6 +1775,13 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_over output_path="s3://bucket/code/", ) + assert ( + mock_sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelQuantizationConfig"]["Image"] + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi27.0.0-cu124" + ) + assert mock_lmi_js_model.set_deployment_config.call_args_list[0].kwargs == { "instance_type": "ml.g5.24xlarge", "config_name": "lmi", @@ -1763,3 +1797,163 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_over "OPTION_TENSOR_PARALLEL_DEGREE": "8", "OPTION_QUANTIZE": "fp8", # should be added to the env } + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_fine_tuned_model", + return_value=False, + ) + def test_optimize_on_js_model_test_image_defaulting_scenarios( + self, + mock_is_fine_tuned, + mock_is_jumpstart_model, + mock_js_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + + mock_lmi_js_model = MagicMock() + mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-1-70b-instruct", + schema_builder=SchemaBuilder("test", "test"), + sagemaker_session=MagicMock(), + ) + model_builder.pysdk_model = mock_lmi_js_model + + # assert lmi version is upgraded to hardcoded default + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + ) + + # assert lmi version is left as is + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi21.0.0-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi21.0.0-cu124", + ) + + # assert lmi version is upgraded to the highest provided version + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelShardingConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + }, + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124" + } + }, + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelShardingConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + self.assertEqual( + optimization_args["OptimizationConfigs"][1]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + + # assert lmi version is upgraded to the highest provided version and sets empty image config + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124" + } + }, + {"ModelShardingConfig": {}}, + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + self.assertEqual( + optimization_args["OptimizationConfigs"][1]["ModelShardingConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + + # assert lmi version is left as is on minor version bump + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.1.0-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.1.0-cu124", + ) + + # assert lmi version is left as is on patch version bump + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.0.1-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.0.1-cu124", + ) diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 1e20bf1cf3..6ef07b0f46 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -74,6 +74,7 @@ ModelServer.MMS, ModelServer.TGI, ModelServer.TEI, + ModelServer.SMD, } mock_session = MagicMock() @@ -2890,6 +2891,85 @@ def test_optimize_for_hf_without_custom_s3_path( }, ) + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_jumpstart") + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + def test_build_multiple_inference_component_modelbuilders( + self, + mock_pre_trained_model, + mock_is_jumpstart_model_id, + mock_build_for_js, + mock_serve_settings, + ): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder1 = ModelBuilder( + model="gpt_llm_burt", inference_component_name="ic1", resource_requirements=Mock() + ) + builder2 = ModelBuilder( + model="gpt_llm_burt", inference_component_name="ic2", resource_requirements=Mock() + ) + + builder3 = ModelBuilder( + model="gpt_llm_burt", inference_component_name="ic3", resource_requirements=Mock() + ) + + chain_builder = ModelBuilder( + modelbuilder_list=[builder1, builder2, builder3], + ) + chain_builder.build(sagemaker_session=mock_session) + assert mock_build_for_js.call_count == 3 + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_jumpstart") + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._does_ic_exist", + return_value=True, + ) + @patch( + "sagemaker.session.Session.update_inference_component", + return_value=MagicMock(), + ) + def test_deploy_existing_inference_component_calls_update_inference_component( + self, + mock_update_inference_component, + mock_ic_exists, + mock_pre_trained_model, + mock_is_jumpstart_model_id, + mock_build_for_js, + mock_serve_settings, + ): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder1 = ModelBuilder( + model="gpt_llm_burt", inference_component_name="ic1", resource_requirements=Mock() + ) + + chain_builder = ModelBuilder( + modelbuilder_list=[builder1], + ).build() + chain_builder.deploy() + assert mock_update_inference_component.call_count == 1 + def test_deploy_invalid_inputs(self): model_builder = ModelBuilder( model="meta-llama/Meta-Llama-3-8B-Instruct", @@ -2902,7 +2982,7 @@ def test_deploy_invalid_inputs(self): try: model_builder.deploy(**inputs) except ValueError as e: - assert "Model Needs to be built before deploying" in str(e) + assert "Model needs to be built before deploying" in str(e) @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") def test_display_benchmark_metrics_non_string_model(self, mock_is_jumpstart): @@ -3733,6 +3813,9 @@ def test_optimize_sharding_with_override_for_js( pysdk_model.env = {"key": "val"} pysdk_model._enable_network_isolation = True pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + pysdk_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } mock_build_for_jumpstart.side_effect = lambda **kwargs: pysdk_model mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( @@ -3803,8 +3886,9 @@ def test_optimize_sharding_with_override_for_js( OptimizationConfigs=[ { "ModelShardingConfig": { - "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} - } + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}, + }, } ], OutputConfig={ diff --git a/tests/unit/sagemaker/telemetry/test_telemetry_logging.py b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py index 9107256b5b..bd8db82a16 100644 --- a/tests/unit/sagemaker/telemetry/test_telemetry_logging.py +++ b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py @@ -300,3 +300,39 @@ def test_get_default_sagemaker_session_with_no_region(self): assert "Must setup local AWS configuration with a region supported by SageMaker." in str( context.exception ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_valid_region(self, mock_get_region, mock_get_accountId): + """Test to verify telemetry request is sent when region is valid""" + mock_get_accountId.return_value = "testAccountId" + mock_session = MagicMock() + + # Test with valid region + mock_get_region.return_value = "us-east-1" + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + _send_telemetry_request(1, [1, 2], mock_session) + # Assert telemetry request was sent + mock_requests_helper.assert_called_once_with( + "https://sm-pysdk-t-us-east-1.s3.us-east-1.amazonaws.com/telemetry?" + "x-accountId=testAccountId&x-status=1&x-feature=1,2", + 2, + ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_invalid_region(self, mock_get_region, mock_get_accountId): + """Test to verify telemetry request is not sent when region is invalid""" + mock_get_accountId.return_value = "testAccountId" + mock_session = MagicMock() + + # Test with invalid region + mock_get_region.return_value = "invalid-region" + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + _send_telemetry_request(1, [1, 2], mock_session) + # Assert telemetry request was not sent + mock_requests_helper.assert_not_called() diff --git a/tests/unit/test_common.py b/tests/unit/test_common.py index 8fe7383fe4..9fe49ad448 100644 --- a/tests/unit/test_common.py +++ b/tests/unit/test_common.py @@ -16,12 +16,12 @@ import tempfile import pytest import itertools +from sagemaker.deserializers import RecordDeserializer +from sagemaker.serializers import RecordSerializer from scipy.sparse import coo_matrix from sagemaker.amazon.common import ( - RecordDeserializer, write_numpy_to_dense_tensor, read_recordio, - RecordSerializer, write_spmatrix_to_sparse_tensor, ) from sagemaker.amazon.record_pb2 import Record diff --git a/tests/unit/test_inputs.py b/tests/unit/test_inputs.py index 7d9c2b2c2f..133c31eb75 100644 --- a/tests/unit/test_inputs.py +++ b/tests/unit/test_inputs.py @@ -41,6 +41,8 @@ def test_training_input_all_arguments(): record_wrapping = "RecordIO" s3_data_type = "Manifestfile" input_mode = "Pipe" + hub_access_config = {"HubContentArn": "some-hub-content-arn"} + model_access_config = {"AcceptEula": True} result = TrainingInput( s3_data=prefix, distribution=distribution, @@ -49,6 +51,8 @@ def test_training_input_all_arguments(): content_type=content_type, record_wrapping=record_wrapping, s3_data_type=s3_data_type, + hub_access_config=hub_access_config, + model_access_config=model_access_config, ) expected = { "DataSource": { @@ -56,6 +60,8 @@ def test_training_input_all_arguments(): "S3DataDistributionType": distribution, "S3DataType": s3_data_type, "S3Uri": prefix, + "ModelAccessConfig": model_access_config, + "HubAccessConfig": hub_access_config, } }, "CompressionType": compression, @@ -76,6 +82,8 @@ def test_training_input_all_arguments_heterogeneous_cluster(): s3_data_type = "Manifestfile" instance_groups = ["data-server"] input_mode = "Pipe" + hub_access_config = {"HubContentArn": "some-hub-content-arn"} + model_access_config = {"AcceptEula": True} result = TrainingInput( s3_data=prefix, distribution=distribution, @@ -85,6 +93,8 @@ def test_training_input_all_arguments_heterogeneous_cluster(): record_wrapping=record_wrapping, s3_data_type=s3_data_type, instance_groups=instance_groups, + hub_access_config=hub_access_config, + model_access_config=model_access_config, ) expected = { @@ -94,6 +104,8 @@ def test_training_input_all_arguments_heterogeneous_cluster(): "S3DataType": s3_data_type, "S3Uri": prefix, "InstanceGroupNames": instance_groups, + "ModelAccessConfig": model_access_config, + "HubAccessConfig": hub_access_config, } }, "CompressionType": compression, diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index c93a381c11..dc21f50b68 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -206,6 +206,32 @@ def test_load_config_with_model_channel_no_inputs(estimator): assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME +def test_load_config_with_access_configs(estimator): + estimator.model_uri = MODEL_URI + estimator.model_channel_name = MODEL_CHANNEL_NAME + estimator.model_access_config = {"AcceptEula": True} + estimator.hub_access_config = {"HubContentArn": "dummy_arn"} + + config = _Job._load_config(inputs=None, estimator=estimator) + assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == MODEL_URI + assert config["input_config"][0]["ChannelName"] == MODEL_CHANNEL_NAME + assert config["role"] == ROLE + assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH + assert "KmsKeyId" not in config["output_config"] + assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT + assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE + assert config["resource_config"]["VolumeSizeInGB"] == VOLUME_SIZE + assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME + assert ( + config["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] + == estimator.model_access_config + ) + assert ( + config["input_config"][0]["DataSource"]["S3DataSource"]["HubAccessConfig"] + == estimator.hub_access_config + ) + + def test_load_config_with_code_channel(framework): inputs = TrainingInput(BUCKET_NAME) @@ -347,20 +373,43 @@ def test_format_record_set_list_input(): @pytest.mark.parametrize( - "channel_uri, channel_name, content_type, input_mode", + "channel_uri, channel_name, content_type, input_mode, model_access_config, hub_access_config", [ - [MODEL_URI, MODEL_CHANNEL_NAME, "application/x-sagemaker-model", "File"], - [CODE_URI, CODE_CHANNEL_NAME, None, None], + [ + MODEL_URI, + MODEL_CHANNEL_NAME, + "application/x-sagemaker-model", + "File", + {"AcceptEula": True}, + None, + ], + [CODE_URI, CODE_CHANNEL_NAME, None, None, None, {"HubContentArn": "dummy_arn"}], ], ) -def test_prepare_channel(channel_uri, channel_name, content_type, input_mode): +def test_prepare_channel( + channel_uri, channel_name, content_type, input_mode, model_access_config, hub_access_config +): channel = _Job._prepare_channel( - [], channel_uri, channel_name, content_type=content_type, input_mode=input_mode + [], + channel_uri, + channel_name, + content_type=content_type, + input_mode=input_mode, + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) assert channel["DataSource"]["S3DataSource"]["S3Uri"] == channel_uri assert channel["DataSource"]["S3DataSource"]["S3DataDistributionType"] == "FullyReplicated" assert channel["DataSource"]["S3DataSource"]["S3DataType"] == "S3Prefix" + if hub_access_config: + assert channel["DataSource"]["S3DataSource"]["HubAccessConfig"] == hub_access_config + else: + assert "HubAccessConfig" not in channel["DataSource"]["S3DataSource"] + if model_access_config: + assert channel["DataSource"]["S3DataSource"]["ModelAccessConfig"] == model_access_config + else: + assert "ModelAccessConfig" not in channel["DataSource"]["S3DataSource"] assert channel["ChannelName"] == channel_name assert "CompressionType" not in channel assert "RecordWrapperType" not in channel @@ -546,6 +595,23 @@ def test_format_string_uri_input_string(): assert s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] == inputs +def test_format_string_uri_input_string_with_access_configs(): + inputs = BUCKET_NAME + model_access_config = {"AcceptEula": True} + hub_access_config = {"HubContentArn": "dummy_arn"} + + s3_uri_input = _Job._format_string_uri_input( + inputs, model_access_config=model_access_config, hub_access_config=hub_access_config + ) + + assert s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] == inputs + assert s3_uri_input.config["DataSource"]["S3DataSource"]["HubAccessConfig"] == hub_access_config + assert ( + s3_uri_input.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] + == model_access_config + ) + + def test_format_string_uri_file_system_input(): file_system_id = "fs-fd85e556" file_system_type = "EFS" @@ -585,6 +651,26 @@ def test_format_string_uri_input(): ) +def test_format_string_uri_input_with_access_configs(): + inputs = TrainingInput(BUCKET_NAME) + model_access_config = {"AcceptEula": True} + hub_access_config = {"HubContentArn": "dummy_arn"} + + s3_uri_input = _Job._format_string_uri_input( + inputs, model_access_config=model_access_config, hub_access_config=hub_access_config + ) + + assert ( + s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] + == inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ) + assert s3_uri_input.config["DataSource"]["S3DataSource"]["HubAccessConfig"] == hub_access_config + assert ( + s3_uri_input.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] + == model_access_config + ) + + def test_format_string_uri_input_exception(): inputs = 1 diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 6076d44e90..34d3c6784b 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -23,7 +23,10 @@ from sagemaker import image_uris from sagemaker.pytorch import defaults from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel -from sagemaker.pytorch.estimator import _get_training_recipe_image_uri +from sagemaker.pytorch.estimator import ( + _get_training_recipe_image_uri, + _get_training_recipe_gpu_script, +) from sagemaker.instance_group import InstanceGroup from sagemaker.session_settings import SessionSettings @@ -1049,6 +1052,52 @@ def test_training_recipe_for_trainium(sagemaker_session): assert pytorch.distribution == expected_distribution +@pytest.mark.parametrize( + "test_case", + [ + { + "script": "llama_pretrain.py", + "recipe": { + "model": { + "model_type": "llama_v3", + }, + }, + }, + { + "script": "mistral_pretrain.py", + "recipe": { + "model": { + "model_type": "mistral", + }, + }, + }, + { + "script": "deepseek_pretrain.py", + "recipe": { + "model": { + "model_type": "deepseek_llamav3", + }, + }, + }, + { + "script": "deepseek_pretrain.py", + "recipe": { + "model": { + "model_type": "deepseek_qwenv2", + }, + }, + }, + ], +) +@patch("shutil.copyfile") +def test_get_training_recipe_gpu_script(mock_copyfile, test_case): + script = test_case["script"] + recipe = test_case["recipe"] + mock_copyfile.return_value = None + + assert _get_training_recipe_gpu_script("code_dir", recipe, "source_dir") == script + + def test_training_recipe_for_trainium_custom_source_dir(sagemaker_session): container_log_level = '"logging.INFO"' diff --git a/tests/unit/test_s3.py b/tests/unit/test_s3.py index a226954986..b54552cacb 100644 --- a/tests/unit/test_s3.py +++ b/tests/unit/test_s3.py @@ -17,6 +17,7 @@ from mock import Mock from sagemaker import s3 +from sagemaker.s3_utils import is_s3_url BUCKET_NAME = "mybucket" REGION = "us-west-2" @@ -132,6 +133,34 @@ def test_parse_s3_url_fail(): assert "Expecting 's3' scheme" in str(error) +@pytest.mark.parametrize( + "input_url", + [ + ("s3://bucket/code_location"), + ("s3://bucket/code_location/sub_location"), + ("s3://bucket/code_location/sub_location/"), + ("s3://bucket/"), + ("s3://bucket"), + ], +) +def test_is_s3_url_true(input_url): + assert is_s3_url(input_url) is True + + +@pytest.mark.parametrize( + "input_url", + [ + ("bucket/code_location"), + ("bucket/code_location/sub_location"), + ("sub_location/"), + ("s3/bucket/"), + ("t3://bucket"), + ], +) +def test_is_s3_url_false(input_url): + assert is_s3_url(input_url) is False + + @pytest.mark.parametrize( "expected_output, input_args", [ diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d2d2c3bcfb..f873e9b14c 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -5006,6 +5006,7 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session) domain = "COMPUTER_VISION" task = "IMAGE_CLASSIFICATION" sample_payload_url = "s3://test-bucket/model" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5094,6 +5095,8 @@ def test_create_model_package_from_containers_with_source_uri_and_inference_spec skip_model_validation = "All" source_uri = "dummy-source-uri" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + created_versioned_mp_arn = ( "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" ) @@ -5149,6 +5152,7 @@ def test_create_model_package_from_containers_with_source_uri_for_unversioned_mp approval_status = ("Approved",) skip_model_validation = "All" source_uri = "dummy-source-uri" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} with pytest.raises( ValueError, @@ -5221,6 +5225,8 @@ def test_create_model_package_from_containers_with_source_uri_set_to_mp(sagemake return_value={"ModelPackageArn": created_versioned_mp_arn} ) + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + sagemaker_session.create_model_package_from_containers( model_package_group_name=model_package_group_name, containers=containers, @@ -5443,6 +5449,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s approval_status = ("Approved",) description = "description" customer_metadata_properties = {"key1": "value1"} + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5510,6 +5517,7 @@ def test_create_model_package_from_containers_with_one_instance_types( approval_status = ("Approved",) description = "description" customer_metadata_properties = {"key1": "value1"} + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -7183,3 +7191,65 @@ def test_delete_hub_content_reference(sagemaker_session): } sagemaker_session.sagemaker_client.delete_hub_content_reference.assert_called_with(**request) + + +def test_create_model_package_from_containers_to_create_mpg_if_not_present_without_search( + sagemaker_session, +): + sagemaker_session.sagemaker_client.search.side_effect = Exception() + sagemaker_session.sagemaker_client.search.return_value = {} + sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [ + { + "ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg"}], + "NextToken": "NextToken", + }, + {"ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg-test"}]}, + ] + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", + model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg", + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [ + {"ModelPackageGroupSummaryList": []} + ] + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with( + ModelPackageGroupName="mock-mpg" + ) + + +def test_create_model_package_from_containers_to_create_mpg_if_not_present(sagemaker_session): + # with search api + sagemaker_session.sagemaker_client.search.return_value = { + "Results": [ + { + "ModelPackageGroup": { + "ModelPackageGroupName": "mock-mpg", + "ModelPackageGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:model-package-group/mock-mpg", + } + } + ] + } + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", + model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg", + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with( + ModelPackageGroupName="mock-mpg" + )