diff --git a/.gitignore b/.gitignore index ad6e488dbd..fc07847fba 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,9 @@ env/ .python-version *.html **/_repack_script_launcher.sh +src/sagemaker/modules/train/container_drivers/sm_train.sh +src/sagemaker/modules/train/container_drivers/sourcecode.json +src/sagemaker/modules/train/container_drivers/distributed.json tests/data/**/_repack_model.py tests/data/experiment/sagemaker-dev-1.0.tar.gz src/sagemaker/serve/tmp_workspace \ No newline at end of file diff --git a/.pydocstylerc b/.pydocstylerc index a5083c0d63..9ed879a760 100644 --- a/.pydocstylerc +++ b/.pydocstylerc @@ -2,3 +2,4 @@ inherit = false ignore = D104,D107,D202,D203,D213,D214,D400,D401,D404,D406,D407,D411,D413,D414,D415,D417 match = (?!record_pb2).*\.py +match-dir = (?!.*test).* \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in index ab053b00aa..28f1569c35 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,8 +1,10 @@ recursive-include src/sagemaker *.py include src/sagemaker/image_uri_config/*.json +include src/sagemaker/pytorch/training_recipes.json include src/sagemaker/serve/schema/*.json include src/sagemaker/serve/requirements.txt +include src/sagemaker/modules/train/sm_recipes/training_recipes.json recursive-include requirements * include VERSION diff --git a/doc/api/training/index.rst b/doc/api/training/index.rst index 5f85359d20..0f61cd1931 100644 --- a/doc/api/training/index.rst +++ b/doc/api/training/index.rst @@ -5,6 +5,7 @@ Training APIs .. toctree:: :maxdepth: 4 + model_trainer algorithm analytics automl diff --git a/doc/api/training/model_trainer.rst b/doc/api/training/model_trainer.rst new file mode 100644 index 0000000000..5b0781f810 --- /dev/null +++ b/doc/api/training/model_trainer.rst @@ -0,0 +1,17 @@ +ModelTrainer +------------ + +.. autoclass:: sagemaker.modules.train.model_trainer.ModelTrainer + :members: + +Configs +~~~~~~~ + +.. automodule:: sagemaker.modules.configs + :members: + +Distributed +~~~~~~~~~~~ + +.. automodule:: sagemaker.modules.distributed + :members: diff --git a/doc/frameworks/pytorch/using_pytorch.rst b/doc/frameworks/pytorch/using_pytorch.rst index 73e2887440..d415f38c27 100644 --- a/doc/frameworks/pytorch/using_pytorch.rst +++ b/doc/frameworks/pytorch/using_pytorch.rst @@ -21,12 +21,9 @@ To train a PyTorch model by using the SageMaker Python SDK: .. |create pytorch estimator| replace:: Create a ``sagemaker.pytorch.PyTorch`` Estimator .. _create pytorch estimator: #create-an-estimator -.. |call fit| replace:: Call the estimator's ``fit`` method -.. _call fit: #call-the-fit-method - -1. `Prepare a training script <#prepare-a-pytorch-training-script>`_ +1. `Prepare a training script <#prepare-a-pytorch-training-script>`_ OR `Choose an Amazon SageMaker HyperPod recipe`_ 2. |create pytorch estimator|_ -3. |call fit|_ +3. `Call the estimator's fit method or ModelTrainer's train method`_ Prepare a PyTorch Training Script ================================= @@ -175,6 +172,16 @@ see `AWS Deep Learning Containers `__ +Choose an Amazon Sagemaker HyperPod recipe +========================================== + +Alternatively, instead of using your own training script, you can choose an +`Amazon SageMaker HyperPod recipe `_ to launch training for a supported model. +If using a recipe, you do not need to provide your own training script. You only need to determine +which recipe you want to run. You can modify a recipe as explained in the next section. + + + Create an Estimator =================== @@ -196,10 +203,121 @@ directories ('train' and 'test'). 'test': 's3://my-data-bucket/path/to/my/test/data'}) +Amazon Sagemaker HyperPod recipes +--------------------------------- +Alternatively, if you are using Amazon SageMaker HyperPod recipes, you can follow the following instructions: +Prerequisites: you need ``git`` installed on your client to access Amazon SageMaker HyperPod recipes code. -Call the fit Method -=================== +When using a recipe, you must set the ``training_recipe`` arg in place of providing a training script. +This can be a recipe from `here `_ +or a local file or a custom url. Please note that you must override the following using +``recipe_overrides``: + +* directory paths for the local container in the recipe as appropriate for Python SDK +* the output s3 URIs +* Huggingface access token +* any other recipe fields you wish to edit + +The code snippet below shows an example. +Please refer to `SageMaker docs `_ +for more details about the expected local paths in the container and the Amazon SageMaker +HyperPod recipes tutorial for more examples. +You can override the fields by either setting ``recipe_overrides`` or +providing a modified ``training_recipe`` through a local file or a custom url. +When using the recipe, any provided ``entry_point`` will be ignored. + +SageMaker will automatically set up the distribution args. +It will also determine the image to use for your model and device type, +but you can override this with the ``image_uri`` arg. + +You can also override the number of nodes in the recipe with the ``instance_count`` arg to estimator. +``source_dir`` will default to current working directory unless specified. +A local copy of training scripts and recipe will be saved in the ``source_dir``. +You can specify any additional packages you want to install for training in an optional ``requirements.txt`` in the ``source_dir``. + +Note for llama3.2 multi-modal models, you need to upgrade transformers library by providing a ``requirements.txt`` in the source file with ``transformers==4.45.2``. +Please refer to the Amazon SageMaker HyperPod recipes documentation for more details. + + +Here is an example usage for recipe ``hf_llama3_8b_seq8k_gpu_p5x16_pretrain``. + + +.. code:: python + + overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "exp_dir": "", + "explicit_log_dir": "/opt/ml/output/tensorboard", + "checkpoint_dir": "/opt/ml/checkpoints", + }, + "model": { + "data": { + "train_dir": "/opt/ml/input/data/train", + "val_dir": "/opt/ml/input/data/val", + }, + }, + } + pytorch_estimator = PyTorch( + output_path=output_path, + base_job_name=f"llama-recipe", + role=role, + instance_type="ml.p5.48xlarge", + training_recipe="hf_llama3_8b_seq8k_gpu_p5x16_pretrain", + recipe_overrides=recipe_overrides, + sagemaker_session=sagemaker_session, + tensorboard_output_config=tensorboard_output_config, + ) + pytorch_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data', + 'test': 's3://my-data-bucket/path/to/my/test/data'}) + + # Or alternatively with ModelTrainer + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "exp_dir": "", + "explicit_log_dir": "/opt/ml/output/tensorboard", + "checkpoint_dir": "/opt/ml/checkpoints", + }, + "model": { + "data": { + "train_dir": "/opt/ml/input/data/train", + "val_dir": "/opt/ml/input/data/val", + }, + }, + } + + model_trainer = ModelTrainer.from_recipe( + output_path=output_path, + base_job_name=f"llama-recipe", + training_recipe="training/llama/hf_llama3_8b_seq8k_gpu_p5x16_pretrain", + recipe_overrides=recipe_overrides, + compute=Compute(instance_type="ml.p5.48xlarge"), + sagemaker_session=sagemaker_session + ).with_tensorboard_output_config( + tensorboard_output_config=tensorboard_output_config + ) + + train_input = Input( + channel_name="train", + data_source="s3://my-data-bucket/path/to/my/training/data" + ) + + test_input = Input( + channel_name="test", + data_source="s3://my-data-bucket/path/to/my/test/data" + ) + + model_trainer.train(input_data_config=[train_input, test_input) + + +Call the estimator's fit method or ModelTrainer's train method +============================================================== You start your training script by calling ``fit`` on a ``PyTorch`` Estimator. ``fit`` takes both required and optional arguments. diff --git a/doc/overview.rst b/doc/overview.rst index 319560b5ff..a1dc5c6918 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -4,6 +4,7 @@ Using the SageMaker Python SDK SageMaker Python SDK provides several high-level abstractions for working with Amazon SageMaker. These are: +- **ModelTrainer**: New interface encapsulating training on SageMaker. - **Estimators**: Encapsulate training on SageMaker. - **Models**: Encapsulate built ML models. - **Predictors**: Provide real-time inference and transformation using Python data-types against a SageMaker endpoint. @@ -24,8 +25,8 @@ Train a Model with the SageMaker Python SDK To train a model by using the SageMaker Python SDK, you: 1. Prepare a training script -2. Create an estimator -3. Call the ``fit`` method of the estimator +2. Create a ModelTrainer or Estimator +3. Call the ``train`` method of the ModelTrainer or the ``fit`` method of the Estimator After you train a model, you can save it, and then serve the model as an endpoint to get real-time inferences or get inferences for an entire dataset by using batch transform. @@ -85,6 +86,46 @@ If you want to use, for example, boolean hyperparameters, you need to specify `` For more on training environment variables, please visit `SageMaker Containers `_. +Using ModelTrainer +================== + +To use the ModelTrainer class, you need to provide a few essential parameters such as the training image URI and the source code configuration. The class allows you to spin up a SageMaker training job with minimal parameters, particularly by specifying the source code and training image. + +For more information about class definitions see `ModelTrainer `_. + +Example: Launching a Training Job with Custom Script + +.. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import SourceCode, InputData + + # Image URI for the training job + pytorch_image = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310" + + # Define the script to be run + source_code = SourceCode( + source_dir="basic-script-mode", + requirements="requirements.txt", + entry_script="custom_script.py", + ) + + # Define the ModelTrainer + model_trainer = ModelTrainer( + training_image=pytorch_image, + source_code=source_code, + base_job_name="script-mode", + ) + + # Pass the input data + input_data = InputData( + channel_name="train", + data_source=training_input_path, # S3 path where training data is stored + ) + + # Start the training job + model_trainer.train(input_data_config=[input_data], wait=False) + Using Estimators ================ diff --git a/doc/requirements.txt b/doc/requirements.txt index 8193dfa22a..9bef9392a8 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -5,3 +5,4 @@ packaging==20.9 jinja2==3.1.4 schema==0.7.5 accelerate>=0.24.1,<=0.27.0 +graphene<4.0 diff --git a/pyproject.toml b/pyproject.toml index 8b9e9fa92f..4657f41737 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,10 +35,12 @@ dependencies = [ "boto3>=1.34.142,<2.0", "cloudpickle==2.2.1", "docker", + "fastapi", "google-pasta", "importlib-metadata>=1.4.0,<7.0", "jsonschema", "numpy>=1.9.0,<2.0", + "omegaconf>=2.2,<2.3", "packaging>=20.0", "pandas", "pathos", @@ -53,6 +55,7 @@ dependencies = [ "tblib>=1.7.0,<4", "tqdm", "urllib3>=1.26.8,<3.0.0", + "uvicorn" ] [project.scripts] diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 1592576a47..9664a63e1d 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -49,3 +49,4 @@ uvicorn>=0.30.1 fastapi==0.115.4 nest-asyncio sagemaker-mlflow>=0.1.0 +deepdiff>=8.0.0 diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index a1769b5a4c..71ea51c60f 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -74,5 +74,6 @@ ) from sagemaker.debugger import ProfilerConfig, Profiler # noqa: F401 +from sagemaker.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401 __version__ = importlib_metadata.version("sagemaker") diff --git a/src/sagemaker/batch_inference/__init__.py b/src/sagemaker/batch_inference/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/batch_inference/batch_transform_inference_config.py b/src/sagemaker/batch_inference/batch_transform_inference_config.py new file mode 100644 index 0000000000..3d3618d7fb --- /dev/null +++ b/src/sagemaker/batch_inference/batch_transform_inference_config.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Config Classes for taking in parameters for Batch Inference""" + +from __future__ import absolute_import +from pydantic import BaseModel + + +class BatchTransformInferenceConfig(BaseModel): + """Config class for Batch Transform Inference + + * Can be used to deploy from ModelBuilder + """ + + instance_count: int + instance_type: str + output_path: str diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index 35c4859930..34a98c0b8e 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -116,6 +116,7 @@ REGION_NAME = "region_name" TELEMETRY_OPT_OUT = "TelemetryOptOut" NOTEBOOK_JOB = "NotebookJob" +MODEL_TRAINER = "ModelTrainer" def _simple_path(*args: str): @@ -142,6 +143,7 @@ def _simple_path(*args: str): ) TRAINING_JOB_ROLE_ARN_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, ROLE_ARN) TRAINING_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, VPC_CONFIG) +TRAINING_JOB_TAGS_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, TAGS) TRAINING_JOB_SECURITY_GROUP_IDS_PATH = _simple_path( TRAINING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS ) @@ -656,6 +658,25 @@ def _simple_path(*args: str): "minItems": 1, "maxItems": 15, }, + "role": { + TYPE: "string", + "pattern": r"^arn:aws[a-z\-]*:iam::\d{12}:role/?[a-zA-Z_0-9+=,.@\-_/]+$", + "minLength": 20, + "maxLength": 2048, + }, + "baseJobName": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "sourceCode": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "distributed": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "compute": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "networking": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "stoppingCondition": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "trainingImage": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "trainingImageConfig": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "algorithmName": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "outputDataConfig": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "trainingInputMode": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "environment": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "hyperparameters": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, }, PROPERTIES: { SCHEMA_VERSION: { @@ -709,6 +730,7 @@ def _simple_path(*args: str): }, }, }, + MODEL_TRAINER: {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, ESTIMATOR: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index a58d701337..ea51a86101 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -106,7 +106,8 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline - +from sagemaker.telemetry.telemetry_logging import _telemetry_emitter +from sagemaker.telemetry.constants import Feature logger = logging.getLogger(__name__) @@ -1297,6 +1298,7 @@ def latest_job_profiler_artifacts_path(self): ) return None + @_telemetry_emitter(feature=Feature.ESTIMATOR, func_name="estimator.fit") @runnable_by_pipeline def fit( self, diff --git a/src/sagemaker/image_uri_config/hyperpod-recipes-neuron.json b/src/sagemaker/image_uri_config/hyperpod-recipes-neuron.json new file mode 100644 index 0000000000..cd5a69bfe2 --- /dev/null +++ b/src/sagemaker/image_uri_config/hyperpod-recipes-neuron.json @@ -0,0 +1,52 @@ +{ + "training": { + "processors": [ + "neuronx" + ], + "version_aliases": { + "2.1.2": "2.1.2" + }, + "versions": { + "2.1.2": { + "py_versions": [ + "py310" + ], + "repository": "pytorch-training-neuronx", + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572", + "ca-central-1": "763104351884" + }, + "container_version": { + "neuronx": "ubuntu20.04" + }, + "sdk_versions": [ + "sdk2.20.2" + ] + } + } + } +} diff --git a/src/sagemaker/image_uri_config/pytorch-smp.json b/src/sagemaker/image_uri_config/pytorch-smp.json index 91d8ab9184..449726927a 100644 --- a/src/sagemaker/image_uri_config/pytorch-smp.json +++ b/src/sagemaker/image_uri_config/pytorch-smp.json @@ -9,7 +9,7 @@ "2.2": "2.3.1", "2.2.0": "2.3.1", "2.3.1": "2.5.0", - "2.4.1": "2.6.1" + "2.4.1": "2.7.0" }, "versions": { "2.0.1": { @@ -162,7 +162,7 @@ }, "repository": "smdistributed-modelparallel" }, - "2.6.1": { + "2.7.0": { "py_versions": [ "py311" ], diff --git a/src/sagemaker/image_uri_config/xgboost.json b/src/sagemaker/image_uri_config/xgboost.json index 88d621af49..e1a312e61c 100644 --- a/src/sagemaker/image_uri_config/xgboost.json +++ b/src/sagemaker/image_uri_config/xgboost.json @@ -1,7 +1,7 @@ { "inference": { "version_aliases": { - "latest": "1" + "latest": "1.7-1" }, "versions": { "1": { diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index dd7012b2f2..3ca7d2ed2e 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -192,7 +192,7 @@ def retrieve( config = _config_for_framework_and_scope(_framework, final_image_scope, accelerator_type) original_version = version - version = _validate_version_and_set_if_needed(version, config, framework) + version = _validate_version_and_set_if_needed(version, config, framework, image_scope) version_config = config["versions"][_version_for_config(version, config)] if framework == HUGGING_FACE_FRAMEWORK: @@ -224,7 +224,7 @@ def retrieve( container_version = version_config["container_version"][processor] # Append sdk version in case of trainium instances - if repo in ["pytorch-training-neuron"]: + if repo in ["pytorch-training-neuron", "pytorch-training-neuronx"]: if not sdk_version: sdk_version = _get_latest_versions(version_config["sdk_versions"]) container_version = sdk_version + "-" + container_version @@ -463,6 +463,23 @@ def _get_latest_versions(list_of_versions): return sorted(list_of_versions, reverse=True)[0] +def _get_latest_version(framework, version, image_scope): + """Get the latest version from the input framework""" + if version: + return version + try: + framework_config = config_for_framework(framework) + except FileNotFoundError: + raise ValueError("Invalid framework {}".format(framework)) + + if not framework_config: + raise ValueError("Invalid framework {}".format(framework)) + + if not version: + version = _fetch_latest_version_from_config(framework_config, image_scope) + return version + + def _validate_accelerator_type(accelerator_type): """Raises a ``ValueError`` if ``accelerator_type`` is invalid.""" if not accelerator_type.startswith("ml.eia") and accelerator_type != "local_sagemaker_notebook": @@ -472,32 +489,16 @@ def _validate_accelerator_type(accelerator_type): ) -def _validate_version_and_set_if_needed(version, config, framework): +def _validate_version_and_set_if_needed(version, config, framework, image_scope): """Checks if the framework/algorithm version is one of the supported versions.""" + if not config: + config = config_for_framework(framework) available_versions = list(config["versions"].keys()) aliased_versions = list(config.get("version_aliases", {}).keys()) - if len(available_versions) == 1 and version not in aliased_versions: - log_message = "Defaulting to the only supported framework/algorithm version: {}.".format( - available_versions[0] - ) - if version and version != available_versions[0]: - logger.warning("%s Ignoring framework/algorithm version: %s.", log_message, version) - elif not version: - logger.info(log_message) - return available_versions[0] - - if version is None and framework in [ - DATA_WRANGLER_FRAMEWORK, - HUGGING_FACE_LLM_FRAMEWORK, - HUGGING_FACE_TEI_GPU_FRAMEWORK, - HUGGING_FACE_TEI_CPU_FRAMEWORK, - HUGGING_FACE_LLM_NEURONX_FRAMEWORK, - STABILITYAI_FRAMEWORK, - ]: - version = _get_latest_versions(available_versions) - + if not version: + version = _get_latest_version(framework, version, image_scope) _validate_arg(version, available_versions + aliased_versions, "{} version".format(framework)) return version @@ -746,3 +747,55 @@ def get_base_python_image_uri(region, py_version="310") -> str: repo_and_tag = repo + ":" + version return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo_and_tag) + + +def _fetch_latest_version_from_config( # pylint: disable=R0911 + framework_config: dict, image_scope: Optional[str] = None +) -> Optional[str]: + """Helper function to fetch the latest version as a string from a framework's config + + Args: + framework_config (dict): A framework config dict. + image_scope (str): Scope of the image, eg: training, inference + Returns: + Version string if latest version found else None + """ + if image_scope in framework_config: + if image_scope_config := framework_config[image_scope]: + if "version_aliases" in image_scope_config: + if "latest" in image_scope_config["version_aliases"]: + return image_scope_config["version_aliases"]["latest"] + top_version = None + bottom_version = None + + if "versions" in framework_config: + versions = list(framework_config["versions"].keys()) + if len(versions) == 1: + return versions[0] + top_version = versions[0] + bottom_version = versions[-1] + if top_version == "latest" or bottom_version == "latest": + return None + elif ( + image_scope is not None + and image_scope in framework_config + and "versions" in framework_config[image_scope] + ): + versions = list(framework_config[image_scope]["versions"].keys()) + top_version = versions[0] + bottom_version = versions[-1] + elif "processing" in framework_config and "versions" in framework_config["processing"]: + versions = list(framework_config["processing"]["versions"].keys()) + top_version = versions[0] + bottom_version = versions[-1] + if top_version and bottom_version: + if top_version.endswith(".x") or bottom_version.endswith(".x"): + top_number = int(top_version[:-2]) + bottom_number = int(bottom_version[:-2]) + max_version = max(top_number, bottom_number) + return f"{max_version}.x" + if Version(top_version) >= Version(bottom_version): + return top_version + return bottom_version + + return None diff --git a/src/sagemaker/modules/__init__.py b/src/sagemaker/modules/__init__.py new file mode 100644 index 0000000000..d7f209f00c --- /dev/null +++ b/src/sagemaker/modules/__init__.py @@ -0,0 +1,19 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""SageMaker modules directory.""" +from __future__ import absolute_import + +from sagemaker_core.main.utils import logger as sagemaker_core_logger +from sagemaker_core.helper.session_helper import Session, get_execution_role # noqa: F401 + +logger = sagemaker_core_logger diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py new file mode 100644 index 0000000000..ec0df519f5 --- /dev/null +++ b/src/sagemaker/modules/configs.py @@ -0,0 +1,219 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides the configuration classes used in ``sagemaker.modules``. + +Some of these classes are re-exported from ``sagemaker_core.shapes``. For convinence, +users can import these classes directly from ``sagemaker.modules.configs``. + +For more documentation on ``sagemaker_core.shapes``, see: + - https://sagemaker-core.readthedocs.io/en/stable/#sagemaker-core-shapes +""" + +from __future__ import absolute_import + +from typing import Optional, Union +from pydantic import BaseModel, model_validator + +import sagemaker_core.shapes as shapes + +# TODO: Can we add custom logic to some of these to set better defaults? +from sagemaker_core.shapes import ( + StoppingCondition, + RetryStrategy, + OutputDataConfig, + Channel, + ShuffleConfig, + DataSource, + S3DataSource, + FileSystemDataSource, + TrainingImageConfig, + TrainingRepositoryAuthConfig, + Tag, + InfraCheckConfig, + RemoteDebugConfig, + SessionChainingConfig, + InstanceGroup, + TensorBoardOutputConfig, + CheckpointConfig, +) + +from sagemaker.modules.utils import convert_unassigned_to_none + +__all__ = [ + "SourceCode", + "StoppingCondition", + "RetryStrategy", + "OutputDataConfig", + "Channel", + "ShuffleConfig", + "DataSource", + "S3DataSource", + "FileSystemDataSource", + "TrainingImageConfig", + "TrainingRepositoryAuthConfig", + "Tag", + "InfraCheckConfig", + "RemoteDebugConfig", + "SessionChainingConfig", + "InstanceGroup", + "TensorBoardOutputConfig", + "CheckpointConfig", + "Compute", + "Networking", + "InputData", +] + + +class SourceCode(BaseModel): + """SourceCode. + + The SourceCode class allows the user to specify the source code location, dependencies, + entry script, or commands to be executed in the training job container. + + Parameters: + source_dir (Optional[str]): + The local directory containing the source code to be used in the training job container. + requirements (Optional[str]): + The path within ``source_dir`` to a ``requirements.txt`` file. If specified, the listed + requirements will be installed in the training job container. + entry_script (Optional[str]): + The path within ``source_dir`` to the entry script that will be executed in the training + job container. If not specified, command must be provided. + command (Optional[str]): + The command(s) to execute in the training job container. Example: "python my_script.py". + If not specified, entry_script must be provided. + """ + + source_dir: Optional[str] = None + requirements: Optional[str] = None + entry_script: Optional[str] = None + command: Optional[str] = None + + +class Compute(shapes.ResourceConfig): + """Compute. + + The Compute class is a subclass of ``sagemaker_core.shapes.ResourceConfig`` + and allows the user to specify the compute resources for the training job. + + Parameters: + instance_type (Optional[str]): + The ML compute instance type. For information about available instance types, + see https://aws.amazon.com/sagemaker/pricing/. + instance_count (Optional[int]): The number of ML compute instances to use. For distributed + training, provide a value greater than 1. + volume_size_in_gb (Optional[int]): + The size of the ML storage volume that you want to provision. ML storage volumes store + model artifacts and incremental states. Training algorithms might also use the ML + storage volume for scratch space. Default: 30 + volume_kms_key_id (Optional[str]): + The Amazon Web Services KMS key that SageMaker uses to encrypt data on the storage + volume attached to the ML compute instance(s) that run the training job. + keep_alive_period_in_seconds (Optional[int]): + The duration of time in seconds to retain configured resources in a warm pool for + subsequent training jobs. + instance_groups (Optional[List[InstanceGroup]]): + A list of instance groups for heterogeneous clusters to be used in the training job. + enable_managed_spot_training (Optional[bool]): + To train models using managed spot training, choose True. Managed spot training + provides a fully managed and scalable infrastructure for training machine learning + models. this option is useful when training jobs can be interrupted and when there + is flexibility when the training job is run. + """ + + volume_size_in_gb: Optional[int] = 30 + enable_managed_spot_training: Optional[bool] = None + + @model_validator(mode="after") + def _model_validator(self) -> "Compute": + """Convert Unassigned values to None.""" + return convert_unassigned_to_none(self) + + def _to_resource_config(self) -> shapes.ResourceConfig: + """Convert to a sagemaker_core.shapes.ResourceConfig object.""" + compute_config_dict = self.model_dump() + resource_config_fields = set(shapes.ResourceConfig.__annotations__.keys()) + filtered_dict = { + k: v for k, v in compute_config_dict.items() if k in resource_config_fields + } + return shapes.ResourceConfig(**filtered_dict) + + +class Networking(shapes.VpcConfig): + """Networking. + + The Networking class is a subclass of ``sagemaker_core.shapes.VpcConfig`` and + allows the user to specify the networking configuration for the training job. + + Parameters: + security_group_ids (Optional[List[str]]): + The VPC security group IDs, in the form sg-xxxxxxxx. Specify the + security groups for the VPC that is specified in the Subnets field. + subnets (Optional[List[str]]): + The ID of the subnets in the VPC to which you want to connect your + training job or model. + enable_network_isolation (Optional[bool]): + Isolates the training container. No inbound or outbound network calls can be made, + except for calls between peers within a training cluster for distributed training. + If you enable network isolation for training jobs that are configured to use a VPC, + SageMaker downloads and uploads customer data and model artifacts through the + specified VPC, but the training container does not have network access. + enable_inter_container_traffic_encryption (Optional[bool]): + To encrypt all communications between ML compute instances in distributed training + choose True. Encryption provides greater security for distributed training, but + training might take longer. How long it takes depends on the amount of + communication between compute instances, especially if you use a deep learning + algorithm in distributed training. + """ + + enable_network_isolation: Optional[bool] = None + enable_inter_container_traffic_encryption: Optional[bool] = None + + @model_validator(mode="after") + def _model_validator(self) -> "Networking": + """Convert Unassigned values to None.""" + return convert_unassigned_to_none(self) + + def _to_vpc_config(self) -> shapes.VpcConfig: + """Convert to a sagemaker_core.shapes.VpcConfig object.""" + compute_config_dict = self.model_dump() + resource_config_fields = set(shapes.VpcConfig.__annotations__.keys()) + filtered_dict = { + k: v for k, v in compute_config_dict.items() if k in resource_config_fields + } + return shapes.VpcConfig(**filtered_dict) + + +class InputData(BaseModel): + """InputData. + + This config allows the user to specify an input data source for the training job. + + Will be found at ``/opt/ml/input/data/`` within the training container. + For convience, can be referenced inside the training container like: + + .. code:: python + + import os + input_data_dir = os.environ['SM_CHANNEL_'] + + Parameters: + channel_name (str): + The name of the input data source channel. + data_source (Union[str, S3DataSource, FileSystemDataSource]): + The data source for the channel. Can be an S3 URI string, local file path string, + S3DataSource object, or FileSystemDataSource object. + """ + + channel_name: str = None + data_source: Union[str, FileSystemDataSource, S3DataSource] = None diff --git a/src/sagemaker/modules/constants.py b/src/sagemaker/modules/constants.py new file mode 100644 index 0000000000..e64d85367d --- /dev/null +++ b/src/sagemaker/modules/constants.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Constants module.""" +from __future__ import absolute_import +import os + +DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge" + +SM_CODE = "code" +SM_CODE_CONTAINER_PATH = "/opt/ml/input/data/code" + +SM_DRIVERS = "sm_drivers" +SM_DRIVERS_CONTAINER_PATH = "/opt/ml/input/data/sm_drivers" +SM_DRIVERS_LOCAL_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "train/container_drivers" +) + +SOURCE_CODE_JSON = "sourcecode.json" +DISTRIBUTED_JSON = "distributed.json" +TRAIN_SCRIPT = "sm_train.sh" + +DEFAULT_CONTAINER_ENTRYPOINT = ["/bin/bash"] +DEFAULT_CONTAINER_ARGUMENTS = [ + "-c", + f"chmod +x {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT} " + + f"&& {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT}", +] diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py new file mode 100644 index 0000000000..6cdc136dcf --- /dev/null +++ b/src/sagemaker/modules/distributed.py @@ -0,0 +1,124 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Distributed module.""" +from __future__ import absolute_import + +from typing import Optional, Dict, Any, List +from pydantic import BaseModel, PrivateAttr +from sagemaker.modules.utils import safe_serialize + + +class SMP(BaseModel): + """SMP. + + This class is used for configuring the SageMaker Model Parallelism v2 parameters. + For more information on the model parallelism parameters, see: + https://docs.aws.amazon.com/sagemaker/latest/dg/distributed-model-parallel-v2-reference.html#distributed-model-parallel-v2-reference-init-config + + Parameters: + hybrid_shard_degree (Optional[int]): + Specifies a sharded parallelism degree for the model. + sm_activation_offloading (Optional[bool]): + Specifies whether to enable the SMP activation offloading implementation. + activation_loading_horizon (Optional[int]): + An integer specifying the activation offloading horizon type for FSDP. This is the + maximum number of checkpointed or offloaded layers whose inputs can be in the GPU + memory simultaneously. + fsdp_cache_flush_warnings (Optional[bool]): + Detects and warns if cache flushes happen in the PyTorch memory manager, because they + can degrade computational performance. + allow_empty_shards (Optional[bool]): + Whether to allow empty shards when sharding tensors if tensor is not divisible. This is + an experimental fix for crash during checkpointing in certain scenarios. Disabling this + falls back to the original PyTorch behavior. + tensor_parallel_degree (Optional[int]): + Specifies a tensor parallelism degree. The value must be between 1 and world_size. + context_parallel_degree (Optional[int]): + Specifies the context parallelism degree. The value must be between 1 and world_size , + and must be <= hybrid_shard_degree. + expert_parallel_degree (Optional[int]): + Specifies a expert parallelism degree. The value must be between 1 and world_size. + random_seed (Optional[int]): + A seed number for the random operations in distributed modules by SMP tensor + parallelism or expert parallelism. + """ + + hybrid_shard_degree: Optional[int] = None + sm_activation_offloading: Optional[bool] = None + activation_loading_horizon: Optional[int] = None + fsdp_cache_flush_warnings: Optional[bool] = None + allow_empty_shards: Optional[bool] = None + tensor_parallel_degree: Optional[int] = None + context_parallel_degree: Optional[int] = None + expert_parallel_degree: Optional[int] = None + random_seed: Optional[int] = None + + def _to_mp_hyperparameters(self) -> Dict[str, Any]: + """Converts to the hyperparameters format for the SageMaker Model Parallelism v2.""" + mp_parameters = self.model_dump(exclude_none=True) + hyperparameters = { + "mp_parameters": safe_serialize(mp_parameters), + } + return hyperparameters + + +class DistributedConfig(BaseModel): + """Base class for distributed training configurations.""" + + _type: str = PrivateAttr() + + def model_dump(self, *args, **kwargs): + """Dump the model to a dictionary.""" + result = super().model_dump(*args, **kwargs) + result["_type"] = self._type + return result + + +class Torchrun(DistributedConfig): + """Torchrun. + + The Torchrun class configures a job that uses ``torchrun`` or + ``torch.distributed.launch`` in the backend to launch distributed training. + + Parameters: + process_count_per_node (int): + The number of processes to run on each node in the training job. + Will default to the number of GPUs available in the container. + smp (Optional[SMP]): + The SageMaker Model Parallelism v2 parameters. + """ + + _type: str = PrivateAttr(default="torchrun") + + process_count_per_node: Optional[int] = None + smp: Optional["SMP"] = None + + +class MPI(DistributedConfig): + """MPI. + + The MPI class configures a job that uses ``mpirun`` in the backend to launch + distributed training. + + Parameters: + process_count_per_node (int): + The number of processes to run on each node in the training job. + Will default to the number of GPUs available in the container. + mpi_additional_options (Optional[str]): + The custom MPI options to use for the training job. + """ + + _type: str = PrivateAttr(default="mpi") + + process_count_per_node: Optional[int] = None + mpi_additional_options: Optional[List[str]] = None diff --git a/src/sagemaker/modules/local_core/local_container.py b/src/sagemaker/modules/local_core/local_container.py new file mode 100644 index 0000000000..5424f4f865 --- /dev/null +++ b/src/sagemaker/modules/local_core/local_container.py @@ -0,0 +1,590 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""LocalContainer class module.""" +from __future__ import absolute_import + +import base64 +import os +import re +import shutil +import subprocess +from tempfile import TemporaryDirectory +from typing import Any, Dict, List, Optional +from pydantic import BaseModel, ConfigDict + +from sagemaker.local.image import ( + _Volume, + _aws_credentials, + _check_output, + _pull_image, + _stream_output, + _write_json_file, +) +from sagemaker.local.utils import check_for_studio, recursive_copy +from sagemaker.model import DIR_PARAM_NAME +from sagemaker.modules import logger, Session +from sagemaker.modules.configs import Channel +from sagemaker.utils import ECR_URI_PATTERN, create_tar_file, _module_import_error, download_folder +from sagemaker_core.main.utils import Unassigned +from sagemaker_core.shapes import DataSource + +from six.moves.urllib.parse import urlparse + +STUDIO_HOST_NAME = "sagemaker-local" +DOCKER_COMPOSE_FILENAME = "docker-compose.yaml" +DOCKER_COMPOSE_HTTP_TIMEOUT_ENV = "COMPOSE_HTTP_TIMEOUT" +DOCKER_COMPOSE_HTTP_TIMEOUT = "120" + +REGION_ENV_NAME = "AWS_REGION" +TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME" +S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL" +S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL" +SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE" + + +class _LocalContainer(BaseModel): + """A local training job class for local mode model trainer. + + Attributes: + training_job_name (str): + The name of the training job. + instance_type (str): + The instance type. + instance_count (int): + The number of instances. + image (str): + The image name for training. + container_root (str): + The directory path for the local container root. + input_from_s3 (bool): + If the input is from s3. + is_studio (bool): + If the container is running on SageMaker studio instance. + hosts (Optional[List[str]]): + The list of host names. + input_data_config: Optional[List[Channel]] + The input data channels for the training job. + Takes a list of Channel objects or a dictionary of channel names to DataSourceType. + DataSourceType can be an S3 URI string, local file path string, + S3DataSource object, or FileSystemDataSource object. + environment (Optional[Dict[str, str]]): + The environment variables for the training job. + hyper_parameters (Optional[Dict[str, Any]]): + The hyperparameters for the training job. + sagemaker_session (Optional[Session]): + The SageMaker session. + For local mode training, SageMaker session will only be used when input is from S3 or + image needs to be pulled from ECR. + container_entrypoint (Optional[List[str]]): + The command to be executed in the container. + container_arguments (Optional[List[str]]): + The arguments of the container commands. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + training_job_name: str + instance_type: str + instance_count: int + image: str + container_root: str + input_from_s3: Optional[bool] = False + is_studio: Optional[bool] = False + hosts: Optional[List[str]] = [] + input_data_config: Optional[List[Channel]] + environment: Optional[Dict[str, str]] + hyper_parameters: Optional[Dict[str, str]] + sagemaker_session: Optional[Session] = None + container_entrypoint: Optional[List[str]] + container_arguments: Optional[List[str]] + + def model_post_init(self, __context: Any): + """Post init method to perform custom validation and set default values.""" + self.hosts = [f"algo-{i}" for i in range(1, self.instance_count + 1)] + if self.environment is None: + self.environment = {} + if self.hyper_parameters is None: + self.hyper_parameters = {} + + for channel in self.input_data_config: + if channel.data_source and channel.data_source.s3_data_source != Unassigned(): + self.input_from_s3 = True + data_distribution = channel.data_source.s3_data_source.s3_data_distribution_type + if self.sagemaker_session is None: + # In local mode only initiate session when neccessary + self.sagemaker_session = Session() + elif ( + channel.data_source and channel.data_source.file_system_data_source != Unassigned() + ): + self.input_from_s3 = False + data_distribution = channel.data_source.file_system_data_source.file_system_type + else: + raise ValueError( + "Need channel.data_source to have s3_data_source or file_system_data_source" + ) + + supported_distributions = ["FullyReplicated", "EFS"] + if data_distribution and data_distribution not in supported_distributions: + raise RuntimeError( + "Invalid Data Distribution: '{}'. Local mode currently supports FullyReplicated " + "Distribution for S3 data source and EFS Distribution for local data source.".format( + data_distribution, + ) + ) + self.is_studio = check_for_studio() + + def train( + self, + wait: bool, + ) -> str: + """Run a training job locally using docker-compose. + + Args: + wait (bool): + Whether to wait the training output before exiting. + """ + # create output/data folder since sagemaker-containers 2.0 expects it + os.makedirs(os.path.join(self.container_root, "output", "data"), exist_ok=True) + # A shared directory for all the containers. It is only mounted if the training script is + # Local. + os.makedirs(os.path.join(self.container_root, "shared"), exist_ok=True) + + data_dir = os.path.join(self.container_root, "input", "data") + os.makedirs(data_dir, exist_ok=True) + volumes = self._prepare_training_volumes( + data_dir, self.input_data_config, self.hyper_parameters + ) + # If local, source directory needs to be updated to mounted /opt/ml/code path + if DIR_PARAM_NAME in self.hyper_parameters: + src_dir = self.hyper_parameters[DIR_PARAM_NAME] + parsed_uri = urlparse(src_dir) + if parsed_uri.scheme == "file": + self.hyper_parameters[DIR_PARAM_NAME] = "/opt/ml/code" + + for host in self.hosts: + # Create the configuration files + self._create_config_file_directories(host) + self._write_config_files(host, self.input_data_config, self.hyper_parameters) + + self.environment[TRAINING_JOB_NAME_ENV_NAME] = self.training_job_name + if self.input_from_s3: + self.environment[S3_ENDPOINT_URL_ENV_NAME] = ( + self.sagemaker_session.s3_resource.meta.client._endpoint.host + ) + + if self._ecr_login_if_needed(): + _pull_image(self.image) + + if self.sagemaker_session: + self.environment[REGION_ENV_NAME] = self.sagemaker_session.boto_region_name + + compose_data = self._generate_compose_file(self.environment, volumes) + compose_command = self._generate_compose_command(wait) + process = subprocess.Popen( + compose_command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) + + try: + _stream_output(process) + finally: + artifacts = self.retrieve_artifacts(compose_data) + + # Print our Job Complete line + logger.info("Local training job completed, output artifacts saved to %s", artifacts) + return artifacts + + def retrieve_artifacts( + self, + compose_data: dict, + ): + """Get the model artifacts from all the container nodes. + + Used after training completes to gather the data from all the + individual containers. As the official SageMaker Training Service, it + will override duplicate files if multiple containers have the same file + names. + + Args: + compose_data (dict): Docker-Compose configuration in dictionary + format. + + Returns: Local path to the collected model artifacts. + """ + # We need a directory to store the artfiacts from all the nodes + # and another one to contained the compressed final artifacts + artifacts = os.path.join(self.container_root, "artifacts") + compressed_artifacts = os.path.join(self.container_root, "compressed_artifacts") + os.makedirs(artifacts, exist_ok=True) + + model_artifacts = os.path.join(artifacts, "model") + output_artifacts = os.path.join(artifacts, "output") + + artifact_dirs = [model_artifacts, output_artifacts, compressed_artifacts] + for d in artifact_dirs: + os.makedirs(d, exist_ok=True) + + # Gather the artifacts from all nodes into artifacts/model and artifacts/output + for host in self.hosts: + volumes = compose_data["services"][str(host)]["volumes"] + volumes = [v[:-2] if v.endswith(":z") else v for v in volumes] + for volume in volumes: + if re.search(r"^[A-Za-z]:", volume): + unit, host_dir, container_dir = volume.split(":") + host_dir = unit + ":" + host_dir + else: + host_dir, container_dir = volume.split(":") + if container_dir == "/opt/ml/model": + recursive_copy(host_dir, model_artifacts) + elif container_dir == "/opt/ml/output": + recursive_copy(host_dir, output_artifacts) + + # Tar Artifacts -> model.tar.gz and output.tar.gz + model_files = [os.path.join(model_artifacts, name) for name in os.listdir(model_artifacts)] + output_files = [ + os.path.join(output_artifacts, name) for name in os.listdir(output_artifacts) + ] + create_tar_file(model_files, os.path.join(compressed_artifacts, "model.tar.gz")) + create_tar_file(output_files, os.path.join(compressed_artifacts, "output.tar.gz")) + + output_data = "file://%s" % compressed_artifacts + + return os.path.join(output_data, "model.tar.gz") + + def _create_config_file_directories(self, host: str): + """Creates the directories for the config files. + + Args: + host (str): The name of the current host. + """ + for d in ["input", "input/config", "output", "model"]: + os.makedirs(os.path.join(self.container_root, host, d), exist_ok=True) + + def _write_config_files( + self, + host: str, + input_data_config: Optional[List[Channel]], + hyper_parameters: Optional[Dict[str, str]], + ): + """Write the config files for the training containers. + + This method writes the hyper_parameters, resources and input data + configuration files. + + Returns: None + + Args: + host (str): The name of the current host. + input_data_config (List[Channel]): Training input channels to be used for + training. + hyper_parameters (Dict[str, str]): Hyperparameters for training. + """ + config_path = os.path.join(self.container_root, host, "input", "config") + # Only support single container now + resource_config = { + "current_host": host, + "hosts": self.hosts, + "network_interface_name": "ethwe", + "current_instance_type": self.instance_type, + } + + json_input_data_config = {} + for channel in input_data_config: + channel_name = channel.channel_name + json_input_data_config[channel_name] = {"TrainingInputMode": "File"} + if channel.content_type != Unassigned(): + json_input_data_config[channel_name]["ContentType"] = channel.content_type + + _write_json_file(os.path.join(config_path, "hyperparameters.json"), hyper_parameters) + _write_json_file(os.path.join(config_path, "resourceconfig.json"), resource_config) + _write_json_file(os.path.join(config_path, "inputdataconfig.json"), json_input_data_config) + + def _generate_compose_file(self, environment: Dict[str, str], volumes: List[str]) -> dict: + """Writes a config file describing a training/hosting environment. + + This method generates a docker compose configuration file, it has an + entry for each container that will be created (based on self.hosts). it + calls + :meth:~sagemaker.local_session.SageMakerContainer._create_docker_host to + generate the config for each individual container. + + Args: + environment (Dict[str, str]): a dictionary with environment variables to be + passed on to the containers. + volumes (List[str]): a list of volumes that will be mapped to + the containers + + Returns: (dict) A dictionary representation of the configuration that was written. + """ + + if os.environ.get(DOCKER_COMPOSE_HTTP_TIMEOUT_ENV) is None: + os.environ[DOCKER_COMPOSE_HTTP_TIMEOUT_ENV] = DOCKER_COMPOSE_HTTP_TIMEOUT + + services = { + host: self._create_docker_host(host, environment, volumes) for host in self.hosts + } + + if self.is_studio: + content = { + "services": services, + } + else: + content = { + "services": services, + "networks": {"sagemaker-local": {"name": "sagemaker-local"}}, + } + + docker_compose_path = os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME) + + try: + import yaml + except ImportError as e: + logger.error(_module_import_error("yaml", "Local mode", "local")) + raise e + + yaml_content = yaml.dump(content, default_flow_style=False) + with open(docker_compose_path, "w") as f: + f.write(yaml_content) + + return content + + def _create_docker_host( + self, + host: str, + environment: Dict[str, str], + volumes: List[str], + ) -> Dict: + """Creates the docker host configuration. + + Args: + host (str): The host address + environment (Dict[str, str]): a dictionary with environment variables to be + passed on to the containers. + volumes (List[str]): List of volumes that will be mapped to the containers + """ + environment = ["{}={}".format(k, v) for k, v in environment.items()] + aws_creds = None + if self.sagemaker_session: + # In local mode only get aws credentials when neccessary + aws_creds = _aws_credentials(self.sagemaker_session.boto_session) + if aws_creds is not None: + environment.extend(aws_creds) + + if self.is_studio: + environment.extend([f"{SM_STUDIO_LOCAL_MODE}=True"]) + + # Add volumes for the input and output of each host + host_volumes = volumes.copy() + subdirs = ["output", "output/data", "input"] + for subdir in subdirs: + host_dir = os.path.join(self.container_root, host, subdir) + container_dir = "/opt/ml/{}".format(subdir) + volume = _Volume(host_dir, container_dir) + host_volumes.append(volume.map) + + host_config = { + "image": self.image, + "volumes": host_volumes, + "environment": environment, + } + + if self.container_entrypoint: + host_config["entrypoint"] = self.container_entrypoint + if self.container_arguments: + host_config["entrypoint"] = host_config["entrypoint"] + self.container_arguments + + if self.is_studio: + host_config["network_mode"] = "sagemaker" + else: + host_config["networks"] = {"sagemaker-local": {"aliases": [host]}} + + # for GPU support pass in nvidia as the runtime, this is equivalent + # to setting --runtime=nvidia in the docker commandline. + if self.instance_type == "local_gpu": + host_config["deploy"] = { + "resources": { + "reservations": {"devices": [{"count": "all", "capabilities": ["gpu"]}]} + } + } + + return host_config + + def _generate_compose_command(self, wait: bool): + """Invokes the docker compose command. + + Args: + wait (bool): Whether to wait for the docker command result. + """ + _compose_cmd_prefix = self._get_compose_cmd_prefix() + + command = _compose_cmd_prefix + [ + "-f", + os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME), + "up", + "--build", + "--abort-on-container-exit" if wait else "--detach", + ] + + logger.info("docker command: %s", " ".join(command)) + return command + + def _ecr_login_if_needed(self): + """Log into ECR, if needed. + + Only ECR images that not have been pulled locally need login. + """ + sagemaker_pattern = re.compile(ECR_URI_PATTERN) + sagemaker_match = sagemaker_pattern.match(self.image) + if not sagemaker_match: + return False + + # Do we already have the image locally? + if _check_output("docker images -q %s" % self.image).strip(): + return False + + if not self.sagemaker_session: + # In local mode only initiate session when neccessary + self.sagemaker_session = Session() + + ecr = self.sagemaker_session.boto_session.client("ecr") + auth = ecr.get_authorization_token(registryIds=[self.image.split(".")[0]]) + authorization_data = auth["authorizationData"][0] + + raw_token = base64.b64decode(authorization_data["authorizationToken"]) + token = raw_token.decode("utf-8").strip("AWS:") + ecr_url = auth["authorizationData"][0]["proxyEndpoint"] + + # Log in to ecr, but use communicate to not print creds to the console + cmd = f"docker login {ecr_url} -u AWS --password-stdin".split() + proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + ) + + proc.communicate(input=token.encode()) + + return True + + def _prepare_training_volumes( + self, + data_dir: str, + input_data_config: Optional[List[Channel]], + hyper_parameters: Optional[Dict[str, str]], + ) -> List[str]: + """Prepares the training volumes based on input and output data configs. + + Args: + data_dir (str): The directory of input data. + input_data_config (Optional[List[Channel]]): Training input channels to be used for + training. + hyper_parameters (Optional[Dict[str, str]]): Hyperparameters for training. + """ + volumes = [] + model_dir = os.path.join(self.container_root, "model") + volumes.append(_Volume(model_dir, "/opt/ml/model").map) + + # Mount the metadata directory if present. + # Only expected to be present on SM notebook instances. + # This is used by some DeepEngine libraries + metadata_dir = "/opt/ml/metadata" + if os.path.isdir(metadata_dir): + volumes.append(_Volume(metadata_dir, metadata_dir).map) + + # Set up the channels for the containers. For local data we will + # mount the local directory to the container. For S3 Data we will download the S3 data + # first. + for channel in input_data_config: + channel_name = channel.channel_name + channel_dir = os.path.join(data_dir, channel_name) + os.makedirs(channel_dir, exist_ok=True) + + data_source_local_path = self._get_data_source_local_path(channel.data_source) + volumes.append(_Volume(data_source_local_path, channel=channel_name).map) + + # If there is a training script directory and it is a local directory, + # mount it to the container. + if DIR_PARAM_NAME in hyper_parameters: + training_dir = hyper_parameters[DIR_PARAM_NAME] + parsed_uri = urlparse(training_dir) + if parsed_uri.scheme == "file": + host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path) + volumes.append(_Volume(host_dir, "/opt/ml/code").map) + shared_dir = os.path.join(self.container_root, "shared") + volumes.append(_Volume(shared_dir, "/opt/ml/shared").map) + + return volumes + + def _get_data_source_local_path(self, data_source: DataSource): + """Return a local data path of :class:`sagemaker.local.data.DataSource`. + + If the data source is from S3, the data will be downloaded to a temporary + local path. + If the data source is local file, the absolute path will be returned. + + Args: + data_source (DataSource): a data source of local file or s3 + + Returns: + str: The local path of the data. + """ + if data_source.s3_data_source != Unassigned(): + uri = data_source.s3_data_source.s3_uri + parsed_uri = urlparse(uri) + local_dir = TemporaryDirectory(prefix=os.path.join(self.container_root + "/")).name + download_folder(parsed_uri.netloc, parsed_uri.path, local_dir, self.sagemaker_session) + return local_dir + else: + return os.path.abspath(data_source.file_system_data_source.directory_path) + + def _get_compose_cmd_prefix(self) -> List[str]: + """Gets the Docker Compose command. + + The method initially looks for 'docker compose' v2 + executable, if not found looks for 'docker-compose' executable. + + Returns: + List[str]: Docker Compose executable split into list. + + Raises: + ImportError: If Docker Compose executable was not found. + """ + compose_cmd_prefix = [] + + output = None + try: + output = subprocess.check_output( + ["docker", "compose", "version"], + stderr=subprocess.DEVNULL, + encoding="UTF-8", + ) + except subprocess.CalledProcessError: + logger.info( + "'Docker Compose' is not installed. " + "Proceeding to check for 'docker-compose' CLI." + ) + + if output and "v2" in output.strip(): + logger.info("'Docker Compose' found using Docker CLI.") + compose_cmd_prefix.extend(["docker", "compose"]) + return compose_cmd_prefix + + if shutil.which("docker-compose") is not None: + logger.info("'Docker Compose' found using Docker Compose CLI.") + compose_cmd_prefix.extend(["docker-compose"]) + return compose_cmd_prefix + + raise ImportError( + "Docker Compose is not installed. " + "Local Mode features will not work without docker compose. " + "For more information on how to install 'docker compose', please, see " + "https://docs.docker.com/compose/install/" + ) diff --git a/src/sagemaker/modules/templates.py b/src/sagemaker/modules/templates.py new file mode 100644 index 0000000000..fba60dda47 --- /dev/null +++ b/src/sagemaker/modules/templates.py @@ -0,0 +1,88 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Templates module.""" +from __future__ import absolute_import + +EXECUTE_BASE_COMMANDS = """ +CMD="{base_command}" +echo "Executing command: $CMD" +eval $CMD +""" + +EXECUTE_BASIC_SCRIPT_DRIVER = """ +echo "Running Basic Script driver" +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/basic_script_driver.py +""" + +EXEUCTE_TORCHRUN_DRIVER = """ +echo "Running Torchrun driver" +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/torchrun_driver.py +""" + +EXECUTE_MPI_DRIVER = """ +echo "Running MPI driver" +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/mpi_driver.py +""" + +TRAIN_SCRIPT_TEMPLATE = """ +#!/bin/bash +set -e +echo "Starting training script" + +handle_error() {{ + EXIT_STATUS=$? + echo "An error occurred with exit code $EXIT_STATUS" + if [ ! -s /opt/ml/output/failure ]; then + echo "Training Execution failed. For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'. +TrainingJob - $TRAINING_JOB_NAME" >> /opt/ml/output/failure + fi + exit $EXIT_STATUS +}} + +check_python() {{ + SM_PYTHON_CMD=$(command -v python3 || command -v python) + SM_PIP_CMD=$(command -v pip3 || command -v pip) + + # Check if Python is found + if [[ -z "$SM_PYTHON_CMD" || -z "$SM_PIP_CMD" ]]; then + echo "Error: The Python executable was not found in the system path." + return 1 + fi + + return 0 +}} + +trap 'handle_error' ERR + +check_python + +$SM_PYTHON_CMD --version + +echo "/opt/ml/input/config/resourceconfig.json:" +cat /opt/ml/input/config/resourceconfig.json +echo + +echo "/opt/ml/input/config/inputdataconfig.json:" +cat /opt/ml/input/config/inputdataconfig.json +echo + +echo "Setting up environment variables" +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/scripts/environment.py +source /opt/ml/input/sm_training.env + +{working_dir} +{install_requirements} +{execute_driver} + +echo "Training Container Execution Completed" +""" diff --git a/src/sagemaker/modules/train/__init__.py b/src/sagemaker/modules/train/__init__.py new file mode 100644 index 0000000000..51fa17fe04 --- /dev/null +++ b/src/sagemaker/modules/train/__init__.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules train directory.""" +from __future__ import absolute_import + +from sagemaker.modules.train.model_trainer import ModelTrainer # noqa: F401 diff --git a/src/sagemaker/modules/train/container_drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/__init__.py new file mode 100644 index 0000000000..18557a2eb5 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container_drivers directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/basic_script_driver.py b/src/sagemaker/modules/train/container_drivers/basic_script_driver.py new file mode 100644 index 0000000000..cb0278bc9f --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/basic_script_driver.py @@ -0,0 +1,79 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module is the entry point for the Basic Script Driver.""" +from __future__ import absolute_import + +import sys +import shlex + +from typing import List + +from utils import ( + logger, + get_python_executable, + read_source_code_json, + read_hyperparameters_json, + execute_commands, + write_failure_file, + hyperparameters_to_cli_args, +) + + +def create_commands() -> List[str]: + """Create the commands to execute.""" + source_code = read_source_code_json() + hyperparameters = read_hyperparameters_json() + python_executable = get_python_executable() + + entry_script = source_code["entry_script"] + args = hyperparameters_to_cli_args(hyperparameters) + if entry_script.endswith(".py"): + commands = [python_executable, entry_script] + commands += args + elif entry_script.endswith(".sh"): + args_str = " ".join(shlex.quote(arg) for arg in args) + commands = [ + "/bin/sh", + "-c", + f"chmod +x {entry_script} && ./{entry_script} {args_str}", + ] + else: + raise ValueError( + f"Unsupported entry script type: {entry_script}. Only .py and .sh are supported." + ) + return commands + + +def main(): + """Main function for the Basic Script Driver. + + This function is the entry point for the Basic Script Driver. + + Execution Lifecycle: + 1. Read the source code and hyperparameters JSON files. + 2. Set hyperparameters as command line arguments. + 3. Create the commands to execute. + 4. Execute the commands. + """ + + cmd = create_commands() + + logger.info(f"Executing command: {' '.join(cmd)}") + exit_code, traceback = execute_commands(cmd) + if exit_code != 0: + write_failure_file(traceback) + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/src/sagemaker/modules/train/container_drivers/mpi_driver.py b/src/sagemaker/modules/train/container_drivers/mpi_driver.py new file mode 100644 index 0000000000..dceb748cc0 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/mpi_driver.py @@ -0,0 +1,106 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module is the entry point for the MPI driver script.""" +from __future__ import absolute_import + +import os +import sys +import json + +from utils import ( + logger, + read_source_code_json, + read_distributed_json, + read_hyperparameters_json, + hyperparameters_to_cli_args, + get_process_count, + execute_commands, + write_failure_file, + USER_CODE_PATH, +) +from mpi_utils import ( + start_sshd_daemon, + bootstrap_master_node, + bootstrap_worker_node, + get_mpirun_command, + write_status_file_to_workers, + write_env_vars_to_file, +) + + +def main(): + """Main function for the MPI driver script. + + The MPI Dirver is responsible for setting up the MPI environment, + generating the correct mpi commands, and launching the MPI job. + + Execution Lifecycle: + 1. Setup General Environment Variables at /etc/environment + 2. Start SSHD Daemon + 3. Bootstrap Worker Nodes + a. Wait to establish connection with Master Node + b. Wait for Master Node to write status file + 4. Bootstrap Master Node + a. Wait to establish connection with Worker Nodes + b. Generate MPI Command + c. Execute MPI Command with user script provided in `entry_script` + d. Write status file to Worker Nodes + 5. Exit + + """ + source_code = read_source_code_json() + distribution = read_distributed_json() + hyperparameters = read_hyperparameters_json() + + sm_current_host = os.environ["SM_CURRENT_HOST"] + sm_hosts = json.loads(os.environ["SM_HOSTS"]) + sm_master_addr = os.environ["SM_MASTER_ADDR"] + + write_env_vars_to_file() + start_sshd_daemon() + + if sm_current_host != sm_master_addr: + bootstrap_worker_node(sm_master_addr) + else: + worker_hosts = [host for host in sm_hosts if host != sm_master_addr] + bootstrap_master_node(worker_hosts) + + host_list = json.loads(os.environ["SM_HOSTS"]) + host_count = int(os.environ["SM_HOST_COUNT"]) + process_count = get_process_count(distribution) + + if process_count > 1: + host_list = ["{}:{}".format(host, process_count) for host in host_list] + + mpi_command = get_mpirun_command( + host_count=host_count, + host_list=host_list, + num_processes=process_count, + additional_options=distribution.get("mpi_additional_options", []), + entry_script_path=os.path.join(USER_CODE_PATH, source_code["entry_script"]), + ) + + args = hyperparameters_to_cli_args(hyperparameters) + mpi_command += args + + logger.info(f"Executing command: {' '.join(mpi_command)}") + exit_code, error_traceback = execute_commands(mpi_command) + write_status_file_to_workers(worker_hosts) + + if exit_code != 0: + write_failure_file(error_traceback) + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/src/sagemaker/modules/train/container_drivers/mpi_utils.py b/src/sagemaker/modules/train/container_drivers/mpi_utils.py new file mode 100644 index 0000000000..c3c2b7effe --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/mpi_utils.py @@ -0,0 +1,265 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides mpi related utility functions for the container drivers.""" +from __future__ import absolute_import + +import os +import time +import subprocess + +from typing import List + +from utils import logger, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable + +FINISHED_STATUS_FILE = "/tmp/done.algo-1" +READY_FILE = "/tmp/ready.%s" +DEFAULT_SSH_PORT = 22 + + +def _write_file_to_host(host: str, status_file: str) -> bool: + """Write the a file to the provided host.""" + try: + logger.info(f"Writing {status_file} to {host}") + subprocess.run( + ["ssh", host, "touch", f"{status_file}"], + capture_output=True, + text=True, + check=True, + ) + logger.info("Finished writing status file") + return True + except subprocess.CalledProcessError: + logger.info(f"Cannot connect to {host}") + return False + + +def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE): + """Write the status file to all worker nodes.""" + for worker in worker_hosts: + retry = 0 + while not _write_file_to_host(worker, status_file): + time.sleep(5) + retry += 1 + if retry > 5: + raise TimeoutError(f"Timed out waiting for {worker} to be reachable.") + logger.info(f"Retrying to write status file to {worker}") + + +def _wait_for_status_file(status_file: str): + """Wait for the status file to be created.""" + logger.info(f"Waiting for status file {status_file}") + while not os.path.exists(status_file): + time.sleep(30) + logger.info(f"Found status file {status_file}") + + +def start_sshd_daemon(): + """Start the SSH daemon on the current node.""" + sshd_executable = "/usr/sbin/sshd" + + if not os.path.exists(sshd_executable): + raise RuntimeError("SSH daemon not found.") + + # Start the sshd in daemon mode (-D) + subprocess.Popen([sshd_executable, "-D"]) + logger.info("Started SSH daemon.") + + +def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: + """Check if the connection to the provided host and port is possible.""" + try: + import paramiko + + logger.debug("Testing connection to host %s", host) + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect(host, port=port) + client.close() + logger.info("Can connect to host %s", host) + return True + except Exception as e: # pylint: disable=W0703 + logger.info("Cannot connect to host %s", host) + logger.debug(f"Connection failed with exception: {e}") + return False + + +def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300): + """Master node waits until it can connect to all worker nodes.""" + start_time = time.time() + if not worker_hosts: + logger.info("No worker nodes to connect to.") + return + + while True: + logger.info("Master is attempting to connect to all workers...") + all_workers_connected = all( + _can_connect(worker, port) and os.path.exists(READY_FILE % worker) + for worker in worker_hosts + ) + + if all_workers_connected: + logger.info("Master can connect to all worker nodes.") + break + if time.time() - start_time > timeout: + raise TimeoutError("Timed out waiting for workers to be reachable.") + + time.sleep(5) # Wait for 5 seconds before trying again + + +def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300): + """Worker nodes wait until they can connect to the master node.""" + start_time = time.time() + while True: + logger.info(f"Worker is attempting to connect to the master node {master_host}...") + if _can_connect(master_host, port): + logger.info(f"Worker can connect to master node {master_host}.") + break + if time.time() - start_time > timeout: + raise TimeoutError(f"Timed out waiting for master {master_host} to be reachable.") + + time.sleep(5) # Wait for 5 seconds before trying again + + +def bootstrap_worker_node(master_host: str, status_file: str = FINISHED_STATUS_FILE): + """Bootstrap the worker nodes.""" + logger.info("Bootstrapping worker node...") + _wait_for_master(master_host) + _write_file_to_host(master_host, READY_FILE % os.environ["SM_CURRENT_HOST"]) + _wait_for_status_file(status_file) + + +def bootstrap_master_node(worker_hosts: List[str]): + """Bootstrap the master node.""" + logger.info("Bootstrapping master node...") + _wait_for_workers(worker_hosts) + + +def validate_smddprun() -> bool: + """Whether smddprun is installed. + + Returns: + bool: True if installed + """ + try: + output = subprocess.run( + ["which", "smddprun"], + capture_output=True, + text=True, + check=True, + ) + return output.stdout != "" + except subprocess.CalledProcessError: + return False + + +def validate_smddpmprun() -> bool: + """Whether smddpmprun is installed. + + Returns: + bool: True if both are installed + """ + try: + output = subprocess.run( + ["which", "smddpmprun"], + capture_output=True, + text=True, + check=True, + ) + return output.stdout != "" + except subprocess.CalledProcessError: + return False + + +def write_env_vars_to_file(): + """Write environment variables to /etc/environment file.""" + with open("/etc/environment", "a") as f: + for name in os.environ: + f.write("{}={}\n".format(name, os.environ.get(name))) + + +def get_mpirun_command( + host_count: int, + host_list: List[str], + num_processes: int, + additional_options: List[str], + entry_script_path: str, +): + """Fetch mpi command""" + network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0") + + mpirun_command = [ + "mpirun", + "--host", + ",".join(host_list), + "-np", + str(num_processes), + "--allow-run-as-root", + "--tag-output", + "-mca", + "btl_tcp_if_include", + network_interface_name, + "-mca", + "oob_tcp_if_include", + network_interface_name, + "-mca", + "plm_rsh_no_tree_spawn", + "1", + "-mca", + "pml", + "ob1", + "-mca", + "btl", + "^openib", + "-mca", + "orte_abort_on_non_zero_status", + "1", + "-mca", + "btl_vader_single_copy_mechanism", + "none", + "-mca", + "plm_rsh_num_concurrent", + str(host_count), + "-x", + "NCCL_SOCKET_IFNAME=%s" % network_interface_name, + "-x", + "LD_LIBRARY_PATH", + "-x", + "PATH", + ] + + if additional_options: + mpirun_command.extend(additional_options) + + instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"] + # EFA settings + if instance_type in SM_EFA_NCCL_INSTANCES: + mpirun_command.extend(["-x", "FI_PROVIDER=efa"]) + # Use simple protocol to handle the out-of-order data delivery from EFA + mpirun_command.extend(["-x", "NCCL_PROTO=simple"]) + + if instance_type in SM_EFA_RDMA_INSTANCES: + # Use EFA's RDMA functionality for one-sided and two-sided transfer + mpirun_command.extend(["-x", "FI_EFA_USE_DEVICE_RDMA=1"]) + + for credential in [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + ]: + if credential in os.environ: + mpirun_command.extend(["-x", credential]) + + mpirun_command.extend([get_python_executable()]) + mpirun_command.extend(["-m", "mpi4py", entry_script_path]) + return mpirun_command diff --git a/src/sagemaker/modules/train/container_drivers/scripts/__init__.py b/src/sagemaker/modules/train/container_drivers/scripts/__init__.py new file mode 100644 index 0000000000..1abbce4067 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/scripts/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules scripts directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/scripts/environment.py b/src/sagemaker/modules/train/container_drivers/scripts/environment.py new file mode 100644 index 0000000000..ea6abac425 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/scripts/environment.py @@ -0,0 +1,287 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module is used to define the environment variables for the training job container.""" +from __future__ import absolute_import + +from typing import Dict, Any +import multiprocessing +import subprocess +import json +import os +import sys +import logging + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) + +from utils import safe_serialize, safe_deserialize # noqa: E402 # pylint: disable=C0413 + +# Initialize logger +SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) +logger = logging.getLogger(__name__) +console_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(console_handler) +logger.setLevel(int(SM_LOG_LEVEL)) + +SM_MODEL_DIR = "/opt/ml/model" + +SM_INPUT_DIR = "/opt/ml/input" +SM_INPUT_DATA_DIR = "/opt/ml/input/data" +SM_INPUT_CONFIG_DIR = "/opt/ml/input/config" + +SM_OUTPUT_DIR = "/opt/ml/output" +SM_OUTPUT_FAILURE = "/opt/ml/output/failure" +SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" + +SM_MASTER_ADDR = "algo-1" +SM_MASTER_PORT = 7777 + +RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json" +INPUT_DATA_CONFIG = f"{SM_INPUT_CONFIG_DIR}/inputdataconfig.json" +HYPERPARAMETERS_CONFIG = f"{SM_INPUT_CONFIG_DIR}/hyperparameters.json" + +ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env" + +SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"] +HIDDEN_VALUE = "******" + + +def num_cpus() -> int: + """Return the number of CPUs available in the current container. + + Returns: + int: Number of CPUs available in the current container. + """ + return multiprocessing.cpu_count() + + +def num_gpus() -> int: + """Return the number of GPUs available in the current container. + + Returns: + int: Number of GPUs available in the current container. + """ + try: + cmd = ["nvidia-smi", "--list-gpus"] + output = subprocess.check_output(cmd).decode("utf-8") + return sum(1 for line in output.splitlines() if line.startswith("GPU ")) + except (OSError, subprocess.CalledProcessError): + logger.info("No GPUs detected (normal if no gpus installed)") + return 0 + + +def num_neurons() -> int: + """Return the number of neuron cores available in the current container. + + Returns: + int: Number of Neuron Cores available in the current container. + """ + try: + cmd = ["neuron-ls", "-j"] + output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") + j = json.loads(output) + neuron_cores = 0 + for item in j: + neuron_cores += item.get("nc_count", 0) + logger.info("Found %s neurons on this instance", neuron_cores) + return neuron_cores + except OSError: + logger.info("No Neurons detected (normal if no neurons installed)") + return 0 + except subprocess.CalledProcessError as e: + if e.output is not None: + try: + msg = e.output.decode("utf-8").partition("error=")[2] + logger.info( + "No Neurons detected (normal if no neurons installed). \ + If neuron installed then %s", + msg, + ) + except AttributeError: + logger.info("No Neurons detected (normal if no neurons installed)") + else: + logger.info("No Neurons detected (normal if no neurons installed)") + + return 0 + + +def deserialize_hyperparameters(hyperparameters: Dict[str, str]) -> Dict[str, Any]: + """Deserialize hyperparameters from string to their original types. + + Args: + hyperparameters (Dict[str, str]): Hyperparameters as strings. + + Returns: + Dict[str, Any]: Hyperparameters as their original types. + """ + deserialized_hyperparameters = {} + for key, value in hyperparameters.items(): + deserialized_hyperparameters[key] = safe_deserialize(value) + return deserialized_hyperparameters + + +def set_env( + resource_config: Dict[str, Any], + input_data_config: Dict[str, Any], + hyperparameters_config: Dict[str, Any], + output_file: str = ENV_OUTPUT_FILE, +): + """Set environment variables for the training job container. + + Args: + resource_config (Dict[str, Any]): Resource configuration for the training job. + input_data_config (Dict[str, Any]): Input data configuration for the training job. + hyperparameters_config (Dict[str, Any]): Hyperparameters configuration for the training job. + output_file (str): Output file to write the environment variables. + """ + # Constants + env_vars = { + "SM_MODEL_DIR": SM_MODEL_DIR, + "SM_INPUT_DIR": SM_INPUT_DIR, + "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR, + "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR, + "SM_OUTPUT_DIR": SM_OUTPUT_DIR, + "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE, + "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR, + "SM_LOG_LEVEL": SM_LOG_LEVEL, + "SM_MASTER_ADDR": SM_MASTER_ADDR, + "SM_MASTER_PORT": SM_MASTER_PORT, + } + + # Data Channels + channels = list(input_data_config.keys()) + for channel in channels: + env_vars[f"SM_CHANNEL_{channel.upper()}"] = f"{SM_INPUT_DATA_DIR}/{channel}" + env_vars["SM_CHANNELS"] = channels + + # Hyperparameters + hps = deserialize_hyperparameters(hyperparameters_config) + for key, value in hps.items(): + key_upper = key.replace("-", "_").upper() + env_vars[f"SM_HP_{key_upper}"] = value + env_vars["SM_HPS"] = hps + + # Host Variables + current_host = resource_config["current_host"] + current_instance_type = resource_config["current_instance_type"] + hosts = resource_config["hosts"] + sorted_hosts = sorted(hosts) + + env_vars["SM_CURRENT_HOST"] = current_host + env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type + env_vars["SM_HOSTS"] = sorted_hosts + env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"] + env_vars["SM_HOST_COUNT"] = len(sorted_hosts) + env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host) + + env_vars["SM_NUM_CPUS"] = num_cpus() + env_vars["SM_NUM_GPUS"] = num_gpus() + env_vars["SM_NUM_NEURONS"] = num_neurons() + + # Misc. + env_vars["SM_RESOURCE_CONFIG"] = resource_config + env_vars["SM_INPUT_DATA_CONFIG"] = input_data_config + + # All Training Environment Variables + env_vars["SM_TRAINING_ENV"] = { + "channel_input_dirs": { + channel: env_vars[f"SM_CHANNEL_{channel.upper()}"] for channel in channels + }, + "current_host": env_vars["SM_CURRENT_HOST"], + "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"], + "hosts": env_vars["SM_HOSTS"], + "master_addr": env_vars["SM_MASTER_ADDR"], + "master_port": env_vars["SM_MASTER_PORT"], + "hyperparameters": env_vars["SM_HPS"], + "input_data_config": input_data_config, + "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"], + "input_data_dir": env_vars["SM_INPUT_DATA_DIR"], + "input_dir": env_vars["SM_INPUT_DIR"], + "job_name": os.environ["TRAINING_JOB_NAME"], + "log_level": env_vars["SM_LOG_LEVEL"], + "model_dir": env_vars["SM_MODEL_DIR"], + "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"], + "num_cpus": env_vars["SM_NUM_CPUS"], + "num_gpus": env_vars["SM_NUM_GPUS"], + "num_neurons": env_vars["SM_NUM_NEURONS"], + "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"], + "resource_config": env_vars["SM_RESOURCE_CONFIG"], + } + with open(output_file, "w") as f: + for key, value in env_vars.items(): + f.write(f"export {key}='{safe_serialize(value)}'\n") + + logger.info("Environment Variables:") + log_env_variables(env_vars_dict=env_vars) + + +def mask_sensitive_info(data): + """Recursively mask sensitive information in a dictionary.""" + if isinstance(data, dict): + for k, v in data.items(): + if isinstance(v, dict): + data[k] = mask_sensitive_info(v) + elif isinstance(v, str) and any( + keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS + ): + data[k] = HIDDEN_VALUE + return data + + +def log_key_value(key: str, value: str): + """Log a key-value pair, masking sensitive values if necessary.""" + if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS): + logger.info("%s=%s", key, HIDDEN_VALUE) + elif isinstance(value, dict): + masked_value = mask_sensitive_info(value) + logger.info("%s=%s", key, json.dumps(masked_value)) + else: + try: + decoded_value = json.loads(value) + if isinstance(decoded_value, dict): + masked_value = mask_sensitive_info(decoded_value) + logger.info("%s=%s", key, json.dumps(masked_value)) + else: + logger.info("%s=%s", key, decoded_value) + except (json.JSONDecodeError, TypeError): + logger.info("%s=%s", key, value) + + +def log_env_variables(env_vars_dict: Dict[str, Any]): + """Log Environment Variables from the environment and an env_vars_dict.""" + for key, value in os.environ.items(): + log_key_value(key, value) + + for key, value in env_vars_dict.items(): + log_key_value(key, value) + + +def main(): + """Main function to set the environment variables for the training job container.""" + with open(RESOURCE_CONFIG, "r") as f: + resource_config = json.load(f) + with open(INPUT_DATA_CONFIG, "r") as f: + input_data_config = json.load(f) + with open(HYPERPARAMETERS_CONFIG, "r") as f: + hyperparameters_config = json.load(f) + + set_env( + resource_config=resource_config, + input_data_config=input_data_config, + hyperparameters_config=hyperparameters_config, + output_file=ENV_OUTPUT_FILE, + ) + + +if __name__ == "__main__": + main() diff --git a/src/sagemaker/modules/train/container_drivers/torchrun_driver.py b/src/sagemaker/modules/train/container_drivers/torchrun_driver.py new file mode 100644 index 0000000000..666479ec84 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/torchrun_driver.py @@ -0,0 +1,128 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module is the entry point for the Torchrun driver script.""" +from __future__ import absolute_import + +import os +import sys + +from typing import List, Tuple + +from utils import ( + logger, + read_source_code_json, + read_distributed_json, + read_hyperparameters_json, + hyperparameters_to_cli_args, + get_process_count, + get_python_executable, + execute_commands, + write_failure_file, + USER_CODE_PATH, + SM_EFA_NCCL_INSTANCES, + SM_EFA_RDMA_INSTANCES, +) + + +def pytorch_version() -> Tuple[int, int]: + """Get the PyTorch version as a tuple of integers.""" + import torch + + return tuple(map(int, torch.__version__.split(".")[:2])) + + +def get_base_pytorch_command() -> List[str]: + """Get the base Torch Distributed launcher to execute""" + if pytorch_version() >= (1, 9): + return ["torchrun"] + return [f"{get_python_executable()}", "-m", "torch.distributed.launch"] + + +def setup_env(): + """Setup the environment variables for PyTorch distributed training""" + instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"] + network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0") + if instance_type in SM_EFA_NCCL_INSTANCES: + # Enable EFA use + os.environ["FI_PROVIDER"] = "efa" + if instance_type in SM_EFA_RDMA_INSTANCES: + # Use EFA's RDMA functionality for one-sided and two-sided transfer + os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1" + os.environ["RDMAV_FORK_SAFE"] = "1" + os.environ["NCCL_SOCKET_IFNAME"] = str(network_interface_name) + os.environ["NCCL_PROTO"] = "simple" + + +def create_commands(): + """Create the Torch Distributed command to execute""" + source_code = read_source_code_json() + distribution = read_distributed_json() + hyperparameters = read_hyperparameters_json() + + process_count = get_process_count(distribution) + host_count = int(os.environ["SM_HOST_COUNT"]) + + torch_cmd = [] + if os.environ.get("RUN_NEURON_PARALLEL_COMPILE") == "1": + torch_cmd.append("neuron_parallel_compile") + + torch_cmd.extend(get_base_pytorch_command()) + torch_cmd.extend( + [ + f"--nnodes={host_count}", + f"--nproc_per_node={process_count}", + ] + ) + + # If more than one node is used, add node rank information + if int(host_count) > 1: + torch_cmd.extend( + [ + f"--master_addr={os.environ['SM_MASTER_ADDR']}", + f"--master_port={os.environ['SM_MASTER_PORT']}", + f"--node_rank={os.environ['SM_CURRENT_HOST_RANK']}", + ] + ) + + torch_cmd.extend([os.path.join(USER_CODE_PATH, source_code["entry_script"])]) + + args = hyperparameters_to_cli_args(hyperparameters) + torch_cmd += args + + return torch_cmd + + +def main(): + """Main function to execute the PyTorch distributed training script. + + This function sets some environment variables and executes the PyTorch + distributed training script. + + Execution Lifecycle: + 1. Setup Environment Variables for PyTorch Distributed Training + 2. Create Torch Distributed Command + 3. Execute Torch Distributed Command with user script provided in `entry_script` + 4. Exit + + """ + setup_env() + torch_cmd = create_commands() + logger.info(f"Executing command: {' '.join(torch_cmd)}") + exit_code, traceback = execute_commands(torch_cmd) + if exit_code != 0: + write_failure_file(traceback) + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/src/sagemaker/modules/train/container_drivers/utils.py b/src/sagemaker/modules/train/container_drivers/utils.py new file mode 100644 index 0000000000..e939a6e0b8 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/utils.py @@ -0,0 +1,213 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides utility functions for the container drivers.""" +from __future__ import absolute_import + +import os +import logging +import sys +import subprocess +import traceback +import json + +from typing import List, Dict, Any, Tuple, IO, Optional + +# Initialize logger +SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) +logger = logging.getLogger(__name__) +console_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(console_handler) +logger.setLevel(int(SM_LOG_LEVEL)) + +FAILURE_FILE = "/opt/ml/output/failure" +DEFAULT_FAILURE_MESSAGE = """ +Training Execution failed. +For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'. +TrainingJob - {training_job_name} +""" + +USER_CODE_PATH = "/opt/ml/input/data/code" +SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json" +DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json" + +HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json" + +SM_EFA_NCCL_INSTANCES = [ + "ml.g4dn.8xlarge", + "ml.g4dn.12xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", + "ml.p4d.24xlarge", + "ml.p4de.24xlarge", + "ml.p5.48xlarge", + "ml.trn1.32xlarge", +] + +SM_EFA_RDMA_INSTANCES = [ + "ml.p4d.24xlarge", + "ml.p4de.24xlarge", + "ml.trn1.32xlarge", +] + + +def write_failure_file(message: Optional[str] = None): + """Write a failure file with the message.""" + if message is None: + message = DEFAULT_FAILURE_MESSAGE.format(training_job_name=os.environ["TRAINING_JOB_NAME"]) + if not os.path.exists(FAILURE_FILE): + with open(FAILURE_FILE, "w") as f: + f.write(message) + + +def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON): + """Read the source code config json file.""" + try: + with open(source_code_json, "r") as f: + source_code_dict = json.load(f) or {} + except FileNotFoundError: + source_code_dict = {} + return source_code_dict + + +def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON): + """Read the distribution config json file.""" + try: + with open(distributed_json, "r") as f: + distributed_dict = json.load(f) or {} + except FileNotFoundError: + distributed_dict = {} + return distributed_dict + + +def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON): + """Read the hyperparameters config json file.""" + try: + with open(hyperparameters_json, "r") as f: + hyperparameters_dict = json.load(f) or {} + except FileNotFoundError: + hyperparameters_dict = {} + return hyperparameters_dict + + +def get_process_count(distributed_dict: Dict[str, Any]) -> int: + """Get the number of processes to run on each node in the training job.""" + return ( + int(distributed_dict.get("process_count_per_node", 0)) + or int(os.environ.get("SM_NUM_GPUS", 0)) + or int(os.environ.get("SM_NUM_NEURONS", 0)) + or 1 + ) + + +def hyperparameters_to_cli_args(hyperparameters: Dict[str, Any]) -> List[str]: + """Convert the hyperparameters to CLI arguments.""" + cli_args = [] + for key, value in hyperparameters.items(): + value = safe_deserialize(value) + cli_args.extend([f"--{key}", safe_serialize(value)]) + + return cli_args + + +def safe_deserialize(data: Any) -> Any: + """Safely deserialize data from a JSON string. + + This function handles the following cases: + 1. If `data` is not a string, it returns the input as-is. + 2. If `data` is a string and matches common boolean values ("true" or "false"), + it returns the corresponding boolean value (True or False). + 3. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`. + 4. If `data` is a string but cannot be decoded as JSON, it returns the original string. + + Returns: + Any: The deserialized data, or the original input if it cannot be JSON-decoded. + """ + if not isinstance(data, str): + return data + + lower_data = data.lower() + if lower_data in ["true"]: + return True + if lower_data in ["false"]: + return False + + try: + return json.loads(data) + except json.JSONDecodeError: + return data + + +def safe_serialize(data): + """Serialize the data without wrapping strings in quotes. + + This function handles the following cases: + 1. If `data` is a string, it returns the string as-is without wrapping in quotes. + 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns + the JSON-encoded string using `json.dumps()`. + 3. If `data` cannot be serialized (e.g., a custom object), it returns the string + representation of the data using `str(data)`. + + Args: + data (Any): The data to serialize. + + Returns: + str: The serialized JSON-compatible string or the string representation of the input. + """ + if isinstance(data, str): + return data + try: + return json.dumps(data) + except TypeError: + return str(data) + + +def get_python_executable() -> str: + """Get the python executable path.""" + return sys.executable + + +def log_subprocess_output(pipe: IO[bytes]): + """Log the output from the subprocess.""" + for line in iter(pipe.readline, b""): + logger.info(line.decode("utf-8").strip()) + + +def execute_commands(commands: List[str]) -> Tuple[int, str]: + """Execute the provided commands and return exit code with failure traceback if any.""" + try: + process = subprocess.Popen( + commands, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + with process.stdout: + log_subprocess_output(process.stdout) + exitcode = process.wait() + if exitcode != 0: + raise subprocess.CalledProcessError(exitcode, commands) + return exitcode, "" + except subprocess.CalledProcessError as e: + # Capture the traceback in case of failure + error_traceback = traceback.format_exc() + print(f"Command failed with exit code {e.returncode}. Traceback: {error_traceback}") + return e.returncode, error_traceback + + +def is_worker_node() -> bool: + """Check if the current node is a worker node.""" + return os.environ.get("SM_CURRENT_HOST") != os.environ.get("SM_MASTER_ADDR") + + +def is_master_node() -> bool: + """Check if the current node is the master node.""" + return os.environ.get("SM_CURRENT_HOST") == os.environ.get("SM_MASTER_ADDR") diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py new file mode 100644 index 0000000000..31decfaca9 --- /dev/null +++ b/src/sagemaker/modules/train/model_trainer.py @@ -0,0 +1,1029 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""ModelTrainer class module.""" +from __future__ import absolute_import + +from enum import Enum +import os +import json +import shutil +from tempfile import TemporaryDirectory + +from typing import Optional, List, Union, Dict, Any, ClassVar + +from graphene.utils.str_converters import to_camel_case, to_snake_case + +from sagemaker_core.main import resources +from sagemaker_core.resources import TrainingJob +from sagemaker_core.shapes import AlgorithmSpecification + +from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call + +from sagemaker.config.config_schema import ( + _simple_path, + SAGEMAKER, + MODEL_TRAINER, + MODULES, + PYTHON_SDK, + TRAINING_JOB_ENVIRONMENT_PATH, + TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + TRAINING_JOB_VPC_CONFIG_PATH, + TRAINING_JOB_SUBNETS_PATH, + TRAINING_JOB_SECURITY_GROUP_IDS_PATH, + TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, + TRAINING_JOB_RESOURCE_CONFIG_PATH, + TRAINING_JOB_ROLE_ARN_PATH, + TRAINING_JOB_TAGS_PATH, +) + +from sagemaker.utils import resolve_value_from_config +from sagemaker.modules import Session, get_execution_role +from sagemaker.modules.configs import ( + Compute, + StoppingCondition, + RetryStrategy, + OutputDataConfig, + SourceCode, + TrainingImageConfig, + Channel, + DataSource, + S3DataSource, + FileSystemDataSource, + Networking, + Tag, + InfraCheckConfig, + RemoteDebugConfig, + SessionChainingConfig, + TensorBoardOutputConfig, + CheckpointConfig, + InputData, +) + +from sagemaker.modules.local_core.local_container import _LocalContainer +from sagemaker.modules.distributed import Torchrun, MPI, DistributedConfig +from sagemaker.modules.utils import ( + _get_repo_name_from_image, + _get_unique_name, + _is_valid_path, + _is_valid_s3_uri, + safe_serialize, +) +from sagemaker.modules.types import DataSourceType +from sagemaker.modules.constants import ( + DEFAULT_INSTANCE_TYPE, + SM_CODE, + SM_CODE_CONTAINER_PATH, + SM_DRIVERS, + SM_DRIVERS_LOCAL_PATH, + TRAIN_SCRIPT, + DEFAULT_CONTAINER_ENTRYPOINT, + DEFAULT_CONTAINER_ARGUMENTS, + SOURCE_CODE_JSON, + DISTRIBUTED_JSON, +) +from sagemaker.modules.templates import ( + TRAIN_SCRIPT_TEMPLATE, + EXECUTE_BASE_COMMANDS, + EXECUTE_MPI_DRIVER, + EXEUCTE_TORCHRUN_DRIVER, + EXECUTE_BASIC_SCRIPT_DRIVER, +) +from sagemaker.telemetry.telemetry_logging import _telemetry_emitter +from sagemaker.telemetry.constants import Feature +from sagemaker.modules import logger +from sagemaker.modules.train.sm_recipes.utils import _get_args_from_recipe, _determine_device_type + + +class Mode(Enum): + """Enum class for training mode.""" + + LOCAL_CONTAINER = "LOCAL_CONTAINER" + SAGEMAKER_TRAINING_JOB = "SAGEMAKER_TRAINING_JOB" + + +class ModelTrainer(BaseModel): + """Class that trains a model using AWS SageMaker. + + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import SourceCode, Compute, InputData + + source_code = SourceCode(source_dir="source", entry_script="train.py") + training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image" + model_trainer = ModelTrainer( + training_image=training_image, + source_code=source_code, + ) + + train_data = InputData(channel_name="train", data_source="s3://bucket/train") + model_trainer.train(input_data_config=[train_data]) + + training_job = model_trainer._latest_training_job + + Parameters: + training_mode (Mode): + The training mode. Valid values are "Mode.LOCAL_CONTAINER" or + "Mode.SAGEMAKER_TRAINING_JOB". + sagemaker_session (Optiona(Session)): + The SageMakerCore session. For convinience, can be imported like: + ``from sagemaker.modules import Session``. + If not specified, a new session will be created. + If the default bucket for the artifacts needs to be updated, it can be done by + passing it in the Session object. + role (Optional(str)): + The IAM role ARN for the training job. + If not specified, the default SageMaker execution role will be used. + base_job_name (Optional[str]): + The base name for the training job. + If not specified, a default name will be generated using the algorithm name + or training image. + source_code (Optional[SourceCode]): + The source code configuration. This is used to configure the source code for + running the training job. + distributed (Optional[Union[MPI, Torchrun]]): + The distributed runner for the training job. This is used to configure + a distributed training job. If specifed, ``source_code`` must also + be provided. + compute (Optional[Compute]): + The compute configuration. This is used to specify the compute resources for + the training job. If not specified, will default to 1 instance of ml.m5.xlarge. + networking (Optional[Networking]): + The networking configuration. This is used to specify the networking settings + for the training job. + stopping_condition (Optional[StoppingCondition]): + The stopping condition. This is used to specify the different stopping + conditions for the training job. + If not specified, will default to 1 hour max run time. + algorithm_name (Optional[str]): + The SageMaker marketplace algorithm name/arn to use for the training job. + algorithm_name cannot be specified if training_image is specified. + training_image (Optional[str]): + The training image URI to use for the training job container. + training_image cannot be specified if algorithm_name is specified. + To find available sagemaker distributed images, + see: https://docs.aws.amazon.com/sagemaker/latest/dg-ecr-paths/sagemaker-algo-docker-registry-paths + training_image_config (Optional[TrainingImageConfig]): + Training image Config. This is the configuration to use an image from a private + Docker registry for a training job. + output_data_config (Optional[OutputDataConfig]): + The output data configuration. This is used to specify the output data location + for the training job. + If not specified in the session, will default to + ``s3://///``. + input_data_config (Optional[List[Union[Channel, InputData]]]): + The input data config for the training job. + Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI + string, local file path string, S3DataSource object, or FileSystemDataSource object. + checkpoint_config (Optional[CheckpointConfig]): + Contains information about the output location for managed spot training checkpoint + data. + training_input_mode (Optional[str]): + The input mode for the training job. Valid values are "Pipe", "File", "FastFile". + Defaults to "File". + environment (Optional[Dict[str, str]]): + The environment variables for the training job. + hyperparameters (Optional[Dict[str, Any]]): + The hyperparameters for the training job. + tags (Optional[List[Tag]]): + An array of key-value pairs. You can use tags to categorize your AWS resources + in different ways, for example, by purpose, owner, or environment. + local_container_root (Optional[str]): + The local root directory to store artifacts from a training job launched in + "LOCAL_CONTAINER" mode. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB + sagemaker_session: Optional[Session] = None + role: Optional[str] = None + base_job_name: Optional[str] = None + source_code: Optional[SourceCode] = None + distributed: Optional[Union[MPI, Torchrun]] = None + compute: Optional[Compute] = None + networking: Optional[Networking] = None + stopping_condition: Optional[StoppingCondition] = None + training_image: Optional[str] = None + training_image_config: Optional[TrainingImageConfig] = None + algorithm_name: Optional[str] = None + output_data_config: Optional[OutputDataConfig] = None + input_data_config: Optional[List[Union[Channel, InputData]]] = None + checkpoint_config: Optional[CheckpointConfig] = None + training_input_mode: Optional[str] = "File" + environment: Optional[Dict[str, str]] = {} + hyperparameters: Optional[Dict[str, Any]] = {} + tags: Optional[List[Tag]] = None + local_container_root: Optional[str] = os.getcwd() + + # Created Artifacts + _latest_training_job: Optional[resources.TrainingJob] = PrivateAttr(default=None) + + # Private TrainingJob Parameters + _tensorboard_output_config: Optional[TensorBoardOutputConfig] = PrivateAttr(default=None) + _retry_strategy: Optional[RetryStrategy] = PrivateAttr(default=None) + _infra_check_config: Optional[InfraCheckConfig] = PrivateAttr(default=None) + _session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None) + _remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None) + + _temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) + + CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = [ + "role", + "base_job_name", + "source_code", + "compute", + "networking", + "stopping_condition", + "training_image", + "training_image_config", + "algorithm_name", + "output_data_config", + "checkpoint_config", + "training_input_mode", + "environment", + "hyperparameters", + ] + + SERIALIZABLE_CONFIG_ATTRIBUTES: ClassVar[Any] = { + "source_code": SourceCode, + "compute": Compute, + "networking": Networking, + "stopping_condition": StoppingCondition, + "training_image_config": TrainingImageConfig, + "output_data_config": OutputDataConfig, + "checkpoint_config": CheckpointConfig, + } + + def _populate_intelligent_defaults(self): + """Function to populate all the possible default configs + + Model Trainer specific configs take precedence over the generic training job ones. + """ + self._populate_intelligent_defaults_from_model_trainer_space() + self._populate_intelligent_defaults_from_training_job_space() + + def _populate_intelligent_defaults_from_training_job_space(self): + """Function to populate all the possible default configs from Training Job Space""" + if not self.environment: + self.environment = resolve_value_from_config( + config_path=TRAINING_JOB_ENVIRONMENT_PATH, sagemaker_session=self.sagemaker_session + ) + + default_enable_network_isolation = resolve_value_from_config( + config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + sagemaker_session=self.sagemaker_session, + ) + default_vpc_config = resolve_value_from_config( + config_path=TRAINING_JOB_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session + ) + + if not self.networking: + if default_enable_network_isolation is not None or default_vpc_config is not None: + self.networking = Networking( + default_enable_network_isolation=default_enable_network_isolation, + subnets=resolve_value_from_config(config_path=TRAINING_JOB_SUBNETS_PATH), + security_group_ids=resolve_value_from_config( + config_path=TRAINING_JOB_SECURITY_GROUP_IDS_PATH + ), + ) + else: + if self.networking.enable_network_isolation is None: + self.networking.enable_network_isolation = default_enable_network_isolation + if self.networking.subnets is None: + self.networking.subnets = resolve_value_from_config( + config_path=TRAINING_JOB_SUBNETS_PATH + ) + if self.networking.security_group_ids is None: + self.networking.subnets = resolve_value_from_config( + config_path=TRAINING_JOB_SUBNETS_PATH + ) + + if not self.output_data_config: + default_output_data_config = resolve_value_from_config( + config_path=TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH + ) + if default_output_data_config: + self.output_data_config = OutputDataConfig( + **self._convert_keys_to_snake(default_output_data_config) + ) + + if not self.compute: + default_resource_config = resolve_value_from_config( + config_path=TRAINING_JOB_RESOURCE_CONFIG_PATH + ) + if default_resource_config: + self.compute = Compute(**self._convert_keys_to_snake(default_resource_config)) + + if not self.role: + self.role = resolve_value_from_config(config_path=TRAINING_JOB_ROLE_ARN_PATH) + + if not self.tags: + self.tags = resolve_value_from_config(config_path=TRAINING_JOB_TAGS_PATH) + + def _convert_keys_to_snake(self, config: dict) -> dict: + """Utility helper function that converts the keys of a dictionary into snake case""" + return {to_snake_case(key): value for key, value in config.items()} + + def _populate_intelligent_defaults_from_model_trainer_space(self): + """Function to populate all the possible default configs from Model Trainer Space""" + + for configurable_attribute in self.CONFIGURABLE_ATTRIBUTES: + if getattr(self, configurable_attribute) is None: + default_config = resolve_value_from_config( + config_path=_simple_path( + SAGEMAKER, + PYTHON_SDK, + MODULES, + MODEL_TRAINER, + to_camel_case(configurable_attribute), + ), + sagemaker_session=self.sagemaker_session, + ) + if default_config is not None: + if configurable_attribute in self.SERIALIZABLE_CONFIG_ATTRIBUTES: + default_config = self.SERIALIZABLE_CONFIG_ATTRIBUTES.get( + configurable_attribute + )( + **default_config # pylint: disable=E1134 + ) + setattr(self, configurable_attribute, default_config) + + def __del__(self): + """Destructor method to clean up the temporary directory.""" + # Clean up the temporary directory if it exists + if self._temp_recipe_train_dir is not None: + self._temp_recipe_train_dir.cleanup() + + def _validate_training_image_and_algorithm_name( + self, training_image: Optional[str], algorithm_name: Optional[str] + ): + """Validate that only one of 'training_image' or 'algorithm_name' is provided.""" + if not training_image and not algorithm_name: + raise ValueError( + "Atleast one of 'training_image' or 'algorithm_name' must be provided.", + ) + if training_image and algorithm_name: + raise ValueError( + "Only one of 'training_image' or 'algorithm_name' must be provided.", + ) + + def _validate_distributed_config( + self, + source_code: Optional[SourceCode], + distributed: Optional[DistributedConfig], + ): + """Validate the distribution configuration.""" + if distributed and not source_code.entry_script: + raise ValueError( + "Must provide 'entry_script' if 'distribution' " + "is provided in 'source_code'.", + ) + + # TODO: Move to use pydantic model validators + def _validate_source_code(self, source_code: Optional[SourceCode]): + """Validate the source code configuration.""" + if source_code: + if source_code.requirements or source_code.entry_script: + source_dir = source_code.source_dir + requirements = source_code.requirements + entry_script = source_code.entry_script + if not source_dir: + raise ValueError( + "If 'requirements' or 'entry_script' is provided in 'source_code', " + + "'source_dir' must also be provided.", + ) + if not _is_valid_path(source_dir, path_type="Directory"): + raise ValueError( + f"Invalid 'source_dir' path: {source_dir}. " + "Must be a valid directory.", + ) + if requirements: + if not _is_valid_path( + f"{source_dir}/{requirements}", + path_type="File", + ): + raise ValueError( + f"Invalid 'requirements': {requirements}. " + + "Must be a valid file within the 'source_dir'.", + ) + if entry_script: + if not _is_valid_path( + f"{source_dir}/{entry_script}", + path_type="File", + ): + raise ValueError( + f"Invalid 'entry_script': {entry_script}. " + + "Must be a valid file within the 'source_dir'.", + ) + + def model_post_init(self, __context: Any): + """Post init method to perform custom validation and set default values.""" + self._validate_training_image_and_algorithm_name(self.training_image, self.algorithm_name) + self._validate_source_code(self.source_code) + self._validate_distributed_config(self.source_code, self.distributed) + + if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: + if self.sagemaker_session is None: + self.sagemaker_session = Session() + logger.warning("SageMaker session not provided. Using default Session.") + + if self.role is None: + self.role = get_execution_role(sagemaker_session=self.sagemaker_session) + logger.warning(f"Role not provided. Using default role:\n{self.role}") + + if self.base_job_name is None: + if self.algorithm_name: + self.base_job_name = f"{self.algorithm_name}-job" + elif self.training_image: + self.base_job_name = f"{_get_repo_name_from_image(self.training_image)}-job" + logger.warning(f"Base name not provided. Using default name:\n{self.base_job_name}") + + if self.compute is None: + self.compute = Compute( + instance_type=DEFAULT_INSTANCE_TYPE, + instance_count=1, + volume_size_in_gb=30, + ) + logger.warning(f"Compute not provided. Using default:\n{self.compute}") + + if self.stopping_condition is None: + self.stopping_condition = StoppingCondition( + max_runtime_in_seconds=3600, + max_pending_time_in_seconds=None, + max_wait_time_in_seconds=None, + ) + logger.warning( + f"StoppingCondition not provided. Using default:\n{self.stopping_condition}" + ) + + if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None: + session = self.sagemaker_session + base_job_name = self.base_job_name + self.output_data_config = OutputDataConfig( + s3_output_path=f"s3://{self._fetch_bucket_name_and_prefix(session)}" + f"/{base_job_name}", + compression_type="GZIP", + kms_key_id=None, + ) + logger.warning( + f"OutputDataConfig not provided. Using default:\n{self.output_data_config}" + ) + + # TODO: Autodetect which image to use if source_code is provided + if self.training_image: + logger.info(f"Training image URI: {self.training_image}") + + def _fetch_bucket_name_and_prefix(self, session: Session) -> str: + """Helper function to get the bucket name with the corresponding prefix if applicable""" + if session.default_bucket_prefix is not None: + return f"{session.default_bucket()}/{session.default_bucket_prefix}" + return session.default_bucket() + + @_telemetry_emitter(feature=Feature.MODEL_TRAINER, func_name="model_trainer.train") + @validate_call + def train( + self, + input_data_config: Optional[List[Union[Channel, InputData]]] = None, + wait: Optional[bool] = True, + logs: Optional[bool] = True, + ): + """Train a model using AWS SageMaker. + + Args: + input_data_config (Optional[Union[List[Channel], Dict[str, DataSourceType]]]): + The input data config for the training job. + Takes a list of Channel objects or a dictionary of channel names to DataSourceType. + DataSourceType can be an S3 URI string, local file path string, + S3DataSource object, or FileSystemDataSource object. + wait (Optional[bool]): + Whether to wait for the training job to complete before returning. + Defaults to True. + logs (Optional[bool]): + Whether to display the training container logs while training. + Defaults to True. + """ + self._populate_intelligent_defaults() + current_training_job_name = _get_unique_name(self.base_job_name) + input_data_key_prefix = f"{self.base_job_name}/{current_training_job_name}/input" + if input_data_config: + self.input_data_config = input_data_config + + input_data_config = [] + if self.input_data_config: + input_data_config = self._get_input_data_config( + self.input_data_config, input_data_key_prefix + ) + + string_hyper_parameters = {} + if self.hyperparameters: + for hyper_parameter, value in self.hyperparameters.items(): + string_hyper_parameters[hyper_parameter] = safe_serialize(value) + + container_entrypoint = None + container_arguments = None + if self.source_code: + if self.training_mode == Mode.LOCAL_CONTAINER: + drivers_dir = TemporaryDirectory( + prefix=os.path.join(self.local_container_root + "/") + ) + else: + drivers_dir = TemporaryDirectory() + shutil.copytree(SM_DRIVERS_LOCAL_PATH, drivers_dir.name, dirs_exist_ok=True) + + # If source code is provided, create a channel for the source code + # The source code will be mounted at /opt/ml/input/data/code in the container + if self.source_code.source_dir: + source_code_channel = self.create_input_data_channel( + channel_name=SM_CODE, + data_source=self.source_code.source_dir, + key_prefix=input_data_key_prefix, + ) + input_data_config.append(source_code_channel) + + self._prepare_train_script( + tmp_dir=drivers_dir, + source_code=self.source_code, + distributed=self.distributed, + ) + + if isinstance(self.distributed, Torchrun) and self.distributed.smp: + mp_parameters = self.distributed.smp._to_mp_hyperparameters() + string_hyper_parameters.update(mp_parameters) + + self._write_source_code_json(tmp_dir=drivers_dir, source_code=self.source_code) + self._write_distributed_json(tmp_dir=drivers_dir, distributed=self.distributed) + + # Create an input channel for drivers packaged by the sdk + sm_drivers_channel = self.create_input_data_channel( + channel_name=SM_DRIVERS, + data_source=drivers_dir.name, + key_prefix=input_data_key_prefix, + ) + input_data_config.append(sm_drivers_channel) + + # If source_code is provided, we will always use + # the default container entrypoint and arguments + # to execute the sm_train.sh script. + # Any commands generated from the source_code will be + # executed from the sm_train.sh script. + container_entrypoint = DEFAULT_CONTAINER_ENTRYPOINT + container_arguments = DEFAULT_CONTAINER_ARGUMENTS + + algorithm_specification = AlgorithmSpecification( + algorithm_name=self.algorithm_name, + training_image=self.training_image, + training_input_mode=self.training_input_mode, + training_image_config=self.training_image_config, + container_entrypoint=container_entrypoint, + container_arguments=container_arguments, + ) + + resource_config = self.compute._to_resource_config() + vpc_config = self.networking._to_vpc_config() if self.networking else None + + if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: + training_job = TrainingJob.create( + training_job_name=current_training_job_name, + algorithm_specification=algorithm_specification, + hyper_parameters=string_hyper_parameters, + input_data_config=input_data_config, + resource_config=resource_config, + vpc_config=vpc_config, + # Public Instance Attributes + session=self.sagemaker_session.boto_session, + role_arn=self.role, + tags=self.tags, + stopping_condition=self.stopping_condition, + output_data_config=self.output_data_config, + checkpoint_config=self.checkpoint_config, + environment=self.environment, + enable_managed_spot_training=self.compute.enable_managed_spot_training, + enable_inter_container_traffic_encryption=( + self.networking.enable_inter_container_traffic_encryption + if self.networking + else None + ), + enable_network_isolation=( + self.networking.enable_network_isolation if self.networking else None + ), + # Private Instance Attributes + remote_debug_config=self._remote_debug_config, + tensor_board_output_config=self._tensorboard_output_config, + retry_strategy=self._retry_strategy, + infra_check_config=self._infra_check_config, + session_chaining_config=self._session_chaining_config, + ) + self._latest_training_job = training_job + + if wait: + training_job.wait(logs=logs) + if logs and not wait: + logger.warning( + "Not displaing the training container logs as 'wait' is set to False." + ) + else: + local_container = _LocalContainer( + training_job_name=_get_unique_name(self.base_job_name), + instance_type=resource_config.instance_type, + instance_count=resource_config.instance_count, + image=algorithm_specification.training_image, + container_root=self.local_container_root, + sagemaker_session=self.sagemaker_session, + container_entrypoint=algorithm_specification.container_entrypoint, + container_arguments=algorithm_specification.container_arguments, + input_data_config=input_data_config, + hyper_parameters=string_hyper_parameters, + environment=self.environment, + ) + local_container.train(wait) + + def create_input_data_channel( + self, channel_name: str, data_source: DataSourceType, key_prefix: Optional[str] = None + ) -> Channel: + """Create an input data channel for the training job. + + Args: + channel_name (str): The name of the input data channel. + data_source (DataSourceType): The data source for the input data channel. + DataSourceType can be an S3 URI string, local file path string, + S3DataSource object, or FileSystemDataSource object. + key_prefix (Optional[str]): The key prefix to use when uploading data to S3. + Only applicable when data_source is a local file path string. + If not specified, local data will be uploaded to: + ``s3:////input//`` + + If specified, local data will be uploaded to: + ``s3://///`` + """ + channel = None + if isinstance(data_source, str): + if _is_valid_s3_uri(data_source): + channel = Channel( + channel_name=channel_name, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=data_source, + s3_data_distribution_type="FullyReplicated", + ), + ), + input_mode="File", + ) + if key_prefix: + logger.warning( + "key_prefix is only applicable when data_source is a local file path." + ) + elif _is_valid_path(data_source): + if self.training_mode == Mode.LOCAL_CONTAINER: + channel = Channel( + channel_name=channel_name, + data_source=DataSource( + file_system_data_source=FileSystemDataSource.model_construct( + directory_path=data_source, + file_system_type="EFS", + ), + ), + input_mode="File", + ) + else: + key_prefix = ( + f"{key_prefix}/{channel_name}" + if key_prefix + else f"{self.base_job_name}/input/{channel_name}" + ) + if self.sagemaker_session.default_bucket_prefix: + key_prefix = f"{self.sagemaker_session.default_bucket_prefix}/{key_prefix}" + s3_uri = self.sagemaker_session.upload_data( + path=data_source, + bucket=self.sagemaker_session.default_bucket(), + key_prefix=key_prefix, + ) + channel = Channel( + channel_name=channel_name, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=s3_uri, + s3_data_distribution_type="FullyReplicated", + ), + ), + input_mode="File", + ) + else: + raise ValueError(f"Not a valid S3 URI or local file path: {data_source}.") + elif isinstance(data_source, S3DataSource): + channel = Channel( + channel_name=channel_name, data_source=DataSource(s3_data_source=data_source) + ) + elif isinstance(data_source, FileSystemDataSource): + channel = Channel( + channel_name=channel_name, + data_source=DataSource(file_system_data_source=data_source), + ) + return channel + + def _get_input_data_config( + self, + input_data_channels: Optional[List[Union[Channel, InputData]]], + key_prefix: Optional[str] = None, + ) -> List[Channel]: + """Get the input data configuration for the training job. + + Args: + input_data_channels (Optional[List[Union[Channel, InputData]]]): + The input data config for the training job. + Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI + string, local file path string, S3DataSource object, or FileSystemDataSource object. + """ + if input_data_channels is None: + return [] + + channels = [] + for input_data in input_data_channels: + if isinstance(input_data, Channel): + channels.append(input_data) + elif isinstance(input_data, InputData): + channel = self.create_input_data_channel( + input_data.channel_name, input_data.data_source, key_prefix=key_prefix + ) + channels.append(channel) + else: + raise ValueError( + f"Invalid input data channel: {input_data}. " + + "Must be a Channel or InputDataSource." + ) + return channels + + def _write_source_code_json(self, tmp_dir: TemporaryDirectory, source_code: SourceCode): + """Write the source code configuration to a JSON file.""" + file_path = os.path.join(tmp_dir.name, SOURCE_CODE_JSON) + with open(file_path, "w") as f: + dump = source_code.model_dump(exclude_none=True) if source_code else {} + f.write(json.dumps(dump)) + + def _write_distributed_json( + self, + tmp_dir: TemporaryDirectory, + distributed: Optional[DistributedConfig] = None, + ): + """Write the distributed runner configuration to a JSON file.""" + file_path = os.path.join(tmp_dir.name, DISTRIBUTED_JSON) + with open(file_path, "w") as f: + dump = distributed.model_dump(exclude_none=True) if distributed else {} + f.write(json.dumps(dump)) + + def _prepare_train_script( + self, + tmp_dir: TemporaryDirectory, + source_code: SourceCode, + distributed: Optional[DistributedConfig] = None, + ): + """Prepare the training script to be executed in the training job container. + + Args: + source_code (SourceCodeConfig): The source code configuration. + """ + + base_command = "" + if source_code.command: + if source_code.entry_script: + logger.warning( + "Both 'command' and 'entry_script' are provided in the SourceCodeConfig. " + + "Defaulting to 'command'." + ) + base_command = source_code.command.split() + base_command = " ".join(base_command) + + install_requirements = "" + if source_code.requirements: + install_requirements = "echo 'Installing requirements'\n" + install_requirements = f"$SM_PIP_CMD install -r {source_code.requirements}" + + working_dir = "" + if source_code.source_dir: + working_dir = f"cd {SM_CODE_CONTAINER_PATH}" + + if base_command: + execute_driver = EXECUTE_BASE_COMMANDS.format(base_command=base_command) + elif distributed: + distribution_type = distributed._type + if distribution_type == "mpi": + execute_driver = EXECUTE_MPI_DRIVER + elif distribution_type == "torchrun": + execute_driver = EXEUCTE_TORCHRUN_DRIVER + else: + raise ValueError(f"Unsupported distribution type: {distribution_type}.") + elif source_code.entry_script and not source_code.command and not distributed: + if not source_code.entry_script.endswith((".py", ".sh")): + raise ValueError( + f"Unsupported entry script: {source_code.entry_script}." + + "Only .py and .sh scripts are supported." + ) + execute_driver = EXECUTE_BASIC_SCRIPT_DRIVER + + train_script = TRAIN_SCRIPT_TEMPLATE.format( + working_dir=working_dir, + install_requirements=install_requirements, + execute_driver=execute_driver, + ) + + with open(os.path.join(tmp_dir.name, TRAIN_SCRIPT), "w") as f: + f.write(train_script) + + @classmethod + def from_recipe( + cls, + training_recipe: str, + compute: Compute, + recipe_overrides: Optional[Dict[str, Any]] = None, + networking: Optional[Networking] = None, + stopping_condition: Optional[StoppingCondition] = None, + requirements: Optional[str] = None, + training_image: Optional[str] = None, + training_image_config: Optional[TrainingImageConfig] = None, + output_data_config: Optional[OutputDataConfig] = None, + input_data_config: Optional[List[Union[Channel, InputData]]] = None, + checkpoint_config: Optional[CheckpointConfig] = None, + training_input_mode: Optional[str] = "File", + environment: Optional[Dict[str, str]] = None, + tags: Optional[List[Tag]] = None, + sagemaker_session: Optional[Session] = None, + role: Optional[str] = None, + base_job_name: Optional[str] = None, + ) -> "ModelTrainer": + """Create a ModelTrainer from a training recipe. + + Args: + training_recipe (str): + The training recipe to use for training the model. This must be the name of + a sagemaker training recipe or a path to a local training recipe .yaml file. + compute (Compute): + The compute configuration. This is used to specify the compute resources for + the training job. If not specified, will default to 1 instance of ml.m5.xlarge. + recipe_overrides (Optional[Dict[str, Any]]): + The recipe overrides. This is used to override the default recipe parameters. + networking (Optional[Networking]): + The networking configuration. This is used to specify the networking settings + for the training job. + stopping_condition (Optional[StoppingCondition]): + The stopping condition. This is used to specify the different stopping + conditions for the training job. + If not specified, will default to 1 hour max run time. + requirements (Optional[str]): + The path to a requirements file to install in the training job container. + training_image (Optional[str]): + The training image URI to use for the training job container. If not specified, + the training image will be determined from the recipe. + training_image_config (Optional[TrainingImageConfig]): + Training image Config. This is the configuration to use an image from a private + Docker registry for a training job. + output_data_config (Optional[OutputDataConfig]): + The output data configuration. This is used to specify the output data location + for the training job. + If not specified, will default to ``s3:////output/``. + input_data_config (Optional[List[Union[Channel, InputData]]]): + The input data config for the training job. + Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI + string, local file path string, S3DataSource object, or FileSystemDataSource object. + checkpoint_config (Optional[CheckpointConfig]): + Contains information about the output location for managed spot training checkpoint + data. + training_input_mode (Optional[str]): + The input mode for the training job. Valid values are "Pipe", "File", "FastFile". + Defaults to "File". + environment (Optional[Dict[str, str]]): + The environment variables for the training job. + tags (Optional[List[Tag]]): + An array of key-value pairs. You can use tags to categorize your AWS resources + in different ways, for example, by purpose, owner, or environment. + sagemaker_session (Optional[Session]): + The SageMakerCore session. + If not specified, a new session will be created. + role (Optional[str]): + The IAM role ARN for the training job. + If not specified, the default SageMaker execution role will be used. + base_job_name (Optional[str]): + The base name for the training job. + If not specified, a default name will be generated using the algorithm name + or training image. + """ + if compute.instance_type is None: + raise ValueError( + "Must set ``instance_type`` in compute_config when using training recipes." + ) + device_type = _determine_device_type(compute.instance_type) + if device_type == "cpu": + raise ValueError( + "Training recipes are not supported for CPU instances. " + + "Please provide a GPU or Tranium instance type." + ) + + if training_image_config and training_image is None: + raise ValueError("training_image must be provided when using training_image_config.") + + if sagemaker_session is None: + sagemaker_session = Session() + logger.warning("SageMaker session not provided. Using default Session.") + if role is None: + role = get_execution_role(sagemaker_session=sagemaker_session) + logger.warning(f"Role not provided. Using default role:\n{role}") + + # The training recipe is used to prepare the following args: + # - source_code + # - training_image + # - distributed + # - compute + # - hyperparameters + model_trainer_args, recipe_train_dir = _get_args_from_recipe( + training_recipe=training_recipe, + recipe_overrides=recipe_overrides, + requirements=requirements, + compute=compute, + region_name=sagemaker_session.boto_region_name, + ) + if training_image is not None: + model_trainer_args["training_image"] = training_image + + model_trainer = cls( + sagemaker_session=sagemaker_session, + role=role, + base_job_name=base_job_name, + networking=networking, + stopping_condition=stopping_condition, + training_image_config=training_image_config, + output_data_config=output_data_config, + input_data_config=input_data_config, + checkpoint_config=checkpoint_config, + training_input_mode=training_input_mode, + environment=environment, + tags=tags, + **model_trainer_args, + ) + + model_trainer._temp_recipe_train_dir = recipe_train_dir + return model_trainer + + def with_tensorboard_output_config( + self, tensorboard_output_config: TensorBoardOutputConfig + ) -> "ModelTrainer": + """Set the TensorBoard output configuration. + + Args: + tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig): + The TensorBoard output configuration. + """ + self._tensorboard_output_config = tensorboard_output_config + return self + + def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer": + """Set the retry strategy for the training job. + + Args: + retry_strategy (RetryStrategy): + The retry strategy for the training job. + """ + self._retry_strategy = retry_strategy + return self + + def with_infra_check_config(self, infra_check_config: InfraCheckConfig) -> "ModelTrainer": + """Set the infra check configuration for the training job. + + Args: + infra_check_config (InfraCheckConfig): + The infra check configuration for the training job. + """ + self._infra_check_config = infra_check_config + return self + + def with_session_chaining_config( + self, session_chaining_config: SessionChainingConfig + ) -> "ModelTrainer": + """Set the session chaining configuration for the training job. + + Args: + session_chaining_config (SessionChainingConfig): + The session chaining configuration for the training job. + """ + self._session_chaining_config = session_chaining_config + return self + + def with_remote_debug_config(self, remote_debug_config: RemoteDebugConfig) -> "ModelTrainer": + """Set the remote debug configuration for the training job. + + Args: + remote_debug_config (RemoteDebugConfig): + The remote debug configuration for the training job. + """ + self._remote_debug_config = remote_debug_config + return self diff --git a/src/sagemaker/modules/train/sm_recipes/__init__.py b/src/sagemaker/modules/train/sm_recipes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/modules/train/sm_recipes/training_recipes.json b/src/sagemaker/modules/train/sm_recipes/training_recipes.json new file mode 100644 index 0000000000..400e13f08a --- /dev/null +++ b/src/sagemaker/modules/train/sm_recipes/training_recipes.json @@ -0,0 +1,15 @@ +{ + "adapter_repo": "https://github.com/aws/sagemaker-training-adapter-for-nemo.git", + "launcher_repo": "https://github.com/aws/sagemaker-hyperpod-recipes.git", + "neuron_dist_repo": "https://github.com/aws-neuron/neuronx-distributed-training.git", + "gpu_image" : { + "framework": "pytorch-smp", + "version": "2.4.1", + "additional_args": {} + }, + "neuron_image": { + "framework": "hyperpod-recipes-neuron", + "version": "2.1.2", + "additional_args": {} + } +} \ No newline at end of file diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py new file mode 100644 index 0000000000..ff38bcbde8 --- /dev/null +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -0,0 +1,317 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utility functions for SageMaker training recipes.""" +from __future__ import absolute_import + +import math +import os +import json +import shutil +import tempfile +from urllib.request import urlretrieve +from typing import Dict, Any, Optional, Tuple + +import omegaconf +from omegaconf import OmegaConf, dictconfig + +from sagemaker.image_uris import retrieve + +from sagemaker.modules import logger +from sagemaker.modules.utils import _run_clone_command_silent +from sagemaker.modules.configs import Compute, SourceCode +from sagemaker.modules.distributed import Torchrun, SMP + + +def _try_resolve_recipe(recipe, key=None): + """Try to resolve recipe and return resolved recipe.""" + if key is not None: + recipe = dictconfig.DictConfig({key: recipe}) + try: + OmegaConf.resolve(recipe) + except omegaconf.errors.OmegaConfBaseException: + return None + if key is None: + return recipe + return recipe[key] + + +def _determine_device_type(instance_type: str) -> str: + """Determine device type (gpu, cpu, trainium) based on instance type.""" + instance_family = instance_type.split(".")[1] + if instance_family.startswith(("p", "g")): + return "gpu" + if instance_family.startswith("trn"): + return "trainium" + return "cpu" + + +def _load_recipes_cfg() -> str: + """Load training recipes configuration json.""" + training_recipes_cfg_filename = os.path.join(os.path.dirname(__file__), "training_recipes.json") + with open(training_recipes_cfg_filename) as training_recipes_cfg_file: + training_recipes_cfg = json.load(training_recipes_cfg_file) + return training_recipes_cfg + + +def _load_base_recipe( + training_recipe: str, + recipe_overrides: Optional[Dict[str, Any]] = None, + training_recipes_cfg: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Load recipe and apply overrides.""" + if recipe_overrides is None: + recipe_overrides = dict() + + temp_local_recipe = tempfile.NamedTemporaryFile(prefix="recipe_original", suffix=".yaml").name + + if training_recipe.endswith(".yaml"): + if os.path.isfile(training_recipe): + shutil.copy(training_recipe, temp_local_recipe) + else: + try: + urlretrieve(training_recipe, temp_local_recipe) + except Exception as e: + raise ValueError( + f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}" + ) + else: + recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_") + + launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get( + "launcher_repo" + ) + _run_clone_command_silent(launcher_repo, recipe_launcher_dir.name) + + recipe = os.path.join( + recipe_launcher_dir.name, + "recipes_collection", + "recipes", + training_recipe + ".yaml", + ) + if os.path.isfile(recipe): + shutil.copy(recipe, temp_local_recipe) + else: + raise ValueError(f"Recipe {training_recipe} not found.") + + recipe = OmegaConf.load(temp_local_recipe) + os.unlink(temp_local_recipe) + recipe = OmegaConf.merge(recipe, recipe_overrides) + return recipe + + +def _register_custom_resolvers(): + """Register custom resolvers for OmegaConf.""" + if not OmegaConf.has_resolver("multiply"): + OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) + if not OmegaConf.has_resolver("divide_ceil"): + OmegaConf.register_new_resolver( + "divide_ceil", lambda x, y: int(math.ceil(x / y)), replace=True + ) + if not OmegaConf.has_resolver("divide_floor"): + OmegaConf.register_new_resolver( + "divide_floor", lambda x, y: int(math.floor(x / y)), replace=True + ) + if not OmegaConf.has_resolver("add"): + OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers)) + + +def _configure_gpu_args( + training_recipes_cfg: Dict[str, Any], + region_name: str, + recipe: OmegaConf, + recipe_train_dir: tempfile.TemporaryDirectory, +) -> Dict[str, Any]: + """Configure arguments specific to GPU.""" + source_code = SourceCode() + args = dict() + + adapter_repo = os.environ.get("TRAINING_ADAPTER_GIT", None) or training_recipes_cfg.get( + "adapter_repo" + ) + _run_clone_command_silent(adapter_repo, recipe_train_dir.name) + + model_type_to_entry = { + "llama_v3": ("llama", "llama_pretrain.py"), + "mistral": ("mistral", "mistral_pretrain.py"), + "mixtral": ("mixtral", "mixtral_pretrain.py"), + } + + if "model" not in recipe: + raise ValueError("Supplied recipe does not contain required field model.") + if "model_type" not in recipe["model"]: + raise ValueError("Supplied recipe does not contain required field model_type.") + model_type = recipe["model"]["model_type"] + if model_type not in model_type_to_entry: + raise ValueError(f"Model type {model_type} not supported") + + source_code.source_dir = os.path.join( + recipe_train_dir.name, "examples", model_type_to_entry[model_type][0] + ) + source_code.entry_script = model_type_to_entry[model_type][1] + + gpu_image_cfg = training_recipes_cfg.get("gpu_image") + if isinstance(gpu_image_cfg, str): + training_image = gpu_image_cfg + else: + training_image = retrieve( + gpu_image_cfg.get("framework"), + region=region_name, + version=gpu_image_cfg.get("version"), + image_scope="training", + **gpu_image_cfg.get("additional_args"), + ) + + # Setting dummy parameters for now + torch_distributed = Torchrun(smp=SMP(random_seed="123456")) + args.update( + { + "source_code": source_code, + "training_image": training_image, + "distributed": torch_distributed, + } + ) + return args + + +def _configure_trainium_args( + training_recipes_cfg: Dict[str, Any], + region_name: str, + recipe_train_dir: tempfile.TemporaryDirectory, +) -> Dict[str, Any]: + """Configure arguments specific to Trainium.""" + source_code = SourceCode() + args = dict() + + _run_clone_command_silent(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name) + + source_code.source_dir = os.path.join(recipe_train_dir.name, "examples") + source_code.entry_script = "training_orchestrator.py" + neuron_image_cfg = training_recipes_cfg.get("neuron_image") + if isinstance(neuron_image_cfg, str): + training_image = neuron_image_cfg + else: + training_image = retrieve( + neuron_image_cfg.get("framework"), + region=region_name, + version=neuron_image_cfg.get("version"), + image_scope="training", + **neuron_image_cfg.get("additional_args"), + ) + + args.update( + { + "source_code": source_code, + "training_image": training_image, + "distributed": Torchrun(), + } + ) + return args + + +def _get_args_from_recipe( + training_recipe: str, + compute: Compute, + region_name: str, + recipe_overrides: Optional[Dict[str, Any]], + requirements: Optional[str], +) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: + """Get arguments for ModelTrainer from a training recipe. + + Returns a dictionary of arguments to be used with ModelTrainer like: + ```python + { + "source_code": SourceCode, + "training_image": str, + "distributed": DistributedConfig, + "compute": Compute, + "hyperparameters": Dict[str, Any], + } + ``` + + Args: + training_recipe (str): + Name of the training recipe or path to the recipe file. + compute (Compute): + Compute configuration for training. + region_name (str): + Name of the AWS region. + recipe_overrides (Optional[Dict[str, Any]]): + Overrides for the training recipe. + requirements (Optional[str]): + Path to the requirements file. + """ + if compute.instance_type is None: + raise ValueError("Must set `instance_type` in compute when using training recipes.") + + training_recipes_cfg = _load_recipes_cfg() + recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg) + + if "trainer" not in recipe: + raise ValueError("Supplied recipe does not contain required field trainer.") + + # Set instance_count + if compute.instance_count and "num_nodes" in recipe["trainer"]: + logger.warning( + f"Using Compute to set instance_count:\n{compute}." + "\nIgnoring trainer -> num_nodes in recipe." + ) + if compute.instance_count is None: + if "num_nodes" not in recipe["trainer"]: + raise ValueError( + "Must provide Compute with instance_count or" " set trainer -> num_nodes in recipe." + ) + compute.instance_count = recipe["trainer"]["num_nodes"] + + if requirements and not os.path.isfile(requirements): + raise ValueError(f"Recipe requirements file {requirements} not found.") + + # Get Training Image, SourceCode, and distributed args + device_type = _determine_device_type(compute.instance_type) + recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_") + if device_type == "gpu": + args = _configure_gpu_args(training_recipes_cfg, region_name, recipe, recipe_train_dir) + elif device_type == "trainium": + args = _configure_trainium_args(training_recipes_cfg, region_name, recipe_train_dir) + else: + raise ValueError(f"Devices of type {device_type} are not supported with training recipes.") + + _register_custom_resolvers() + + # Resolve Final Recipe + final_recipe = _try_resolve_recipe(recipe) + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "recipes") + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "training") + if final_recipe is None: + raise RuntimeError("Could not resolve provided recipe.") + + # Save Final Recipe to source_dir + OmegaConf.save( + config=final_recipe, f=os.path.join(args["source_code"].source_dir, "recipe.yaml") + ) + + # If recipe_requirements is provided, copy it to source_dir + if requirements: + shutil.copy(requirements, args["source_code"].source_dir) + args["source_code"].requirements = os.path.basename(requirements) + + # Update args with compute and hyperparameters + args.update( + { + "compute": compute, + "hyperparameters": {"config-path": ".", "config-name": "recipe.yaml"}, + } + ) + + return args, recipe_train_dir diff --git a/src/sagemaker/modules/types.py b/src/sagemaker/modules/types.py new file mode 100644 index 0000000000..18bdcce3bd --- /dev/null +++ b/src/sagemaker/modules/types.py @@ -0,0 +1,19 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Types module.""" +from __future__ import absolute_import + +from typing import Union +from sagemaker.modules.configs import S3DataSource, FileSystemDataSource + +DataSourceType = Union[str, S3DataSource, FileSystemDataSource] diff --git a/src/sagemaker/modules/utils.py b/src/sagemaker/modules/utils.py new file mode 100644 index 0000000000..502f1bbc74 --- /dev/null +++ b/src/sagemaker/modules/utils.py @@ -0,0 +1,194 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utils module.""" +from __future__ import absolute_import + +import os +import json +import subprocess +import tempfile +from pathlib import Path + +from datetime import datetime +from typing import Literal, Any + +from sagemaker_core.shapes import Unassigned +from sagemaker.modules import logger + + +def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool: + """Check if the path is a valid S3 URI. + + This method checks if the path is a valid S3 URI. If the path_type is specified, + it will also check if the path is a file or a directory. + This method does not check if the S3 bucket or object exists. + + Args: + path (str): S3 URI to validate + path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate. + Defaults to "Any". + + Returns: + bool: True if the path is a valid S3 URI, False otherwise + """ + # Check if the path is a valid S3 URI + if not path.startswith("s3://"): + return False + + if path_type == "File": + # If it's a file, it should not end with a slash + return not path.endswith("/") + if path_type == "Directory": + # If it's a directory, it should end with a slash + return path.endswith("/") + + return path_type == "Any" + + +def _is_valid_path(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool: + """Check if the path is a valid local path. + + Args: + path (str): Local path to validate + path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate. + Defaults to "Any". + + Returns: + bool: True if the path is a valid local path, False otherwise + """ + if not os.path.exists(path): + return False + + if path_type == "File": + return os.path.isfile(path) + if path_type == "Directory": + return os.path.isdir(path) + + return path_type == "Any" + + +def _get_unique_name(base, max_length=63): + """Generate a unique name based on the base name. + + This method generates a unique name based on the base name. + The unique name is generated by appending the current timestamp + to the base name. + + Args: + base (str): The base name to use + max_length (int): The maximum length of the unique name. Defaults to 63. + + Returns: + str: The unique name + """ + current_time = datetime.now().strftime("%Y%m%d%H%M%S") + base = base.replace("_", "-") + unique_name = f"{base}-{current_time}" + unique_name = unique_name[:max_length] # Truncate to max_length + return unique_name + + +def _get_repo_name_from_image(image: str) -> str: + """Get the repository name from the image URI. + + Example: + ``` python + _get_repo_name_from_image("123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo:latest") + # Returns "my-repo" + ``` + + Args: + image (str): The image URI + + Returns: + str: The repository name + """ + return image.split("/")[-1].split(":")[0] + + +def convert_unassigned_to_none(instance) -> Any: + """Convert Unassigned values to None for any instance.""" + for name, value in instance.__dict__.items(): + if isinstance(value, Unassigned): + setattr(instance, name, None) + return instance + + +def safe_serialize(data): + """Serialize the data without wrapping strings in quotes. + + This function handles the following cases: + 1. If `data` is a string, it returns the string as-is without wrapping in quotes. + 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns + the JSON-encoded string using `json.dumps()`. + 3. If `data` cannot be serialized (e.g., a custom object), it returns the string + representation of the data using `str(data)`. + + Args: + data (Any): The data to serialize. + + Returns: + str: The serialized JSON-compatible string or the string representation of the input. + """ + if isinstance(data, str): + return data + try: + return json.dumps(data) + except TypeError: + return str(data) + + +def _run_clone_command_silent(repo_url, dest_dir): + """Run the 'git clone' command with the repo url and the directory to clone the repo into. + + Args: + repo_url (str): Git repo url to be cloned. + dest_dir: (str): Local path where the repo should be cloned into. + + Raises: + CalledProcessError: If failed to clone git repo. + """ + my_env = os.environ.copy() + if repo_url.startswith("https://"): + try: + my_env["GIT_TERMINAL_PROMPT"] = "0" + subprocess.check_call( + ["git", "clone", repo_url, dest_dir], + env=my_env, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + except subprocess.CalledProcessError as e: + logger.error(f"Failed to clone repository: {repo_url}") + logger.error(f"Error output:\n{e}") + raise + elif repo_url.startswith("git@") or repo_url.startswith("ssh://"): + try: + with tempfile.TemporaryDirectory() as tmp_dir: + custom_ssh_executable = Path(tmp_dir) / "ssh_batch" + with open(custom_ssh_executable, "w") as pipe: + print("#!/bin/sh", file=pipe) + print("ssh -oBatchMode=yes $@", file=pipe) + os.chmod(custom_ssh_executable, 0o511) + my_env["GIT_SSH"] = str(custom_ssh_executable) + subprocess.check_call( + ["git", "clone", repo_url, dest_dir], + env=my_env, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + except subprocess.CalledProcessError as e: + del my_env["GIT_SSH"] + logger.error(f"Failed to clone repository: {repo_url}") + logger.error(f"Error output:\n{e}") + raise diff --git a/src/sagemaker/partner_app/__init__.py b/src/sagemaker/partner_app/__init__.py new file mode 100644 index 0000000000..b9ef202bc7 --- /dev/null +++ b/src/sagemaker/partner_app/__init__.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""__init__ file for sagemaker.partner_app.auth_provider""" +from __future__ import absolute_import + +from sagemaker.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401 diff --git a/src/sagemaker/partner_app/auth_provider.py b/src/sagemaker/partner_app/auth_provider.py new file mode 100644 index 0000000000..2e0d7da94c --- /dev/null +++ b/src/sagemaker/partner_app/auth_provider.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +"""The SageMaker partner application SDK auth module""" +from __future__ import absolute_import + +import os +import re +from typing import Dict, Tuple + +import boto3 +from botocore.auth import SigV4Auth +from botocore.credentials import Credentials +from requests.auth import AuthBase +from requests.models import PreparedRequest +from sagemaker.partner_app.auth_utils import PartnerAppAuthUtils + +SERVICE_NAME = "sagemaker" +AWS_PARTNER_APP_ARN_REGEX = r"arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:partner-app\/.*" + + +class RequestsAuth(AuthBase): + """Requests authentication class for SigV4 header generation. + + This class is used to generate the SigV4 header and add it to the request headers. + """ + + def __init__(self, sigv4: SigV4Auth, app_arn: str): + """Initialize the RequestsAuth class. + + Args: + sigv4 (SigV4Auth): SigV4Auth object + app_arn (str): Application ARN + """ + self.sigv4 = sigv4 + self.app_arn = app_arn + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + """Callback function to generate the SigV4 header and add it to the request headers. + + Args: + request (PreparedRequest): PreparedRequest object + + Returns: + PreparedRequest: PreparedRequest object with the SigV4 header added + """ + url, signed_headers = PartnerAppAuthUtils.get_signed_request( + sigv4=self.sigv4, + app_arn=self.app_arn, + url=request.url, + method=request.method, + headers=request.headers, + body=request.body, + ) + request.url = url + request.headers.update(signed_headers) + + return request + + +class PartnerAppAuthProvider: + """The SageMaker partner application SDK auth provider class""" + + def __init__(self, credentials: Credentials = None): + """Initialize the PartnerAppAuthProvider class. + + Args: + credentials (Credentials, optional): AWS credentials. Defaults to None. + Raises: + ValueError: If the AWS_PARTNER_APP_ARN environment variable is not set or is invalid. + """ + self.app_arn = os.getenv("AWS_PARTNER_APP_ARN") + if self.app_arn is None: + raise ValueError("Must specify the AWS_PARTNER_APP_ARN environment variable") + + app_arn_regex_match = re.search(AWS_PARTNER_APP_ARN_REGEX, self.app_arn) + if app_arn_regex_match is None: + raise ValueError("Must specify a valid AWS_PARTNER_APP_ARN environment variable") + + split_arn = self.app_arn.split(":") + self.region = split_arn[3] + + self.credentials = ( + credentials if credentials is not None else boto3.Session().get_credentials() + ) + self.sigv4 = SigV4Auth(self.credentials, SERVICE_NAME, self.region) + + def get_signed_request( + self, url: str, method: str, headers: dict, body: object + ) -> Tuple[str, Dict[str, str]]: + """Generate the SigV4 header and add it to the request headers. + + Args: + url (str): Request URL + method (str): HTTP method + headers (dict): Request headers + body (object): Request body + + Returns: + tuple: (url, headers) + """ + return PartnerAppAuthUtils.get_signed_request( + sigv4=self.sigv4, + app_arn=self.app_arn, + url=url, + method=method, + headers=headers, + body=body, + ) + + def get_auth(self) -> RequestsAuth: + """Returns the callback class (RequestsAuth) used for generating the SigV4 header. + + Returns: + RequestsAuth: Callback Object which will calculate the header just before + request submission. + """ + + return RequestsAuth(self.sigv4, os.environ["AWS_PARTNER_APP_ARN"]) diff --git a/src/sagemaker/partner_app/auth_utils.py b/src/sagemaker/partner_app/auth_utils.py new file mode 100644 index 0000000000..eb1dcacaa9 --- /dev/null +++ b/src/sagemaker/partner_app/auth_utils.py @@ -0,0 +1,122 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +"""Partner App Auth Utils Module""" + +from __future__ import absolute_import + +from hashlib import sha256 +import functools +from typing import Tuple, Dict + +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest + +HEADER_CONNECTION = "Connection" +HEADER_X_AMZ_TARGET = "X-Amz-Target" +HEADER_AUTHORIZATION = "Authorization" +HEADER_PARTNER_APP_SERVER_ARN = "X-SageMaker-Partner-App-Server-Arn" +HEADER_PARTNER_APP_AUTHORIZATION = "X-Amz-Partner-App-Authorization" +HEADER_X_AMZ_CONTENT_SHA_256 = "X-Amz-Content-SHA256" +CALL_PARTNER_APP_API_ACTION = "SageMaker.CallPartnerAppApi" + +PAYLOAD_BUFFER = 1024 * 1024 +EMPTY_SHA256_HASH = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" +UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD" + + +class PartnerAppAuthUtils: + """Partner App Auth Utils Class""" + + @staticmethod + def get_signed_request( + sigv4: SigV4Auth, app_arn: str, url: str, method: str, headers: dict, body: object + ) -> Tuple[str, Dict[str, str]]: + """Generate the SigV4 header and add it to the request headers. + + Args: + sigv4 (SigV4Auth): SigV4Auth object + app_arn (str): Application ARN + url (str): Request URL + method (str): HTTP method + headers (dict): Request headers + body (object): Request body + Returns: + tuple: (url, headers) + """ + # Move API key to X-Amz-Partner-App-Authorization + if HEADER_AUTHORIZATION in headers: + headers[HEADER_PARTNER_APP_AUTHORIZATION] = headers[HEADER_AUTHORIZATION] + + # App Arn + headers[HEADER_PARTNER_APP_SERVER_ARN] = app_arn + + # IAM Action + headers[HEADER_X_AMZ_TARGET] = CALL_PARTNER_APP_API_ACTION + + # Body + headers[HEADER_X_AMZ_CONTENT_SHA_256] = PartnerAppAuthUtils.get_body_header(body) + + # Connection header is excluded from server-side signature calculation + connection_header = headers[HEADER_CONNECTION] if HEADER_CONNECTION in headers else None + + if HEADER_CONNECTION in headers: + del headers[HEADER_CONNECTION] + + # Spaces are encoded as %20 + url = url.replace("+", "%20") + + # Calculate SigV4 header + aws_request = AWSRequest( + method=method, + url=url, + headers=headers, + data=body, + ) + sigv4.add_auth(aws_request) + + # Reassemble headers + final_headers = dict(aws_request.headers.items()) + if connection_header is not None: + final_headers[HEADER_CONNECTION] = connection_header + + return (url, final_headers) + + @staticmethod + def get_body_header(body: object): + """Calculate the body header for the SigV4 header. + + Args: + body (object): Request body + """ + if body and hasattr(body, "seek"): + position = body.tell() + read_chunksize = functools.partial(body.read, PAYLOAD_BUFFER) + checksum = sha256() + for chunk in iter(read_chunksize, b""): + checksum.update(chunk) + hex_checksum = checksum.hexdigest() + body.seek(position) + return hex_checksum + + if body and not isinstance(body, bytes): + # Body is of a class we don't recognize, so don't sign the payload + return UNSIGNED_PAYLOAD + + if body: + # The request serialization has ensured that + # request.body is a bytes() type. + return sha256(body).hexdigest() + + # Body is None + return EMPTY_SHA256_HASH diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 412926279c..46c57581d1 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -13,9 +13,17 @@ """Placeholder docstring""" from __future__ import absolute_import +import json import logging +import math +import os +import shutil +import tempfile from typing import Union, Optional, Dict +from urllib.request import urlretrieve +import omegaconf +from omegaconf import OmegaConf, dictconfig from packaging.version import Version from sagemaker.estimator import Framework, EstimatorBase @@ -27,15 +35,90 @@ validate_distribution, profiler_config_deprecation_warning, ) +from sagemaker.git_utils import _run_clone_command +from sagemaker.image_uris import retrieve from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig +from sagemaker.session import Session from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") +def _setup_omegaconf_resolvers(): + """Set up omegaconf resolvers for training recipes.""" + if not OmegaConf.has_resolver("multiply"): + OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) + if not OmegaConf.has_resolver("divide_ceil"): + OmegaConf.register_new_resolver( + "divide_ceil", lambda x, y: int(math.ceil(x / y)), replace=True + ) + if not OmegaConf.has_resolver("divide_floor"): + OmegaConf.register_new_resolver( + "divide_floor", lambda x, y: int(math.floor(x / y)), replace=True + ) + if not OmegaConf.has_resolver("add"): + OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers)) + + +def _try_resolve_recipe(recipe, key=None): + """Try to resolve recipe and return resolved recipe.""" + if key is not None: + recipe = dictconfig.DictConfig({key: recipe}) + try: + OmegaConf.resolve(recipe) + except omegaconf.errors.OmegaConfBaseException: + return None + if key is None: + return recipe + return recipe[key] + + +def _get_training_recipe_image_uri(image_cfg, region_name): + """Fetch image uri given image spec and region name to use for training.""" + if isinstance(image_cfg, str): + return image_cfg + return retrieve( + image_cfg.get("framework"), + region=region_name, + version=image_cfg.get("version"), + image_scope="training", + **image_cfg.get("additional_args"), + ) + + +def _get_training_recipe_gpu_script(code_dir, recipe, source_dir): + """Return path to training script (entry point) when running a gpu recipe.""" + model_type_to_script = { + "llama_v3": ("llama", "llama_pretrain.py"), + "mistral": ("mistral", "mistral_pretrain.py"), + "mixtral": ("mixtral", "mixtral_pretrain.py"), + } + + if "model" not in recipe: + raise ValueError("Supplied recipe does not contain required field model.") + if "model_type" not in recipe["model"]: + raise ValueError("Supplied recipe does not contain required field model_type.") + model_type = recipe["model"]["model_type"] + if model_type not in model_type_to_script: + raise ValueError(f"Model type {model_type} not supported") + + script_dir = os.path.join(code_dir, "examples", model_type_to_script[model_type][0]) + script = model_type_to_script[model_type][1] + shutil.copyfile(os.path.join(script_dir, script), os.path.join(source_dir, script)) + return script + + +def _get_training_recipe_trainium_script(code_dir, source_dir): + """Return path to training script (entry point) when running a trainium recipe.""" + script_dir = os.path.join(code_dir, "examples") + script = "training_orchestrator.py" + shutil.copytree(script_dir, source_dir, dirs_exist_ok=True) + return script + + class PyTorch(Framework): """Handle end-to-end training and deployment of custom PyTorch code.""" @@ -44,9 +127,12 @@ class PyTorch(Framework): LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled" INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type" + # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve + # to retrieve the image uri below before GA. + def __init__( self, - entry_point: Union[str, PipelineVariable], + entry_point: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[str] = None, py_version: Optional[str] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, @@ -54,6 +140,8 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, distribution: Optional[Dict] = None, compiler_config: Optional[TrainingCompilerConfig] = None, + training_recipe: Optional[str] = None, + recipe_overrides: Optional[Dict] = None, **kwargs, ): """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment. @@ -89,7 +177,7 @@ def __init__( a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + when training on Amazon SageMaker. Must be a local path when using training_recipe. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on @@ -246,6 +334,14 @@ def __init__( compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`): Configures SageMaker Training Compiler to accelerate training. + training_recipe (str): Training recipe to use. This is a local file path, a url, + or a recipe provided by Amazon SageMaker HyperPod recipes, + such as training/llama/hf_llama3_70b_seq8k_gpu_p5x64_pretrain. + This is required when using recipes. + recipe_overrides (Dict): Dictionary specifying key values to override in the + training_recipe. This is optional when using + Amazon SageMaker HyperPod recipes. + **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. @@ -255,6 +351,26 @@ def __init__( :class:`~sagemaker.estimator.Framework` and :class:`~sagemaker.estimator.EstimatorBase`. """ + if training_recipe is not None: + if entry_point is not None: + logger.warning("Argument entry_point will be ignored with training_recipe.") + if hyperparameters is not None: + logger.warning("Argument hyperparameters will be ignored with training recipe.") + if distribution is not None: + logger.warning("Argument distribution will be ignored with training_recipe.") + args = self._setup_for_training_recipe( + training_recipe, recipe_overrides, source_dir, kwargs + ) + entry_point = args["entry_point"] + source_dir = args["source_dir"] + hyperparameters = args["hyperparameters"] + if image_uri is None: + image_uri = args["default_image_uri"] + distribution = args["distribution"] + elif entry_point is None: + raise ValueError( + "Argument entry_point must be set when training_recipe is not provided" + ) validate_version_or_image_args(framework_version, py_version, image_uri) if py_version == "py2": logger.warning( @@ -480,3 +596,172 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na ) return init_params + + @classmethod + def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, source_dir, kwargs): + """Performs training recipe specific setup and returns recipe specific args. + + Updates kwargs and returns a dictionary of args to use for estimator + initialization and setup when using a training recipe. Updates the paths in + the recipe for Sagemaker Jobs environment. + + Args: + training_recipe (str): A recipe which is a local file path, a url or a + sagemaker training recipe. + recipe_overrides (Dict): Dictionary specifying key values to override in the + source_dir (str): Path (absolute, or relative) to a directory where to copy + the scripts for training recipe. requirements.txt can also + go here. + kwargs (dict): Dictionary of args used for estimator initializaiton. + Returns: + dict containing arg values for estimator initialization and setup. + + """ + if kwargs.get("sagemaker_session") is not None: + region_name = kwargs.get("sagemaker_session").boto_region_name + else: + region_name = Session().boto_region_name + + training_recipes_cfg_filename = os.path.join( + os.path.dirname(__file__), "training_recipes.json" + ) + with open(training_recipes_cfg_filename) as training_recipes_cfg_file: + training_recipes_cfg = json.load(training_recipes_cfg_file) + + if recipe_overrides is None: + recipe_overrides = dict() + recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_") + recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_") + args = dict() + if source_dir is None: + args["source_dir"] = "." + else: + if not os.path.exists(source_dir): + raise ValueError( + "When using training_recipe, source_dir must be a local directory." + ) + args["source_dir"] = source_dir + + recipe_name = os.path.splitext(os.path.basename(training_recipe))[0] + temp_local_recipe = tempfile.NamedTemporaryFile(prefix=recipe_name, suffix=".yaml").name + if training_recipe.endswith(".yaml"): + if os.path.isfile(training_recipe): + shutil.copy(training_recipe, temp_local_recipe) + else: + try: + urlretrieve(training_recipe, temp_local_recipe) + except Exception as e: + raise ValueError( + f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}" + ) + else: + launcher_repo = os.environ.get( + "TRAINING_LAUNCHER_GIT", None + ) or training_recipes_cfg.get("launcher_repo") + _run_clone_command(launcher_repo, recipe_launcher_dir.name) + recipe = os.path.join( + recipe_launcher_dir.name, + "recipes_collection", + "recipes", + training_recipe + ".yaml", + ) + if os.path.isfile(recipe): + shutil.copy(recipe, temp_local_recipe) + else: + raise ValueError(f"Recipe {training_recipe} not found.") + + recipe = OmegaConf.load(temp_local_recipe) + os.unlink(temp_local_recipe) + recipe = OmegaConf.merge(recipe, recipe_overrides) + + if "instance_type" not in kwargs: + raise ValueError("Must pass instance type to estimator when using training recipes.") + instance_type = kwargs["instance_type"].split(".")[1] + if instance_type.startswith(("p", "g")): + device_type = "gpu" + elif instance_type.startswith("trn"): + device_type = "trainium" + else: + device_type = "cpu" + + if "trainer" not in recipe: + raise ValueError("Supplied recipe does not contain required field trainer.") + if "instance_count" in kwargs and "num_nodes" in recipe["trainer"]: + logger.warning( + "Using instance_count argument to estimator to set number " + " of nodes. Ignoring trainer -> num_nodes in recipe." + ) + if "instance_count" not in kwargs: + if "num_nodes" not in recipe["trainer"]: + raise ValueError( + "Must set either instance_count argument for estimator or" + "set trainer -> num_nodes in recipe." + ) + kwargs["instance_count"] = recipe["trainer"]["num_nodes"] + + # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve + # to retrieve the image uri below before we go GA. + if device_type == "gpu": + adapter_repo = os.environ.get("TRAINING_ADAPTER_GIT", None) or training_recipes_cfg.get( + "adapter_repo" + ) + _run_clone_command(adapter_repo, recipe_train_dir.name) + script = _get_training_recipe_gpu_script( + recipe_train_dir.name, recipe, args["source_dir"] + ) + args["default_image_uri"] = _get_training_recipe_image_uri( + training_recipes_cfg.get("gpu_image"), region_name + ) + smp_options = { + "enabled": True, + "parameters": { + "placement_strategy": "cluster", + }, + } + args["distribution"] = { + "smdistributed": {"modelparallel": smp_options}, + "torch_distributed": {"enabled": True}, + } + elif device_type == "trainium": + _run_clone_command(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name) + script = _get_training_recipe_trainium_script(recipe_train_dir.name, args["source_dir"]) + args["default_image_uri"] = _get_training_recipe_image_uri( + training_recipes_cfg.get("neuron_image"), region_name + ) + args["distribution"] = { + "torch_distributed": {"enabled": True}, + } + else: + raise ValueError( + f"Devices of type {device_type} are not supported with training recipes." + ) + args["entry_point"] = os.path.basename(script) + + recipe_train_dir.cleanup() + recipe_launcher_dir.cleanup() + + if "container" in recipe and not recipe["container"]: + logger.warning( + "Ignoring container from training_recipe. Use image_uri arg for estimator." + ) + + _setup_omegaconf_resolvers() + final_recipe = _try_resolve_recipe(recipe) + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "recipes") + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "training") + if final_recipe is None: + raise RuntimeError("Could not resolve provided recipe.") + cls.training_recipe_file = tempfile.NamedTemporaryFile( + dir=args["source_dir"], + prefix=recipe_name + "_", + suffix=".yaml", + ) + OmegaConf.save(config=final_recipe, f=cls.training_recipe_file.name) + args["hyperparameters"] = { + "config-path": ".", + "config-name": os.path.basename(cls.training_recipe_file.name), + } + + return args diff --git a/src/sagemaker/pytorch/training_recipes.json b/src/sagemaker/pytorch/training_recipes.json new file mode 100644 index 0000000000..df60f95df9 --- /dev/null +++ b/src/sagemaker/pytorch/training_recipes.json @@ -0,0 +1,15 @@ +{ + "adapter_repo": "https://github.com/aws/sagemaker-training-adapter-for-nemo.git", + "launcher_repo": "https://github.com/aws/sagemaker-hyperpod-recipes.git", + "neuron_dist_repo": "https://github.com/aws-neuron/neuronx-distributed-training.git", + "gpu_image" : { + "framework": "pytorch-smp", + "version": "2.4.1", + "additional_args": {} + }, + "neuron_image" : { + "framework": "hyperpod-recipes-neuron", + "version": "2.1.2", + "additional_args": {} + } +} diff --git a/src/sagemaker/serve/app.py b/src/sagemaker/serve/app.py deleted file mode 100644 index fd9dd6a93a..0000000000 --- a/src/sagemaker/serve/app.py +++ /dev/null @@ -1,100 +0,0 @@ -"""FastAPI requests""" - -from __future__ import absolute_import - -import asyncio -import logging -import threading -from typing import Optional - - -logger = logging.getLogger(__name__) - - -try: - import uvicorn -except ImportError: - logger.error("Unable to import uvicorn, check if uvicorn is installed.") - - -try: - from transformers import pipeline -except ImportError: - logger.error("Unable to import transformers, check if transformers is installed.") - - -try: - from fastapi import FastAPI, Request, APIRouter -except ImportError: - logger.error("Unable to import fastapi, check if fastapi is installed.") - - -class InProcessServer: - """Placeholder docstring""" - - def __init__(self, model_id: Optional[str] = None, task: Optional[str] = None): - self._thread = None - self._loop = None - self._stop_event = asyncio.Event() - self._router = APIRouter() - self._model_id = model_id - self._task = task - self.server = None - self.port = None - self.host = None - # TODO: Pick up device automatically. - self._generator = pipeline(task, model=model_id, device="cpu") - - # pylint: disable=unused-variable - @self._router.post("/generate") - async def generate_text(prompt: Request): - """Placeholder docstring""" - str_prompt = await prompt.json() - str_prompt = str_prompt["inputs"] if "inputs" in str_prompt else str_prompt - - generated_text = self._generator( - str_prompt, max_length=30, num_return_sequences=1, truncation=True - ) - return generated_text - - self._create_server() - - def _create_server(self): - """Placeholder docstring""" - app = FastAPI() - app.include_router(self._router) - - config = uvicorn.Config( - app, - host="127.0.0.1", - port=9007, - log_level="info", - loop="asyncio", - reload=True, - use_colors=True, - ) - - self.server = uvicorn.Server(config) - self.host = config.host - self.port = config.port - - def start_server(self): - """Starts the uvicorn server.""" - if not (self._thread and self._thread.is_alive()): - logger.info("Waiting for a connection...") - self._thread = threading.Thread(target=self._start_run_async_in_thread, daemon=True) - self._thread.start() - - def stop_server(self): - """Destroys the uvicorn server.""" - # TODO: Implement me. - - def _start_run_async_in_thread(self): - """Placeholder docstring""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self._serve()) - - async def _serve(self): - """Placeholder docstring""" - await self.server.serve() diff --git a/src/sagemaker/serve/builder/djl_builder.py b/src/sagemaker/serve/builder/djl_builder.py index 608e1c604f..9b1ebf1257 100644 --- a/src/sagemaker/serve/builder/djl_builder.py +++ b/src/sagemaker/serve/builder/djl_builder.py @@ -47,7 +47,7 @@ from sagemaker.serve.model_server.djl_serving.prepare import ( _create_dir_structure, ) -from sagemaker.serve.utils.predictors import DjlLocalModePredictor +from sagemaker.serve.utils.predictors import InProcessModePredictor, DjlLocalModePredictor from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.telemetry_logger import _capture_telemetry @@ -55,6 +55,7 @@ from sagemaker.base_predictor import PredictorBase logger = logging.getLogger(__name__) +LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS] # Match JumpStart DJL entrypoint format _CODE_FOLDER = "code" @@ -96,11 +97,11 @@ def __init__(self): @abstractmethod def _prepare_for_mode(self): - """Placeholder docstring""" + """Abstract method""" @abstractmethod def _get_client_translators(self): - """Placeholder docstring""" + """Abstract method""" def _is_djl(self): """Placeholder docstring""" @@ -146,7 +147,7 @@ def _create_djl_model(self) -> Type[Model]: @_capture_telemetry("djl.deploy") def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: - """Placeholder docstring""" + """Returns predictor depending on local mode or endpoint mode""" timeout = kwargs.get("model_data_download_timeout") if timeout: self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(timeout)}) @@ -189,6 +190,18 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa serializer = self.schema_builder.input_serializer deserializer = self.schema_builder._output_deserializer + + if self.mode == Mode.IN_PROCESS: + + predictor = InProcessModePredictor( + self.modes[str(Mode.IN_PROCESS)], serializer, deserializer + ) + + self.modes[str(Mode.IN_PROCESS)].create_server( + predictor, + ) + return predictor + if self.mode == Mode.LOCAL_CONTAINER: timeout = kwargs.get("model_data_download_timeout") @@ -250,8 +263,9 @@ def _build_for_hf_djl(self): _create_dir_structure(self.model_path) if not hasattr(self, "pysdk_model"): self.env_vars.update({"HF_MODEL_ID": self.model}) + self.hf_model_config = _get_model_config_properties_from_hf( - self.model, self.env_vars.get("HF_TOKEN") + self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_TOKEN") ) default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations( self.model, self.hf_model_config, self.schema_builder @@ -260,9 +274,10 @@ def _build_for_hf_djl(self): self.schema_builder.sample_input["parameters"][ "max_new_tokens" ] = _default_max_new_tokens + self.pysdk_model = self._create_djl_model() - if self.mode == Mode.LOCAL_CONTAINER: + if self.mode in LOCAL_MODES: self._prepare_for_mode() return self.pysdk_model @@ -451,7 +466,6 @@ def _build_for_djl(self): """Placeholder docstring""" self._validate_djl_serving_sample_data() self.secret_key = None - self.pysdk_model = self._build_for_hf_djl() self.pysdk_model.tune = self._tune_for_hf_djl if self.role_arn: diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 802711e427..e5e850b885 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import importlib.util +import json import uuid from typing import Any, Type, List, Dict, Optional, Union from dataclasses import dataclass, field @@ -23,9 +24,17 @@ from pathlib import Path -from sagemaker.enums import Tag -from sagemaker.s3 import S3Downloader +from sagemaker_core.main.resources import TrainingJob +from sagemaker.transformer import Transformer +from sagemaker.async_inference import AsyncInferenceConfig +from sagemaker.batch_inference.batch_transform_inference_config import BatchTransformInferenceConfig +from sagemaker.compute_resource_requirements import ResourceRequirements +from sagemaker.enums import Tag, EndpointType +from sagemaker.estimator import Estimator +from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.s3 import S3Downloader from sagemaker import Session from sagemaker.model import Model from sagemaker.base_predictor import PredictorBase @@ -78,7 +87,10 @@ _extract_speculative_draft_model_provider, _jumpstart_speculative_decoding, ) -from sagemaker.serve.utils.predictors import _get_local_mode_predictor +from sagemaker.serve.utils.predictors import ( + _get_local_mode_predictor, + _get_in_process_mode_predictor, +) from sagemaker.serve.utils.hardware_detector import ( _get_gpu_info, _get_gpu_info_fallback, @@ -99,15 +111,16 @@ from sagemaker.serve.validations.check_image_and_hardware_type import ( validate_image_uri_and_hardware, ) -from sagemaker.utils import Tags +from sagemaker.serverless import ServerlessInferenceConfig +from sagemaker.utils import Tags, unique_name_from_base from sagemaker.workflow.entities import PipelineVariable from sagemaker.huggingface.llm_utils import ( get_huggingface_model_metadata, download_huggingface_model_metadata, ) from sagemaker.serve.validations.optimization import _validate_optimization_configuration - -logger = logging.getLogger(__name__) +from sagemaker.modules.train import ModelTrainer +from sagemaker.modules import logger # Any new server type should be added here supported_model_servers = { @@ -137,6 +150,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, * ``Mode.SAGEMAKER_ENDPOINT``: Launch on a SageMaker endpoint * ``Mode.LOCAL_CONTAINER``: Launch locally with a container + * ``Mode.IN_PROCESS``: Launch locally to a FastAPI server instead of using a container. shared_libs (List[str]): Any shared libraries you want to bring into the model packaging. dependencies (Optional[Dict[str, Any]): The dependencies of the model @@ -173,8 +187,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, The schema builder can be omitted for HuggingFace models with task types TextGeneration, TextClassification, and QuestionAnswering. Omitting SchemaBuilder is in beta for FillMask, and AutomaticSpeechRecognition use-cases. - model (Optional[Union[object, str]): Model object (with ``predict`` method to perform - inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or ``inference_spec`` + model (Optional[Union[object, str, ModelTrainer, TrainingJob, Estimator]]): + Define object from which training artifacts can be extracted. + Either ``model`` or ``inference_spec`` is required for the model builder to build the artifact. inference_spec (InferenceSpec): The inference spec file with your customized ``invoke`` and ``load`` functions. @@ -265,14 +280,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, schema_builder: Optional[SchemaBuilder] = field( default=None, metadata={"help": "Defines the i/o schema of the model"} ) - model: Optional[Union[object, str]] = field( + model: Optional[Union[object, str, ModelTrainer, TrainingJob, Estimator]] = field( default=None, - metadata={ - "help": ( - 'Model object with "predict" method to perform inference ' - "or HuggingFace/JumpStart Model ID" - ) - }, + metadata={"help": "Define object from which training artifacts can be extracted"}, ) inference_spec: InferenceSpec = field( default=None, @@ -429,11 +439,11 @@ def _prepare_for_mode( # init the InProcessMode object self.modes[str(Mode.IN_PROCESS)] = InProcessMode( inference_spec=self.inference_spec, + model=self.model, schema_builder=self.schema_builder, session=self.sagemaker_session, model_path=self.model_path, env_vars=self.env_vars, - model_server=self.model_server, ) self.modes[str(Mode.IN_PROCESS)].prepare() return None @@ -471,7 +481,9 @@ def _get_client_translators(self): return serializer, deserializer - def _get_predictor(self, endpoint_name: str, sagemaker_session: Session) -> Predictor: + def _get_predictor( + self, endpoint_name: str, sagemaker_session: Session, component_name: Optional[str] = None + ) -> Predictor: """Placeholder docstring""" serializer, deserializer = self._get_client_translators() @@ -480,6 +492,7 @@ def _get_predictor(self, endpoint_name: str, sagemaker_session: Session) -> Pred sagemaker_session=sagemaker_session, serializer=serializer, deserializer=deserializer, + component_name=component_name, ) def _create_model(self): @@ -563,6 +576,18 @@ def _model_builder_deploy_wrapper( if mode and mode != self.mode: self._overwrite_mode_in_deploy(overwrite_mode=mode) + if self.mode == Mode.IN_PROCESS: + serializer, deserializer = self._get_client_translators() + + predictor = _get_in_process_mode_predictor( + self.modes[str(Mode.IN_PROCESS)], serializer, deserializer + ) + + self.modes[str(Mode.IN_PROCESS)].create_server( + predictor, + ) + return predictor + if self.mode == Mode.LOCAL_CONTAINER: serializer, deserializer = self._get_client_translators() predictor = _get_local_mode_predictor( @@ -576,14 +601,11 @@ def _model_builder_deploy_wrapper( self.image_uri, container_timeout_in_second, self.secret_key, predictor ) return predictor + if self.mode == Mode.SAGEMAKER_ENDPOINT: # Validate parameters - if not instance_type: - raise ValueError("Missing required parameter `instance_type`") - - if not initial_instance_count: - raise ValueError("Missing required parameter `initial_instance_count`") - + # Instance type and instance count parameter validation is done based on deployment type + # and will be done inside Model.deploy() if is_1p_image_uri(image_uri=self.image_uri): validate_image_uri_and_hardware( image_uri=self.image_uri, @@ -592,7 +614,9 @@ def _model_builder_deploy_wrapper( ) if "endpoint_logging" not in kwargs: - kwargs["endpoint_logging"] = False + kwargs["endpoint_logging"] = True + kwargs.pop("mode", None) + self.pysdk_model.role = kwargs.pop("role", self.pysdk_model.role) predictor = self._original_deploy( *args, instance_type=instance_type, @@ -633,20 +657,21 @@ def _build_for_torchserve(self) -> Type[Model]: """Build the model for torchserve""" self._save_model_inference_spec() - self._auto_detect_container() + if self.mode != Mode.IN_PROCESS: + self._auto_detect_container() - self.secret_key = prepare_for_torchserve( - model_path=self.model_path, - shared_libs=self.shared_libs, - dependencies=self.dependencies, - session=self.sagemaker_session, - image_uri=self.image_uri, - inference_spec=self.inference_spec, - ) + self.secret_key = prepare_for_torchserve( + model_path=self.model_path, + shared_libs=self.shared_libs, + dependencies=self.dependencies, + session=self.sagemaker_session, + image_uri=self.image_uri, + inference_spec=self.inference_spec, + ) self._prepare_for_mode() - - return self._create_model() + self.model = self._create_model() + return self.model def _user_agent_decorator(self, func): """Placeholder docstring""" @@ -814,11 +839,27 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None: self.env_vars.update({"MLFLOW_MODEL_FLAVOR": f"{deployment_flavor}"}) self.dependencies.update({"requirements": mlflow_model_dependency_path}) + @_capture_telemetry("ModelBuilder.build_training_job") + def _collect_training_job_model_telemetry(self): + """Dummy method to collect telemetry for training job handshake""" + return + + @_capture_telemetry("ModelBuilder.build_model_trainer") + def _collect_model_trainer_model_telemetry(self): + """Dummy method to collect telemetry for model trainer handshake""" + return + + @_capture_telemetry("ModelBuilder.build_estimator") + def _collect_estimator_model_telemetry(self): + """Dummy method to collect telemetry for estimator handshake""" + return + # Model Builder is a class to build the model for deployment. # It supports three modes of deployment # 1/ SageMaker Endpoint # 2/ Local launch with container # 3/ In process mode with Transformers server in beta release + @_capture_telemetry("ModelBuilder.build") def build( # pylint: disable=R0911 self, mode: Type[Mode] = None, @@ -837,6 +878,7 @@ def build( # pylint: disable=R0911 Returns: Type[Model]: A deployable ``Model`` object. """ + self.modes = dict() if mode: @@ -844,6 +886,21 @@ def build( # pylint: disable=R0911 if role_arn: self.role_arn = role_arn + self.serve_settings = self._get_serve_setting() + + if isinstance(self.model, TrainingJob): + self.model_path = self.model.model_artifacts.s3_model_artifacts + self.model = None + self._collect_training_job_model_telemetry() + elif isinstance(self.model, ModelTrainer): + self.model_path = self.model._latest_training_job.model_artifacts.s3_model_artifacts + self.model = None + self._collect_model_trainer_model_telemetry() + elif isinstance(self.model, Estimator): + self.model_path = self.model.output_path + self.model = None + self._collect_estimator_model_telemetry() + self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() self.sagemaker_session.settings._local_download_dir = self.model_path @@ -865,7 +922,6 @@ def build( # pylint: disable=R0911 self.sagemaker_session.sagemaker_client._user_agent_creator.to_string ) - self.serve_settings = self._get_serve_setting() self._is_custom_image_uri = self.image_uri is not None self._handle_mlflow_input() @@ -875,13 +931,30 @@ def build( # pylint: disable=R0911 if ( not (isinstance(self.model, str) and self._is_jumpstart_model_id()) ) and self.model_server: - return self._build_for_model_server() + self.built_model = self._build_for_model_server() + return self.built_model if isinstance(self.model, str): model_task = None + if self._is_jumpstart_model_id(): + if self.mode == Mode.IN_PROCESS: + raise ValueError( + f"{self.mode} is not supported for Jumpstart models. " + "Please use LOCAL_CONTAINER mode to deploy a Jumpstart model" + " on your local machine." + ) self.model_hub = ModelHub.JUMPSTART - return self._build_for_jumpstart() + logger.debug("Building for Jumpstart model Id...") + self.built_model = self._build_for_jumpstart() + return self.built_model + + if self.mode != Mode.IN_PROCESS: + if self._use_jumpstart_equivalent(): + self.model_hub = ModelHub.JUMPSTART + logger.debug("Building for Jumpstart equiavalent model Id...") + self.built_model = self._build_for_jumpstart() + return self.built_model self.model_hub = ModelHub.HUGGINGFACE if self.model_metadata: @@ -899,28 +972,28 @@ def build( # pylint: disable=R0911 if self.schema_builder is None and model_task is not None: self._hf_schema_builder_init(model_task) if model_task == "text-generation": - return self._build_for_tgi() - if model_task == "sentence-similarity": - return self._build_for_tei() + self.built_model = self._build_for_tgi() + return self.built_model + if model_task in ["sentence-similarity", "feature-extraction"]: + self.built_model = self._build_for_tei() + return self.built_model elif self._can_fit_on_single_gpu(): - return self._build_for_transformers() + self.built_model = self._build_for_transformers() + return self.built_model else: - return self._build_for_transformers() + self.built_model = self._build_for_transformers() + return self.built_model # Set TorchServe as default model server if not self.model_server: self.model_server = ModelServer.TORCHSERVE - return self._build_for_torchserve() + self.built_model = self._build_for_torchserve() + return self.built_model raise ValueError("%s model server is not supported" % self.model_server) def _build_validations(self): """Validations needed for model server overrides, or auto-detection or fallback""" - if self.mode == Mode.IN_PROCESS and self.model_server is not ModelServer.MMS: - raise ValueError( - "IN_PROCESS mode is only supported for MMS/Transformers server in beta release." - ) - if self.inference_spec and self.model: raise ValueError("Can only set one of the following: model, inference_spec.") @@ -966,6 +1039,7 @@ def _build_for_model_server(self): # pylint: disable=R0911, R1710 if self.model_server == ModelServer.MMS: return self._build_for_transformers() + @_capture_telemetry("ModelBuilder.save") def save( self, save_path: Optional[str] = None, @@ -1514,3 +1588,211 @@ def _optimize_prepare_for_hf(self): should_upload_artifacts=True, ) self.pysdk_model.env.update(env) + + @_capture_telemetry("ModelBuilder.deploy") + def deploy( + self, + endpoint_name: str = None, + initial_instance_count: Optional[int] = 1, + inference_config: Optional[ + Union[ + ServerlessInferenceConfig, + AsyncInferenceConfig, + BatchTransformInferenceConfig, + ResourceRequirements, + ] + ] = None, + ) -> Union[Predictor, Transformer]: + """Deploys the built Model. + + Depending on the type of config provided, this function will call deployment accordingly. + Args: + endpoint_name (str): Name of the endpoint to deploy. + The supplied base name is used as a prefix and + a unique ID is appended to guarantee uniqueness. + initial_instance_count (int): Number of instances to deploy. + inference_config (Optional[Union[ServerlessInferenceConfig, + AsyncInferenceConfig, BatchTransformInferenceConfig, ResourceRequirements]]) : + Additional Config for different deployment types such as + serverless, async, batch and multi-model/container + Returns: + Transformer for Batch Deployments + Predictors for all others + """ + if not hasattr(self, "built_model"): + raise ValueError("Model Needs to be built before deploying") + endpoint_name = unique_name_from_base(endpoint_name) + if not inference_config: # Real-time Deployment + return self.built_model.deploy( + instance_type=self.instance_type, + initial_instance_count=initial_instance_count, + endpoint_name=endpoint_name, + ) + + if isinstance(inference_config, ServerlessInferenceConfig): + return self.built_model.deploy( + serverless_inference_config=inference_config, + endpoint_name=endpoint_name, + ) + + if isinstance(inference_config, AsyncInferenceConfig): + return self.built_model.deploy( + instance_type=self.instance_type, + initial_instance_count=initial_instance_count, + async_inference_config=inference_config, + endpoint_name=endpoint_name, + ) + + if isinstance(inference_config, BatchTransformInferenceConfig): + transformer = self.built_model.transformer( + instance_type=inference_config.instance_type, + output_path=inference_config.output_path, + instance_count=inference_config.instance_count, + ) + return transformer + + if isinstance(inference_config, ResourceRequirements): + # Multi Model and MultiContainer endpoints with Inference Component + return self.built_model.deploy( + instance_type=self.instance_type, + mode=Mode.SAGEMAKER_ENDPOINT, + endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, + resources=inference_config, + initial_instance_count=initial_instance_count, + role=self.role_arn, + ) + + raise ValueError("Deployment Options not supported") + + def display_benchmark_metrics(self, **kwargs): + """Display Markdown Benchmark Metrics for deployment configs.""" + if not isinstance(self.model, str): + raise ValueError("Benchmarking is only supported for JumpStart or HuggingFace models") + if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent(): + return super().display_benchmark_metrics(**kwargs) + else: + raise ValueError("This model does not have benchmark metrics yet") + + def get_deployment_config(self) -> Optional[Dict[str, Any]]: + """Gets the deployment config to apply to the model. + + Returns: + Optional[Dict[str, Any]]: Deployment config to apply to this model. + """ + if not isinstance(self.model, str): + raise ValueError( + "Deployment config is only supported for JumpStart or HuggingFace models" + ) + if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent(): + return super().get_deployment_config() + else: + raise ValueError("This model does not have any deployment config yet") + + def list_deployment_configs(self) -> List[Dict[str, Any]]: + """List deployment configs for the model in the current region. + + Returns: + List[Dict[str, Any]]: A list of deployment configs. + """ + if not isinstance(self.model, str): + raise ValueError( + "Deployment config is only supported for JumpStart or HuggingFace models" + ) + if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent(): + return super().list_deployment_configs() + else: + raise ValueError("This model does not have any deployment config yet") + + def set_deployment_config(self, config_name: str, instance_type: str) -> None: + """Sets the deployment config to apply to the model. + + Args: + config_name (str): + The name of the deployment config to apply to the model. + Call list_deployment_configs to see the list of config names. + instance_type (str): + The instance_type that the model will use after setting + the config. + """ + if not isinstance(self.model, str): + raise ValueError( + "Deployment config is only supported for JumpStart or HuggingFace models" + ) + if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent(): + logger.warning( + "If there are existing deployment configurations, " + "they will be overwritten by the config %s", + config_name, + ) + return super().set_deployment_config(config_name, instance_type) + else: + raise ValueError(f"The deployment config {config_name} cannot be set on this model") + + def _use_jumpstart_equivalent(self): + """Check if the HuggingFace model has a JumpStart equivalent. + + Replace it with the equivalent if there's one + """ + # Do not use the equivalent JS model if image_uri or env_vars is provided + if self.image_uri or self.env_vars: + return False + if not hasattr(self, "_has_jumpstart_equivalent"): + self._jumpstart_mapping = self._retrieve_hugging_face_model_mapping() + self._has_jumpstart_equivalent = self.model in self._jumpstart_mapping + if self._has_jumpstart_equivalent: + # Use schema builder from HF model metadata + if not self.schema_builder: + model_task = None + if self.model_metadata: + model_task = self.model_metadata.get("HF_TASK") + hf_model_md = get_huggingface_model_metadata(self.model) + if not model_task: + model_task = hf_model_md.get("pipeline_tag") + if model_task: + self._hf_schema_builder_init(model_task) + + huggingface_model_id = self.model + jumpstart_model_id = self._jumpstart_mapping[huggingface_model_id]["jumpstart-model-id"] + self.model = jumpstart_model_id + merged_date = self._jumpstart_mapping[huggingface_model_id].get("merged-at") + self._build_for_jumpstart() + compare_model_diff_message = ( + "If you want to identify the differences between the two, " + "please use model_uris.retrieve() to retrieve the model " + "artifact S3 URI and compare them." + ) + logger.warning( # pylint: disable=logging-fstring-interpolation + "Please note that for this model we are using the JumpStart's " + f'local copy "{jumpstart_model_id}" ' + f'of the HuggingFace model "{huggingface_model_id}" you chose. ' + "We strive to keep our local copy synced with the HF model hub closely. " + "This model was synced " + f"{f'on {merged_date}' if merged_date else 'before 11/04/2024'}. " + f"{compare_model_diff_message if not self._is_gated_model() else ''}" + ) + return True + return False + + def _retrieve_hugging_face_model_mapping(self): + """Retrieve the HuggingFace/JumpStart model mapping and preprocess it.""" + converted_mapping = {} + region = self.sagemaker_session.boto_region_name + try: + mapping_json_object = JumpStartS3PayloadAccessor.get_object_cached( + bucket=get_jumpstart_content_bucket(region), + key="hf_model_id_map_cache.json", + region=region, + s3_client=self.sagemaker_session.s3_client, + ) + mapping = json.loads(mapping_json_object) + except Exception: # pylint: disable=broad-except + return converted_mapping + + for k, v in mapping.items(): + converted_mapping[v["hf-model-id"]] = { + "jumpstart-model-id": k, + "jumpstart-model-version": v["jumpstart-model-version"], + "merged-at": v.get("merged-at"), + "hf-model-repo-sha": v.get("hf-model-repo-sha"), + } + return converted_mapping diff --git a/src/sagemaker/serve/builder/tei_builder.py b/src/sagemaker/serve/builder/tei_builder.py index c77a57f1a7..72ecef9448 100644 --- a/src/sagemaker/serve/builder/tei_builder.py +++ b/src/sagemaker/serve/builder/tei_builder.py @@ -26,13 +26,14 @@ ) from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure from sagemaker.serve.utils.optimize_utils import _is_optimized -from sagemaker.serve.utils.predictors import TeiLocalModePredictor +from sagemaker.serve.utils.predictors import InProcessModePredictor, TeiLocalModePredictor from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.telemetry_logger import _capture_telemetry from sagemaker.base_predictor import PredictorBase logger = logging.getLogger(__name__) +LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS] _CODE_FOLDER = "code" @@ -141,6 +142,17 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa serializer = self.schema_builder.input_serializer deserializer = self.schema_builder._output_deserializer + if self.mode == Mode.IN_PROCESS: + self._prepare_for_mode() + predictor = InProcessModePredictor( + self.modes[str(Mode.IN_PROCESS)], serializer, deserializer + ) + + self.modes[str(Mode.IN_PROCESS)].create_server( + predictor, + ) + return predictor + if self.mode == Mode.LOCAL_CONTAINER: timeout = kwargs.get("model_data_download_timeout") @@ -222,7 +234,7 @@ def _build_for_hf_tei(self): self.pysdk_model = self._create_tei_model() - if self.mode == Mode.LOCAL_CONTAINER: + if self.mode in LOCAL_MODES: self._prepare_for_mode() return self.pysdk_model diff --git a/src/sagemaker/serve/builder/tgi_builder.py b/src/sagemaker/serve/builder/tgi_builder.py index 3614e90914..032056cfec 100644 --- a/src/sagemaker/serve/builder/tgi_builder.py +++ b/src/sagemaker/serve/builder/tgi_builder.py @@ -49,13 +49,14 @@ _get_gpu_info_fallback, ) from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure -from sagemaker.serve.utils.predictors import TgiLocalModePredictor +from sagemaker.serve.utils.predictors import TgiLocalModePredictor, InProcessModePredictor from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.telemetry_logger import _capture_telemetry from sagemaker.base_predictor import PredictorBase logger = logging.getLogger(__name__) +LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS] _CODE_FOLDER = "code" _INVALID_SAMPLE_DATA_EX = ( @@ -176,6 +177,17 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa serializer = self.schema_builder.input_serializer deserializer = self.schema_builder._output_deserializer + + if self.mode == Mode.IN_PROCESS: + predictor = InProcessModePredictor( + self.modes[str(Mode.IN_PROCESS)], serializer, deserializer + ) + + self.modes[str(Mode.IN_PROCESS)].create_server( + predictor, + ) + return predictor + if self.mode == Mode.LOCAL_CONTAINER: timeout = kwargs.get("model_data_download_timeout") @@ -280,7 +292,7 @@ def _build_for_hf_tgi(self): ] = _default_max_new_tokens self.pysdk_model = self._create_tgi_model() - if self.mode == Mode.LOCAL_CONTAINER: + if self.mode in LOCAL_MODES: self._prepare_for_mode() return self.pysdk_model diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index b7baf6b513..0388a9a05d 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -38,7 +38,7 @@ from sagemaker.serve.utils.optimize_utils import _is_optimized from sagemaker.serve.utils.predictors import ( TransformersLocalModePredictor, - TransformersInProcessModePredictor, + InProcessModePredictor, ) from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode @@ -237,7 +237,7 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr if self.mode == Mode.IN_PROCESS: timeout = kwargs.get("model_data_download_timeout") - predictor = TransformersInProcessModePredictor( + predictor = InProcessModePredictor( self.modes[str(Mode.IN_PROCESS)], serializer, deserializer ) diff --git a/src/sagemaker/serve/mode/in_process_mode.py b/src/sagemaker/serve/mode/in_process_mode.py index 60d4f91e34..0c262da6f3 100644 --- a/src/sagemaker/serve/mode/in_process_mode.py +++ b/src/sagemaker/serve/mode/in_process_mode.py @@ -4,35 +4,29 @@ from pathlib import Path import logging -from typing import Dict, Type +from typing import Dict, Type, Optional import time from datetime import datetime, timedelta from sagemaker.base_predictor import PredictorBase from sagemaker.serve.spec.inference_spec import InferenceSpec from sagemaker.serve.builder.schema_builder import SchemaBuilder -from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.utils.exceptions import InProcessDeepPingException -from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer +from sagemaker.serve.model_server.in_process_model_server.in_process_server import InProcessServing from sagemaker.session import Session logger = logging.getLogger(__name__) -_PING_HEALTH_CHECK_FAIL_MSG = ( - "Ping health check did not pass. " - + "Please increase container_timeout_seconds or review your inference code." -) +_PING_HEALTH_CHECK_FAIL_MSG = "Ping health check did not pass. Please review your inference code." -class InProcessMode( - InProcessMultiModelServer, -): +class InProcessMode(InProcessServing): """A class that holds methods to deploy model to a container in process environment""" def __init__( self, - model_server: ModelServer, - inference_spec: Type[InferenceSpec], + model: Optional[str], + inference_spec: Optional[InferenceSpec], schema_builder: Type[SchemaBuilder], session: Session, model_path: str = None, @@ -41,12 +35,12 @@ def __init__( # pylint: disable=bad-super-call super().__init__() + self.model = model self.inference_spec = inference_spec self.model_path = model_path self.env_vars = env_vars self.session = session self.schema_builder = schema_builder - self.model_server = model_server self._ping_local_server = None def load(self, model_path: str = None): @@ -66,12 +60,15 @@ def create_server( self, predictor: PredictorBase, ): - """Creating the server and checking ping health.""" - logger.info("Waiting for model server %s to start up...", self.model_server) + """Creating the fast api server and checking ping health.""" - if self.model_server == ModelServer.MMS: - self._ping_local_server = self._multi_model_server_deep_ping - self._start_serving() + logger.info("Waiting for fastapi server to start up...") + + logger.warning("Note: This is not a standard model server.") + logger.warning("The model is being hosted directly on the FastAPI server.") + + self._ping_local_server = self._deep_ping + self._start_serving() # allow some time for server to be ready. time.sleep(1) diff --git a/src/sagemaker/serve/model_server/in_process_model_server/app.py b/src/sagemaker/serve/model_server/in_process_model_server/app.py new file mode 100644 index 0000000000..18fe63a5fc --- /dev/null +++ b/src/sagemaker/serve/model_server/in_process_model_server/app.py @@ -0,0 +1,150 @@ +"""FastAPI requests""" + +from __future__ import absolute_import + +import asyncio +import io +import logging +import threading +import torch +from typing import Optional, Type + +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.builder.schema_builder import SchemaBuilder + +logger = logging.getLogger(__name__) + + +try: + import uvicorn +except ImportError: + logger.error("Unable to import uvicorn, check if uvicorn is installed.") + + +try: + from fastapi import FastAPI, Request, APIRouter +except ImportError: + logger.error("Unable to import fastapi, check if fastapi is installed.") + + +class InProcessServer: + """Generic In-Process Server for Serving Models using InferenceSpec""" + + def __init__( + self, + model: Optional[str] = None, + inference_spec: Optional[InferenceSpec] = None, + schema_builder: Type[SchemaBuilder] = None, + task: Optional[str] = None, + ): + self._thread = None + self._loop = None + self._stop_event = asyncio.Event() + self._shutdown_event = threading.Event() + self._router = APIRouter() + self._task = task + self.server = None + self.port = None + self.host = None + self.model = model + self.inference_spec = inference_spec + self.schema_builder = schema_builder + + if self.inference_spec: + # Use inference_spec to load the model + self._load_model = self.inference_spec.load(model_dir=None) + elif isinstance(self.model, str): + try: + # Use transformers pipeline to load the model + try: + from transformers import pipeline, Pipeline + except ImportError: + logger.error( + "Unable to import transformers, check if transformers is installed." + ) + + device = 0 if torch.cuda.is_available() else -1 + + self._load_model = pipeline(task, model=self.model, device=device) + except Exception: + logger.info("Falling back to SentenceTransformer for model loading.") + try: + from sentence_transformers import SentenceTransformer + except ImportError: + logger.error( + "Unable to import sentence-transformers, check if sentence-transformers is installed." + ) + + self._load_model = SentenceTransformer(self.model) + else: + raise ValueError("Either inference_spec or model must be provided.") + + @self._router.post("/invoke") + async def invoke(request: Request): + """Generate text based on the provided prompt""" + + request_header = request.headers + request_body = await request.body() + content_type = request_header.get("Content-Type", None) + input_data = schema_builder.input_deserializer.deserialize( + io.BytesIO(request_body), content_type[0] + ) + logger.debug(f"Received request: {input_data}") + if self.inference_spec: + response = self.inference_spec.invoke(input_data, self._load_model) + else: + input_data = input_data["inputs"] if "inputs" in input_data else input_data + if isinstance(self._load_model, Pipeline): + response = self._load_model(input_data, max_length=30, num_return_sequences=1) + else: + embeddings = self._load_model.encode(input_data, normalize_embeddings=True) + response = {"embeddings": embeddings.tolist()} + return response + + self._create_server() + + def _create_server(self): + """Placeholder docstring""" + app = FastAPI() + app.include_router(self._router) + + config = uvicorn.Config( + app, + host="127.0.0.1", + port=9007, + log_level="info", + loop="asyncio", + reload=True, + use_colors=True, + ) + + self.server = uvicorn.Server(config) + self.host = config.host + self.port = config.port + + def start_server(self): + """Starts the uvicorn server.""" + if not (self._thread and self._thread.is_alive()): + logger.info("Waiting for a connection...") + self._thread = threading.Thread(target=self._start_run_async_in_thread, daemon=True) + self._thread.start() + + def stop_server(self): + """Stops the Uvicorn server by setting the shutdown event.""" + if self._thread and self._thread.is_alive(): + logger.info("Shutting down the server...") + self._shutdown_event.set() + self.server.handle_exit(sig=0, frame=None) + self._thread.join() + + logger.info("Server shutdown complete.") + + def _start_run_async_in_thread(self): + """Placeholder docstring""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._serve()) + + async def _serve(self): + """Placeholder docstring""" + await self.server.serve() diff --git a/src/sagemaker/serve/model_server/in_process_model_server/in_process_server.py b/src/sagemaker/serve/model_server/in_process_model_server/in_process_server.py new file mode 100644 index 0000000000..d391fe50a0 --- /dev/null +++ b/src/sagemaker/serve/model_server/in_process_model_server/in_process_server.py @@ -0,0 +1,60 @@ +"""Module for In_process Serving""" + +from __future__ import absolute_import + +import requests +import logging +from sagemaker.serve.utils.exceptions import LocalModelInvocationException +from sagemaker.base_predictor import PredictorBase + +logger = logging.getLogger(__name__) + + +class InProcessServing: + """In Process Mode server instance""" + + def _start_serving(self): + """Initializes the start of the server""" + from sagemaker.serve.model_server.in_process_model_server.app import InProcessServer + + self.server = InProcessServer( + inference_spec=self.inference_spec, model=self.model, schema_builder=self.schema_builder + ) + self.server.start_server() + + def _stop_serving(self): + """Stops the server""" + self.server.stop_server() + + def _invoke_serving(self, request: object, content_type: str, accept: str): + """Placeholder docstring""" + try: + response = requests.post( + f"http://{self.server.host}:{self.server.port}/invoke", + data=request, + headers={"Content-Type": content_type, "Accept": accept}, + timeout=600, + ) + response.raise_for_status() + + return response.content + except Exception as e: + if "Connection refused" in str(e): + raise Exception( + "Unable to send request to the local server: Connection refused." + ) from e + raise Exception("Unable to send request to the local container server %s", str(e)) + + def _deep_ping(self, predictor: PredictorBase): + """Sends a deep ping to ensure prediction""" + healthy = False + response = None + try: + response = predictor.predict(self.schema_builder.sample_input) + healthy = response is not None + # pylint: disable=broad-except + except Exception as e: + if "422 Client Error: Unprocessable Entity for url" in str(e): + raise LocalModelInvocationException(str(e)) + + return healthy, response diff --git a/src/sagemaker/serve/model_server/multi_model_server/inference.py b/src/sagemaker/serve/model_server/multi_model_server/inference.py index 1ee7b5e4dc..595b9d9c39 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/inference.py +++ b/src/sagemaker/serve/model_server/multi_model_server/inference.py @@ -44,13 +44,19 @@ def input_fn(input_data, content_type): """Deserializes the bytes that were received from the model server""" try: if hasattr(schema_builder, "custom_input_translator"): - return schema_builder.custom_input_translator.deserialize( + deserialized_data = schema_builder.custom_input_translator.deserialize( io.BytesIO(input_data), content_type ) else: - return schema_builder.input_deserializer.deserialize( + deserialized_data = schema_builder.input_deserializer.deserialize( io.BytesIO(input_data), content_type[0] ) + + # Check if preprocess method is defined and call it + if hasattr(inference_spec, "preprocess"): + return inference_spec.preprocess(deserialized_data) + + return deserialized_data except Exception as e: logger.error("Encountered error: %s in deserialize_response." % e) raise Exception("Encountered error in deserialize_request.") from e @@ -64,6 +70,8 @@ def predict_fn(input_data, predict_callable): def output_fn(predictions, accept_type): """Prediction is serialized to bytes and sent back to the customer""" try: + if hasattr(inference_spec, "postprocess"): + predictions = inference_spec.postprocess(predictions) if hasattr(schema_builder, "custom_output_translator"): return schema_builder.custom_output_translator.serialize(predictions, accept_type) else: diff --git a/src/sagemaker/serve/model_server/multi_model_server/server.py b/src/sagemaker/serve/model_server/multi_model_server/server.py index 69d5d2e5e7..2fab727c05 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/server.py +++ b/src/sagemaker/serve/model_server/multi_model_server/server.py @@ -2,8 +2,6 @@ from __future__ import absolute_import -import json - import requests import logging import platform @@ -11,7 +9,6 @@ from sagemaker import Session, fw_utils from sagemaker.serve.utils.exceptions import LocalModelInvocationException -from sagemaker.serve.utils.exceptions import InProcessDeepPingException from sagemaker.base_predictor import PredictorBase from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join from sagemaker.s3 import S3Uploader @@ -24,62 +21,6 @@ logger = logging.getLogger(__name__) -class InProcessMultiModelServer: - """In Process Mode Multi Model server instance""" - - def _start_serving(self): - """Initializes the start of the server""" - from sagemaker.serve.app import InProcessServer - - if hasattr(self, "inference_spec"): - model_id = self.inference_spec.get_model() - if not model_id: - raise ValueError("Model id was not provided in Inference Spec.") - else: - model_id = None - self.server = InProcessServer(model_id=model_id) - - self.server.start_server() - - def _stop_serving(self): - """Stops the server""" - self.server.stop_server() - - def _invoke_multi_model_server_serving(self, request: bytes, content_type: str, accept: str): - """Placeholder docstring""" - try: - response = requests.post( - f"http://{self.server.host}:{self.server.port}/generate", - data=request, - headers={"Content-Type": content_type, "Accept": accept}, - timeout=600, - ) - response.raise_for_status() - if isinstance(response.content, bytes): - return json.loads(response.content.decode("utf-8")) - return response.content - except Exception as e: - if "Connection refused" in str(e): - raise Exception( - "Unable to send request to the local server: Connection refused." - ) from e - raise Exception("Unable to send request to the local server.") from e - - def _multi_model_server_deep_ping(self, predictor: PredictorBase): - """Sends a deep ping to ensure prediction""" - healthy = False - response = None - try: - response = predictor.predict(self.schema_builder.sample_input) - healthy = response is not None - # pylint: disable=broad-except - except Exception as e: - if "422 Client Error: Unprocessable Entity for url" in str(e): - raise InProcessDeepPingException(str(e)) - - return healthy, response - - class LocalMultiModelServer: """Local Multi Model server instance""" diff --git a/src/sagemaker/serve/model_server/torchserve/inference.py b/src/sagemaker/serve/model_server/torchserve/inference.py index 2675f6ea6a..cad94cc817 100644 --- a/src/sagemaker/serve/model_server/torchserve/inference.py +++ b/src/sagemaker/serve/model_server/torchserve/inference.py @@ -66,13 +66,19 @@ def input_fn(input_data, content_type): """Placeholder docstring""" try: if hasattr(schema_builder, "custom_input_translator"): - return schema_builder.custom_input_translator.deserialize( + deserialized_data = schema_builder.custom_input_translator.deserialize( io.BytesIO(input_data), content_type ) else: - return schema_builder.input_deserializer.deserialize( + deserialized_data = schema_builder.input_deserializer.deserialize( io.BytesIO(input_data), content_type[0] ) + + # Check if preprocess method is defined and call it + if hasattr(inference_spec, "preprocess"): + return inference_spec.preprocess(deserialized_data) + + return deserialized_data except Exception as e: raise Exception("Encountered error in deserialize_request.") from e @@ -85,6 +91,8 @@ def predict_fn(input_data, predict_callable): def output_fn(predictions, accept_type): """Placeholder docstring""" try: + if hasattr(inference_spec, "postprocess"): + predictions = inference_spec.postprocess(predictions) if hasattr(schema_builder, "custom_output_translator"): return schema_builder.custom_output_translator.serialize(predictions, accept_type) else: diff --git a/src/sagemaker/serve/model_server/triton/triton_builder.py b/src/sagemaker/serve/model_server/triton/triton_builder.py index a19235767f..c47991fa09 100644 --- a/src/sagemaker/serve/model_server/triton/triton_builder.py +++ b/src/sagemaker/serve/model_server/triton/triton_builder.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) -SUPPORTED_TRITON_MODE = {Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT} +SUPPORTED_TRITON_MODE = {Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT, Mode.IN_PROCESS} SUPPORTED_TRITON_FRAMEWORK = {"pytorch", "tensorflow"} INPUT_NAME = "input_1" OUTPUT_NAME = "output_1" diff --git a/src/sagemaker/serve/spec/inference_spec.py b/src/sagemaker/serve/spec/inference_spec.py index 2598a38d01..0397e84975 100644 --- a/src/sagemaker/serve/spec/inference_spec.py +++ b/src/sagemaker/serve/spec/inference_spec.py @@ -28,6 +28,12 @@ def invoke(self, input_object: object, model: object): model (object): The model object """ + def preprocess(self, input_data: object): + """Custom pre-processing function""" + + def postprocess(self, predictions: object): + """Custom post-processing function""" + def prepare(self, *args, **kwargs): """Custom prepare function""" diff --git a/src/sagemaker/serve/utils/predictors.py b/src/sagemaker/serve/utils/predictors.py index 89ec2253f1..af05de6425 100644 --- a/src/sagemaker/serve/utils/predictors.py +++ b/src/sagemaker/serve/utils/predictors.py @@ -212,42 +212,6 @@ def delete_predictor(self): self._mode_obj.destroy_server() -class TransformersInProcessModePredictor(PredictorBase): - """Lightweight Transformers predictor for in process mode deployment""" - - def __init__( - self, - mode_obj: Type[InProcessMode], - serializer=JSONSerializer(), - deserializer=JSONDeserializer(), - ): - self._mode_obj = mode_obj - self.serializer = serializer - self.deserializer = deserializer - - def predict(self, data): - """Placeholder docstring""" - return self._mode_obj._invoke_multi_model_server_serving( - self.serializer.serialize(data), - self.content_type, - self.deserializer.ACCEPT[0], - ) - - @property - def content_type(self): - """The MIME type of the data sent to the inference endpoint.""" - return self.serializer.CONTENT_TYPE - - @property - def accept(self): - """The content type(s) that are expected from the inference endpoint.""" - return self.deserializer.ACCEPT - - def delete_predictor(self): - """Shut down and remove the container that you created in LOCAL_CONTAINER mode""" - self._mode_obj.destroy_server() - - class TeiLocalModePredictor(PredictorBase): """Lightweight Tei predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes""" @@ -354,6 +318,58 @@ def _get_local_mode_predictor( raise ValueError("%s model server is not supported yet!" % model_server) +class InProcessModePredictor(PredictorBase): + """Lightweight predictor for in process mode deployment""" + + def __init__( + self, + mode_obj: Type[InProcessMode], + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), + ): + self._mode_obj = mode_obj + self.serializer = serializer + self.deserializer = deserializer + + def predict(self, data): + """Placeholder docstring""" + return self.deserializer.deserialize( + io.BytesIO( + self._mode_obj._invoke_serving( + self.serializer.serialize(data), + self.content_type, + self.accept[0], + ) + ) + ) + + @property + def content_type(self): + """The MIME type of the data sent to the inference endpoint.""" + return self.serializer.CONTENT_TYPE + + @property + def accept(self): + """The content type(s) that are expected from the inference endpoint.""" + return self.deserializer.ACCEPT + + def delete_predictor(self): + """Shut down and remove the container that you created in IN_PROCESS mode""" + self._mode_obj.destroy_server() + + +def _get_in_process_mode_predictor( + # model_server: ModelServer, + mode_obj: Type[InProcessMode], + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), +) -> Type[PredictorBase]: + """Returns Predictor for IN_PROCESS mode""" + return InProcessModePredictor( + mode_obj=mode_obj, serializer=serializer, deserializer=deserializer + ) + + def retrieve_predictor( endpoint_name: str, schema_builder: SchemaBuilder, diff --git a/src/sagemaker/serve/utils/telemetry_logger.py b/src/sagemaker/serve/utils/telemetry_logger.py index 6a1228ba40..a1a0408718 100644 --- a/src/sagemaker/serve/utils/telemetry_logger.py +++ b/src/sagemaker/serve/utils/telemetry_logger.py @@ -122,13 +122,12 @@ def wrapper(self, *args, **kwargs): extra += f"&x-modelServer={MODEL_SERVER_TO_CODE[str(self.model_server)]}" if self.image_uri: - image_uri_tail = self.image_uri.split("/")[1] image_uri_option = _get_image_uri_option( self.image_uri, getattr(self, "_is_custom_image_uri", False) ) - - if self.image_uri: - extra += f"&x-imageTag={image_uri_tail}" + split_image_uri = self.image_uri.split("/") + if len(split_image_uri) > 1: + extra += f"&x-imageTag={split_image_uri[1]}" extra += f"&x-sdkVersion={SDK_VERSION}" @@ -171,7 +170,7 @@ def wrapper(self, *args, **kwargs): extra += f"&x-latency={round(elapsed, 2)}" - if not self.serve_settings.telemetry_opt_out: + if hasattr(self, "serve_settings") and not self.serve_settings.telemetry_opt_out: _send_telemetry( status, MODE_TO_CODE[str(self.mode)], diff --git a/src/sagemaker/telemetry/constants.py b/src/sagemaker/telemetry/constants.py index 332d706351..2108ff9fd6 100644 --- a/src/sagemaker/telemetry/constants.py +++ b/src/sagemaker/telemetry/constants.py @@ -25,6 +25,8 @@ class Feature(Enum): SDK_DEFAULTS = 1 LOCAL_MODE = 2 REMOTE_FUNCTION = 3 + MODEL_TRAINER = 4 + ESTIMATOR = 5 def __str__(self): # pylint: disable=E0307 """Return the feature name.""" diff --git a/src/sagemaker/telemetry/telemetry_logging.py b/src/sagemaker/telemetry/telemetry_logging.py index d2b91a321c..b45550b2c2 100644 --- a/src/sagemaker/telemetry/telemetry_logging.py +++ b/src/sagemaker/telemetry/telemetry_logging.py @@ -52,6 +52,8 @@ str(Feature.SDK_DEFAULTS): 1, str(Feature.LOCAL_MODE): 2, str(Feature.REMOTE_FUNCTION): 3, + str(Feature.MODEL_TRAINER): 4, + str(Feature.ESTIMATOR): 5, } STATUS_TO_CODE = { @@ -61,7 +63,14 @@ def _telemetry_emitter(feature: str, func_name: str): - """Decorator to emit telemetry logs for SageMaker Python SDK functions""" + """Telemetry Emitter + + Decorator to emit telemetry logs for SageMaker Python SDK functions. This class needs + sagemaker_session object as a member. Default session object is a pysdk v2 Session object + in this repo. When collecting telemetry for classes using sagemaker-core Session object, + we should be aware of its differences, such as sagemaker_session.sagemaker_config does not + exist in new Session class. + """ def decorator(func): @functools.wraps(func) @@ -95,10 +104,18 @@ def wrapper(*args, **kwargs): # Construct the feature list to track feature combinations feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]] - if sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS: + if ( + hasattr(sagemaker_session, "sagemaker_config") + and sagemaker_session.sagemaker_config + and feature != Feature.SDK_DEFAULTS + ): feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)]) - if sagemaker_session.local_mode and feature != Feature.LOCAL_MODE: + if ( + hasattr(sagemaker_session, "local_mode") + and sagemaker_session.local_mode + and feature != Feature.LOCAL_MODE + ): feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)]) # Construct the extra info to track platform and environment usage metadata @@ -111,7 +128,7 @@ def wrapper(*args, **kwargs): ) # Add endpoint ARN to the extra info if available - if sagemaker_session.endpoint_arn: + if hasattr(sagemaker_session, "endpoint_arn") and sagemaker_session.endpoint_arn: extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}" start_timer = perf_counter() @@ -171,8 +188,9 @@ def _send_telemetry_request( ) -> None: """Make GET request to an empty object in S3 bucket""" try: - accountId = _get_accountId(session) - region = _get_region_or_default(session) + accountId = _get_accountId(session) if session else "NotAvailable" + # telemetry will be sent to us-west-2 if no session availale + region = _get_region_or_default(session) if session else DEFAULT_AWS_REGION url = _construct_url( accountId, region, diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 3f640bbe33..e8602de8d7 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1160,7 +1160,7 @@ def get_sagemaker_config_value(sagemaker_session, key, sagemaker_config: dict = Returns: object: The corresponding default value in the configuration file. """ - if sagemaker_session: + if sagemaker_session and hasattr(sagemaker_session, "sagemaker_config"): config_to_check = sagemaker_session.sagemaker_config else: config_to_check = sagemaker_config diff --git a/tests/data/modules/local_script/data/test/x_test.npy b/tests/data/modules/local_script/data/test/x_test.npy new file mode 100644 index 0000000000..a9977e39c0 Binary files /dev/null and b/tests/data/modules/local_script/data/test/x_test.npy differ diff --git a/tests/data/modules/local_script/data/test/y_test.npy b/tests/data/modules/local_script/data/test/y_test.npy new file mode 100644 index 0000000000..a7191945ee Binary files /dev/null and b/tests/data/modules/local_script/data/test/y_test.npy differ diff --git a/tests/data/modules/local_script/data/train/x_train.npy b/tests/data/modules/local_script/data/train/x_train.npy new file mode 100644 index 0000000000..d267502e65 Binary files /dev/null and b/tests/data/modules/local_script/data/train/x_train.npy differ diff --git a/tests/data/modules/local_script/data/train/y_train.npy b/tests/data/modules/local_script/data/train/y_train.npy new file mode 100644 index 0000000000..b8c17c4972 Binary files /dev/null and b/tests/data/modules/local_script/data/train/y_train.npy differ diff --git a/tests/data/modules/local_script/local_training_script.py b/tests/data/modules/local_script/local_training_script.py new file mode 100644 index 0000000000..6bb73343c0 --- /dev/null +++ b/tests/data/modules/local_script/local_training_script.py @@ -0,0 +1,147 @@ +# flake8: noqa +import argparse +import numpy as np +import os +import sys +import logging +import json +import shutil +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, TensorDataset +from pytorch_model_def import get_model + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.StreamHandler(sys.stdout)) +current_dir = os.path.dirname(os.path.abspath(__file__)) +data_dir = "/opt/ml/input/data" + + +def get_train_data(train_dir): + """ + Get the training data and convert to tensors + """ + + x_train = np.load(os.path.join(train_dir, "x_train.npy")) + y_train = np.load(os.path.join(train_dir, "y_train.npy")) + logger.info(f"x train: {x_train.shape}, y train: {y_train.shape}") + + return torch.from_numpy(x_train), torch.from_numpy(y_train) + + +def get_test_data(test_dir): + """ + Get the testing data and convert to tensors + """ + + x_test = np.load(os.path.join(test_dir, "x_test.npy")) + y_test = np.load(os.path.join(test_dir, "y_test.npy")) + logger.info(f"x test: {x_test.shape}, y test: {y_test.shape}") + + return torch.from_numpy(x_test), torch.from_numpy(y_test) + + +def model_fn(model_dir): + """ + Load the model for inference + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = get_model() + model.load_state_dict(torch.load(model_dir + "/model.pth")) + model.eval() + return model.to(device) + + +def input_fn(request_body, request_content_type): + """ + Deserialize and prepare the prediction input + """ + + if request_content_type == "application/json": + request = json.loads(request_body) + train_inputs = torch.tensor(request) + return train_inputs + + +def predict_fn(input_data, model): + """ + Apply model to the incoming request + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + with torch.no_grad(): + return model(input_data.float()).numpy()[0] + + +def train(): + """ + Train the PyTorch model + """ + # Directories: train, test and model + train_dir = os.path.join(data_dir, "train") + test_dir = os.path.join(data_dir, "test") + model_dir = os.environ.get("SM_MODEL_DIR", os.path.join(current_dir, "data/model")) + + # Load the training and testing data + x_train, y_train = get_train_data(train_dir) + x_test, y_test = get_test_data(test_dir) + train_ds = TensorDataset(x_train, y_train) + + # Training parameters - used to configure the training loop + batch_size = 64 + epochs = 1 + learning_rate = 0.1 + logger.info( + "batch_size = {}, epochs = {}, learning rate = {}".format(batch_size, epochs, learning_rate) + ) + + train_dl = DataLoader(train_ds, batch_size, shuffle=True) + + # Define the model, loss function and optimizer + model = get_model() + model = model.to(device) + criterion = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + + # Train the model + for epoch in range(epochs): + for x_train_batch, y_train_batch in train_dl: + y = model(x_train_batch.float()) + loss = criterion(y.flatten(), y_train_batch.float()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + epoch += 1 + logger.info(f"epoch: {epoch} -> loss: {loss}") + + # Test the model + with torch.no_grad(): + y = model(x_test.float()).flatten() + mse = ((y - y_test) ** 2).sum() / y_test.shape[0] + print("\nTest MSE:", mse.numpy()) + + # Save the model + os.makedirs(model_dir, exist_ok=True) + torch.save(model.state_dict(), model_dir + "/model.pth") + inference_code_path = model_dir + "/code/" + + if not os.path.exists(inference_code_path): + os.mkdir(inference_code_path) + logger.info("Created a folder at {}!".format(inference_code_path)) + + shutil.copy("local_training_script.py", inference_code_path) + shutil.copy("pytorch_model_def.py", inference_code_path) + logger.info("Saving models files to {}".format(inference_code_path)) + + +if __name__ == "__main__": + print("Running the training job ...\n") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + train() diff --git a/tests/data/modules/local_script/pytorch_model_def.py b/tests/data/modules/local_script/pytorch_model_def.py new file mode 100644 index 0000000000..2440b22f88 --- /dev/null +++ b/tests/data/modules/local_script/pytorch_model_def.py @@ -0,0 +1,23 @@ +# flake8: noqa +import torch +import torch.nn as nn + + +class NeuralNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(8, 8) + self.fc2 = nn.Linear(8, 6) + self.fc3 = nn.Linear(6, 1) + + def forward(self, x): + x = torch.tanh(self.fc1(x)) + x = torch.sigmoid(self.fc2(x)) + x = self.fc3(x) + return x + + +def get_model(): + + model = NeuralNet() + return model diff --git a/tests/data/modules/params_script/train.py b/tests/data/modules/params_script/train.py new file mode 100644 index 0000000000..8d3924a325 --- /dev/null +++ b/tests/data/modules/params_script/train.py @@ -0,0 +1,141 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Script to test hyperparameters contract.""" +from __future__ import absolute_import + +import argparse +import json +import os + +EXPECTED_HYPERPARAMETERS = { + "integer": 1, + "boolean": True, + "float": 3.14, + "string": "Hello World", + "list": [1, 2, 3], + "dict": { + "string": "value", + "integer": 3, + "list": [1, 2, 3], + "dict": {"key": "value"}, + "boolean": True, + }, +} + + +def parse_args(): + parser = argparse.ArgumentParser(description="Test Hyperparameters") + parser.add_argument( + "--string", + type=str, + default=None, + required=True, + ) + parser.add_argument( + "--integer", + type=int, + default=None, + required=True, + ) + parser.add_argument( + "--float", + type=float, + default=None, + required=True, + ) + parser.add_argument( + "--boolean", + type=lambda x: json.loads(x), + default=None, + required=True, + ) + parser.add_argument( + "--list", + type=lambda x: json.loads(x), + default=None, + required=True, + ) + parser.add_argument( + "--dict", + type=lambda x: json.loads(x), + default=None, + required=True, + ) + return parser.parse_args() + + +def main(): + args = parse_args() + print(args) + + assert isinstance(args.string, str) + assert isinstance(args.integer, int) + assert isinstance(args.boolean, bool) + assert isinstance(args.float, float) + assert isinstance(args.list, list) + assert isinstance(args.dict, dict) + + assert args.string == EXPECTED_HYPERPARAMETERS["string"] + assert args.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert args.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert args.float == EXPECTED_HYPERPARAMETERS["float"] + assert args.list == EXPECTED_HYPERPARAMETERS["list"] + assert args.dict == EXPECTED_HYPERPARAMETERS["dict"] + + assert os.environ["SM_HP_STRING"] == EXPECTED_HYPERPARAMETERS["string"] + assert int(os.environ["SM_HP_INTEGER"]) == EXPECTED_HYPERPARAMETERS["integer"] + assert float(os.environ["SM_HP_FLOAT"]) == EXPECTED_HYPERPARAMETERS["float"] + assert json.loads(os.environ["SM_HP_BOOLEAN"]) == EXPECTED_HYPERPARAMETERS["boolean"] + assert json.loads(os.environ["SM_HP_LIST"]) == EXPECTED_HYPERPARAMETERS["list"] + assert json.loads(os.environ["SM_HP_DICT"]) == EXPECTED_HYPERPARAMETERS["dict"] + + params = json.loads(os.environ["SM_HPS"]) + print(f"SM_HPS: {params}") + assert params["string"] == EXPECTED_HYPERPARAMETERS["string"] + assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"] + assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"] + assert params["float"] == EXPECTED_HYPERPARAMETERS["float"] + assert params["list"] == EXPECTED_HYPERPARAMETERS["list"] + assert params["dict"] == EXPECTED_HYPERPARAMETERS["dict"] + + assert isinstance(params, dict) + assert isinstance(params["string"], str) + assert isinstance(params["integer"], int) + assert isinstance(params["boolean"], bool) + assert isinstance(params["float"], float) + assert isinstance(params["list"], list) + assert isinstance(params["dict"], dict) + + params = json.loads(os.environ["SM_TRAINING_ENV"])["hyperparameters"] + print(params) + assert params["string"] == EXPECTED_HYPERPARAMETERS["string"] + assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"] + assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"] + assert params["float"] == EXPECTED_HYPERPARAMETERS["float"] + assert params["list"] == EXPECTED_HYPERPARAMETERS["list"] + assert params["dict"] == EXPECTED_HYPERPARAMETERS["dict"] + + assert isinstance(params, dict) + assert isinstance(params["string"], str) + assert isinstance(params["integer"], int) + assert isinstance(params["boolean"], bool) + assert isinstance(params["float"], float) + assert isinstance(params["list"], list) + assert isinstance(params["dict"], dict) + print(f"SM_TRAINING_ENV -> hyperparameters: {params}") + + print("Test passed.") + + +if __name__ == "__main__": + main() diff --git a/tests/data/modules/params_script/train.sh b/tests/data/modules/params_script/train.sh new file mode 100644 index 0000000000..20f9a3c57a --- /dev/null +++ b/tests/data/modules/params_script/train.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +echo "Do some extra work here..." + +CMD="python train.py $@" +echo "Executing Command: $CMD" + +python train.py "$@" + +echo "Done!" diff --git a/tests/data/modules/script_mode/custom_script.py b/tests/data/modules/script_mode/custom_script.py new file mode 100644 index 0000000000..26e5826267 --- /dev/null +++ b/tests/data/modules/script_mode/custom_script.py @@ -0,0 +1,145 @@ +# flake8: noqa +import argparse +import numpy as np +import os +import sys +import logging +import json +import shutil +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, TensorDataset +from pytorch_model_def import get_model + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.StreamHandler(sys.stdout)) +current_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_train_data(train_dir): + """ + Get the training data and convert to tensors + """ + + x_train = np.load(os.path.join(train_dir, "x_train.npy")) + y_train = np.load(os.path.join(train_dir, "y_train.npy")) + logger.info(f"x train: {x_train.shape}, y train: {y_train.shape}") + + return torch.from_numpy(x_train), torch.from_numpy(y_train) + + +def get_test_data(test_dir): + """ + Get the testing data and convert to tensors + """ + + x_test = np.load(os.path.join(test_dir, "x_test.npy")) + y_test = np.load(os.path.join(test_dir, "y_test.npy")) + logger.info(f"x test: {x_test.shape}, y test: {y_test.shape}") + + return torch.from_numpy(x_test), torch.from_numpy(y_test) + + +def model_fn(model_dir): + """ + Load the model for inference + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = get_model() + model.load_state_dict(torch.load(model_dir + "/model.pth")) + model.eval() + return model.to(device) + + +def input_fn(request_body, request_content_type): + """ + Deserialize and prepare the prediction input + """ + + if request_content_type == "application/json": + request = json.loads(request_body) + train_inputs = torch.tensor(request) + return train_inputs + + +def predict_fn(input_data, model): + """ + Apply model to the incoming request + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + with torch.no_grad(): + return model(input_data.float()).numpy()[0] + + +def train(): + """ + Train the PyTorch model + """ + # Directories: train, test and model + train_dir = os.path.join(current_dir, "data/train") + test_dir = os.path.join(current_dir, "data/test") + model_dir = os.environ.get("SM_MODEL_DIR", os.path.join(current_dir, "data/model")) + + # Load the training and testing data + x_train, y_train = get_train_data(train_dir) + x_test, y_test = get_test_data(test_dir) + train_ds = TensorDataset(x_train, y_train) + + # Training parameters - used to configure the training loop + batch_size = 64 + epochs = 1 + learning_rate = 0.1 + logger.info( + "batch_size = {}, epochs = {}, learning rate = {}".format(batch_size, epochs, learning_rate) + ) + + train_dl = DataLoader(train_ds, batch_size, shuffle=True) + + # Define the model, loss function and optimizer + model = get_model() + model = model.to(device) + criterion = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + + # Train the model + for epoch in range(epochs): + for x_train_batch, y_train_batch in train_dl: + y = model(x_train_batch.float()) + loss = criterion(y.flatten(), y_train_batch.float()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + epoch += 1 + logger.info(f"epoch: {epoch} -> loss: {loss}") + + # Test the model + with torch.no_grad(): + y = model(x_test.float()).flatten() + mse = ((y - y_test) ** 2).sum() / y_test.shape[0] + print("\nTest MSE:", mse.numpy()) + + # Save the model + os.makedirs(model_dir, exist_ok=True) + torch.save(model.state_dict(), model_dir + "/model.pth") + inference_code_path = model_dir + "/code/" + + if not os.path.exists(inference_code_path): + os.mkdir(inference_code_path) + logger.info("Created a folder at {}!".format(inference_code_path)) + + shutil.copy("custom_script.py", inference_code_path) + shutil.copy("pytorch_model_def.py", inference_code_path) + logger.info("Saving models files to {}".format(inference_code_path)) + + +if __name__ == "__main__": + print("Running the training job ...\n") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + train() diff --git a/tests/data/modules/script_mode/data/test/x_test.npy b/tests/data/modules/script_mode/data/test/x_test.npy new file mode 100644 index 0000000000..a9977e39c0 Binary files /dev/null and b/tests/data/modules/script_mode/data/test/x_test.npy differ diff --git a/tests/data/modules/script_mode/data/test/y_test.npy b/tests/data/modules/script_mode/data/test/y_test.npy new file mode 100644 index 0000000000..a7191945ee Binary files /dev/null and b/tests/data/modules/script_mode/data/test/y_test.npy differ diff --git a/tests/data/modules/script_mode/data/train/x_train.npy b/tests/data/modules/script_mode/data/train/x_train.npy new file mode 100644 index 0000000000..d267502e65 Binary files /dev/null and b/tests/data/modules/script_mode/data/train/x_train.npy differ diff --git a/tests/data/modules/script_mode/data/train/y_train.npy b/tests/data/modules/script_mode/data/train/y_train.npy new file mode 100644 index 0000000000..b8c17c4972 Binary files /dev/null and b/tests/data/modules/script_mode/data/train/y_train.npy differ diff --git a/tests/data/modules/script_mode/pytorch_model_def.py b/tests/data/modules/script_mode/pytorch_model_def.py new file mode 100644 index 0000000000..2440b22f88 --- /dev/null +++ b/tests/data/modules/script_mode/pytorch_model_def.py @@ -0,0 +1,23 @@ +# flake8: noqa +import torch +import torch.nn as nn + + +class NeuralNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(8, 8) + self.fc2 = nn.Linear(8, 6) + self.fc3 = nn.Linear(6, 1) + + def forward(self, x): + x = torch.tanh(self.fc1(x)) + x = torch.sigmoid(self.fc2(x)) + x = self.fc3(x) + return x + + +def get_model(): + + model = NeuralNet() + return model diff --git a/tests/data/modules/script_mode/requirements.txt b/tests/data/modules/script_mode/requirements.txt new file mode 100644 index 0000000000..da7441eee2 --- /dev/null +++ b/tests/data/modules/script_mode/requirements.txt @@ -0,0 +1,3 @@ +numpy +-f https://download.pytorch.org/whl/torch_stable.html +torch==2.0.1+cpu diff --git a/tests/integ/sagemaker/modules/__init__.py b/tests/integ/sagemaker/modules/__init__.py new file mode 100644 index 0000000000..9d8bffee3f --- /dev/null +++ b/tests/integ/sagemaker/modules/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Placeholder docstring""" diff --git a/tests/integ/sagemaker/modules/conftest.py b/tests/integ/sagemaker/modules/conftest.py new file mode 100644 index 0000000000..c3de81157a --- /dev/null +++ b/tests/integ/sagemaker/modules/conftest.py @@ -0,0 +1,40 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test image builder""" +from __future__ import absolute_import + +import pytest + +import os +import boto3 +from sagemaker.modules import Session + +DEFAULT_REGION = "us-west-2" + + +@pytest.fixture(scope="module") +def modules_sagemaker_session(): + region = os.environ.get("AWS_DEFAULT_REGION") + if not region: + os.environ["AWS_DEFAULT_REGION"] = DEFAULT_REGION + region_manual_set = True + else: + region_manual_set = True + + boto_session = boto3.Session(region_name=os.environ["AWS_DEFAULT_REGION"]) + sagemaker_session = Session(boto_session=boto_session) + + yield sagemaker_session + + if region_manual_set and "AWS_DEFAULT_REGION" in os.environ: + del os.environ["AWS_DEFAULT_REGION"] diff --git a/tests/integ/sagemaker/modules/train/test_local_model_trainer.py b/tests/integ/sagemaker/modules/train/test_local_model_trainer.py new file mode 100644 index 0000000000..adb5f85f3e --- /dev/null +++ b/tests/integ/sagemaker/modules/train/test_local_model_trainer.py @@ -0,0 +1,225 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing peCWDissions and limitations under the License. +"""This module contains code to test image builder with local mode""" +from __future__ import absolute_import +import os +import errno + +import shutil +import tempfile + +from tests.integ import DATA_DIR +import tests.integ.lock as lock + +from sagemaker.modules.configs import Compute, InputData, SourceCode +from sagemaker.modules.distributed import Torchrun +from sagemaker.modules.train.model_trainer import Mode, ModelTrainer +import subprocess + +DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310" +CWD = os.getcwd() +SOURCE_DIR = os.path.join(DATA_DIR, "modules/local_script") +LOCK_PATH = os.path.join(tempfile.gettempdir(), "sagemaker_test_local_mode_lock") + + +def delete_local_path(path): + try: + if os.path.exists(path) and os.path.isdir(path): + shutil.rmtree(path) + print(f"Removed directory: {path}") + else: + print(f"Directory does not exist: {path}") + except OSError as exc: + # on Linux, when docker writes to any mounted volume, it uses the container's user. In most + # cases this is root. When the container exits and we try to delete them we can't because + # root owns those files. We expect this to happen, so we handle EACCESS. Any other error + # we will raise the exception up. + if exc.errno == errno.EACCES: + print(f"Failed to delete: {path} Please remove it manually.") + else: + print(f"Failed to delete: {path}") + raise + + +def test_single_container_local_mode_local_data(modules_sagemaker_session): + with lock.lock(LOCK_PATH): + try: + source_code = SourceCode( + source_dir=SOURCE_DIR, + entry_script="local_training_script.py", + ) + + compute = Compute( + instance_type="local_cpu", + instance_count=1, + ) + + train_data = InputData( + channel_name="train", + data_source=os.path.join(SOURCE_DIR, "data/train/"), + ) + + test_data = InputData( + channel_name="test", + data_source=os.path.join(SOURCE_DIR, "data/test/"), + ) + + model_trainer = ModelTrainer( + training_image=DEFAULT_CPU_IMAGE, + sagemaker_session=modules_sagemaker_session, + source_code=source_code, + compute=compute, + input_data_config=[train_data, test_data], + base_job_name="local_mode_single_container_local_data", + training_mode=Mode.LOCAL_CONTAINER, + ) + + model_trainer.train() + assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz")) + finally: + subprocess.run(["docker", "compose", "down", "-v"]) + directories = [ + "compressed_artifacts", + "artifacts", + "model", + "shared", + "input", + "output", + "algo-1", + ] + + for directory in directories: + path = os.path.join(CWD, directory) + delete_local_path(path) + + +def test_single_container_local_mode_s3_data(modules_sagemaker_session): + with lock.lock(LOCK_PATH): + try: + # upload local data to s3 + session = modules_sagemaker_session + bucket = session.default_bucket() + session.upload_data( + path=os.path.join(SOURCE_DIR, "data/train/"), + bucket=bucket, + key_prefix="data/train", + ) + session.upload_data( + path=os.path.join(SOURCE_DIR, "data/test/"), + bucket=bucket, + key_prefix="data/test", + ) + + source_code = SourceCode( + source_dir=SOURCE_DIR, + entry_script="local_training_script.py", + ) + + compute = Compute( + instance_type="local_cpu", + instance_count=1, + ) + + # read input data from s3 + train_data = InputData(channel_name="train", data_source=f"s3://{bucket}/data/train/") + + test_data = InputData(channel_name="test", data_source=f"s3://{bucket}/data/test/") + + model_trainer = ModelTrainer( + training_image=DEFAULT_CPU_IMAGE, + sagemaker_session=modules_sagemaker_session, + source_code=source_code, + compute=compute, + input_data_config=[train_data, test_data], + base_job_name="local_mode_single_container_s3_data", + training_mode=Mode.LOCAL_CONTAINER, + ) + + model_trainer.train() + assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz")) + finally: + subprocess.run(["docker", "compose", "down", "-v"]) + directories = [ + "compressed_artifacts", + "artifacts", + "model", + "shared", + "input", + "output", + "algo-1", + ] + + for directory in directories: + path = os.path.join(CWD, directory) + delete_local_path(path) + + +def test_multi_container_local_mode(modules_sagemaker_session): + with lock.lock(LOCK_PATH): + try: + source_code = SourceCode( + source_dir=SOURCE_DIR, + entry_script="local_training_script.py", + ) + + distributed = Torchrun( + process_count_per_node=1, + ) + + compute = Compute( + instance_type="local_cpu", + instance_count=2, + ) + + train_data = InputData( + channel_name="train", + data_source=os.path.join(SOURCE_DIR, "data/train/"), + ) + + test_data = InputData( + channel_name="test", + data_source=os.path.join(SOURCE_DIR, "data/test/"), + ) + + model_trainer = ModelTrainer( + training_image=DEFAULT_CPU_IMAGE, + sagemaker_session=modules_sagemaker_session, + source_code=source_code, + distributed=distributed, + compute=compute, + input_data_config=[train_data, test_data], + base_job_name="local_mode_multi_container", + training_mode=Mode.LOCAL_CONTAINER, + ) + + model_trainer.train() + assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz")) + assert os.path.exists(os.path.join(CWD, "algo-1")) + assert os.path.exists(os.path.join(CWD, "algo-2")) + + finally: + subprocess.run(["docker", "compose", "down", "-v"]) + directories = [ + "compressed_artifacts", + "artifacts", + "model", + "shared", + "input", + "output", + "algo-1", + "algo-2", + ] + + for directory in directories: + path = os.path.join(CWD, directory) + delete_local_path(path) diff --git a/tests/integ/sagemaker/modules/train/test_model_trainer.py b/tests/integ/sagemaker/modules/train/test_model_trainer.py new file mode 100644 index 0000000000..cd298402b2 --- /dev/null +++ b/tests/integ/sagemaker/modules/train/test_model_trainer.py @@ -0,0 +1,108 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test image builder""" +from __future__ import absolute_import + +from tests.integ import DATA_DIR + +from sagemaker.modules.train import ModelTrainer +from sagemaker.modules.configs import SourceCode, Compute +from sagemaker.modules.distributed import MPI, Torchrun + +EXPECTED_HYPERPARAMETERS = { + "integer": 1, + "boolean": True, + "float": 3.14, + "string": "Hello World", + "list": [1, 2, 3], + "dict": { + "string": "value", + "integer": 3, + "list": [1, 2, 3], + "dict": {"key": "value"}, + "boolean": True, + }, +} + +DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310" + + +def test_hp_contract_basic_py_script(modules_sagemaker_session): + source_code = SourceCode( + source_dir=f"{DATA_DIR}/modules/params_script", + entry_script="train.py", + ) + + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=EXPECTED_HYPERPARAMETERS, + source_code=source_code, + base_job_name="hp-contract-basic-py-script", + ) + + model_trainer.train() + + +def test_hp_contract_basic_sh_script(modules_sagemaker_session): + source_code = SourceCode( + source_dir=f"{DATA_DIR}/modules/params_script", + entry_script="train.sh", + ) + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=EXPECTED_HYPERPARAMETERS, + source_code=source_code, + base_job_name="hp-contract-basic-sh-script", + ) + + model_trainer.train() + + +def test_hp_contract_mpi_script(modules_sagemaker_session): + source_code = SourceCode( + source_dir=f"{DATA_DIR}/modules/params_script", + entry_script="train.py", + ) + compute = Compute(instance_type="ml.m5.xlarge", instance_count=2) + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + compute=compute, + hyperparameters=EXPECTED_HYPERPARAMETERS, + source_code=source_code, + distributed=MPI(), + base_job_name="hp-contract-mpi-script", + ) + + model_trainer.train() + + +def test_hp_contract_torchrun_script(modules_sagemaker_session): + source_code = SourceCode( + source_dir=f"{DATA_DIR}/modules/params_script", + entry_script="train.py", + ) + compute = Compute(instance_type="ml.m5.xlarge", instance_count=2) + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + compute=compute, + hyperparameters=EXPECTED_HYPERPARAMETERS, + source_code=source_code, + distributed=Torchrun(), + base_job_name="hp-contract-torchrun-script", + ) + + model_trainer.train() diff --git a/tests/integ/sagemaker/serve/conftest.py b/tests/integ/sagemaker/serve/conftest.py index a1086afea7..5eb3a2ea11 100644 --- a/tests/integ/sagemaker/serve/conftest.py +++ b/tests/integ/sagemaker/serve/conftest.py @@ -10,64 +10,48 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -# from __future__ import absolute_import +from __future__ import absolute_import -# import os -# import pytest -# import platform -# import collections -# from numpy import loadtxt -# from sagemaker.serve.spec.inference_spec import InferenceSpec +import pytest +import os +import boto3 +import sagemaker +import sagemaker_core.helper.session_helper as core_session -# if platform.python_version_tuple()[1] == "8": -# from xgboost import XGBClassifier -# from sklearn.model_selection import train_test_split +DEFAULT_REGION = "us-west-2" -# from tests.integ.sagemaker.serve.constants import XGB_RESOURCE_DIR +@pytest.fixture(scope="module") +def mb_sagemaker_session(): + region = os.environ.get("AWS_DEFAULT_REGION") + if not region: + os.environ["AWS_DEFAULT_REGION"] = DEFAULT_REGION + region_manual_set = True + else: + region_manual_set = True -# XgbTestSplit = collections.namedtuple("XgbTrainTestSplit", "x_test y_test") + boto_session = boto3.Session(region_name=os.environ["AWS_DEFAULT_REGION"]) + sagemaker_session = sagemaker.Session(boto_session=boto_session) + yield sagemaker_session -# @pytest.fixture(scope="session") -# def loaded_xgb_model(): -# model = XGBClassifier() -# model.load_model(XGB_RESOURCE_DIR + "/model.xgb") -# return model + if region_manual_set and "AWS_DEFAULT_REGION" in os.environ: + del os.environ["AWS_DEFAULT_REGION"] -# @pytest.fixture(scope="session") -# def xgb_inference_spec(): -# class MyXGBoostModel(InferenceSpec): -# def load(self, model_dir: str): -# model = XGBClassifier() -# model.load_model(model_dir + "/model.xgb") -# return model +@pytest.fixture(scope="module") +def mb_sagemaker_core_session(): + region = os.environ.get("AWS_DEFAULT_REGION") + if not region: + os.environ["AWS_DEFAULT_REGION"] = DEFAULT_REGION + region_manual_set = True + else: + region_manual_set = True -# def invoke( -# self, -# input: object, -# model: object, -# ): -# y_pred = model.predict(input) -# predictions = [round(value) for value in y_pred] -# return predictions + boto_session = boto3.Session(region_name=os.environ["AWS_DEFAULT_REGION"]) + sagemaker_session = core_session.Session(boto_session=boto_session) -# return MyXGBoostModel() + yield sagemaker_session - -# @pytest.fixture(scope="session") -# def xgb_test_sets(): -# dataset = loadtxt( -# os.path.join(XGB_RESOURCE_DIR, "classification_training_data.data.csv"), delimiter="," -# ) - -# X = dataset[:, 0:8] -# Y = dataset[:, 8] - -# seed = 7 -# test_size = 0.33 - -# _, x_test, _, y_test = train_test_split(X, Y, test_size=test_size, random_state=seed) - -# return XgbTestSplit(x_test, y_test) + if region_manual_set and "AWS_DEFAULT_REGION" in os.environ: + del os.environ["AWS_DEFAULT_REGION"] diff --git a/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py b/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py new file mode 100644 index 0000000000..10f338c4b5 --- /dev/null +++ b/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py @@ -0,0 +1,193 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from sagemaker import get_execution_role +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split + +import os + +from sagemaker_core.main.shapes import ( + AlgorithmSpecification, + Channel, + DataSource, + S3DataSource, + OutputDataConfig, + ResourceConfig, + StoppingCondition, +) +import uuid +from sagemaker.serve.builder.model_builder import ModelBuilder +import pandas as pd +import numpy as np +from sagemaker.serve import InferenceSpec, SchemaBuilder +from sagemaker_core.main.resources import TrainingJob +from xgboost import XGBClassifier + +from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig + +from sagemaker.s3_utils import s3_path_join +from sagemaker.async_inference import AsyncInferenceConfig +from tests.integ.utils import cleanup_model_resources + + +@pytest.fixture(scope="module") +def xgboost_model_builder(mb_sagemaker_session): + sagemaker_session = mb_sagemaker_session + role = get_execution_role(sagemaker_session=sagemaker_session) + bucket = sagemaker_session.default_bucket() + + # Get IRIS Data + iris = load_iris() + iris_df = pd.DataFrame(iris.data, columns=iris.feature_names) + iris_df["target"] = iris.target + + # Prepare Data + os.makedirs("data", exist_ok=True) + + iris_df = iris_df[["target"] + [col for col in iris_df.columns if col != "target"]] + + train_data, test_data = train_test_split(iris_df, test_size=0.2, random_state=42) + + train_data.to_csv("data/train.csv", index=False, header=False) + test_data.to_csv("data/test.csv", index=False, header=False) + + # Remove the target column from the testing data. We will use this to call invoke_endpoint later + test_data.drop("target", axis=1) + + prefix = "DEMO-scikit-iris" + TRAIN_DATA = "train.csv" + DATA_DIRECTORY = "data" + + sagemaker_session.upload_data( + DATA_DIRECTORY, bucket=bucket, key_prefix="{}/{}".format(prefix, DATA_DIRECTORY) + ) + + s3_input_path = "s3://{}/{}/data/{}".format(bucket, prefix, TRAIN_DATA) + s3_output_path = "s3://{}/{}/output".format(bucket, prefix) + + print(s3_input_path) + print(s3_output_path) + + image = "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1" + + class XGBoostSpec(InferenceSpec): + def load(self, model_dir: str): + print(model_dir) + model = XGBClassifier() + model.load_model(model_dir + "/xgboost-model") + return model + + def invoke(self, input_object: object, model: object): + prediction_probabilities = model.predict_proba(input_object) + predictions = np.argmax(prediction_probabilities, axis=1) + return predictions + + data = {"Name": ["Alice", "Bob", "Charlie"]} + df = pd.DataFrame(data) + training_job_name = str(uuid.uuid4()) + schema_builder = SchemaBuilder(sample_input=df, sample_output=df) + + training_job = TrainingJob.create( + training_job_name=training_job_name, + hyper_parameters={ + "objective": "multi:softmax", + "num_class": "3", + "num_round": "10", + "eval_metric": "merror", + }, + algorithm_specification=AlgorithmSpecification( + training_image=image, training_input_mode="File" + ), + role_arn=role, + input_data_config=[ + Channel( + channel_name="train", + content_type="csv", + compression_type="None", + record_wrapper_type="None", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=s3_input_path, + s3_data_distribution_type="FullyReplicated", + ) + ), + ) + ], + output_data_config=OutputDataConfig(s3_output_path=s3_output_path), + resource_config=ResourceConfig( + instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=30 + ), + stopping_condition=StoppingCondition(max_runtime_in_seconds=600), + ) + training_job.wait() + + xgboost_model_builder = ModelBuilder( + name="ModelBuilderTest", + model_path=training_job.model_artifacts.s3_model_artifacts, + role_arn=role, + inference_spec=XGBoostSpec(), + image_uri=image, + schema_builder=schema_builder, + instance_type="ml.c6i.xlarge", + ) + xgboost_model_builder.build() + return xgboost_model_builder + + +def test_real_time_deployment(xgboost_model_builder): + real_time_predictor = xgboost_model_builder.deploy( + endpoint_name="test", initial_instance_count=1 + ) + + assert real_time_predictor is not None + cleanup_model_resources( + sagemaker_session=xgboost_model_builder.sagemaker_session, + model_name=xgboost_model_builder.built_model.name, + endpoint_name=xgboost_model_builder.built_model.endpoint_name, + ) + + +def test_serverless_deployment(xgboost_model_builder): + serverless_predictor = xgboost_model_builder.deploy( + endpoint_name="test1", inference_config=ServerlessInferenceConfig() + ) + + assert serverless_predictor is not None + cleanup_model_resources( + sagemaker_session=xgboost_model_builder.sagemaker_session, + model_name=xgboost_model_builder.built_model.name, + endpoint_name=xgboost_model_builder.built_model.endpoint_name, + ) + + +def test_async_deployment(xgboost_model_builder, mb_sagemaker_session): + async_predictor = xgboost_model_builder.deploy( + endpoint_name="test2", + inference_config=AsyncInferenceConfig( + output_path=s3_path_join( + "s3://", mb_sagemaker_session.default_bucket(), "async_inference/output" + ) + ), + ) + + assert async_predictor is not None + cleanup_model_resources( + sagemaker_session=xgboost_model_builder.sagemaker_session, + model_name=xgboost_model_builder.built_model.name, + endpoint_name=xgboost_model_builder.built_model.endpoint_name, + ) diff --git a/tests/integ/sagemaker/serve/test_schema_builder.py b/tests/integ/sagemaker/serve/test_schema_builder.py index a0c1673ae8..1a2bbe2355 100644 --- a/tests/integ/sagemaker/serve/test_schema_builder.py +++ b/tests/integ/sagemaker/serve/test_schema_builder.py @@ -33,7 +33,9 @@ def test_model_builder_happy_path_with_only_model_id_text_generation(sagemaker_session): - model_builder = ModelBuilder(model="HuggingFaceH4/zephyr-7b-beta") + model_builder = ModelBuilder( + model="HuggingFaceH4/zephyr-7b-beta", sagemaker_session=sagemaker_session + ) model = model_builder.build(sagemaker_session=sagemaker_session) @@ -47,7 +49,9 @@ def test_model_builder_happy_path_with_only_model_id_text_generation(sagemaker_s def test_model_builder_negative_path(sagemaker_session): # A model-task combo unsupported by both the local and remote schema fallback options. (eg: text-to-video) - model_builder = ModelBuilder(model="ByteDance/AnimateDiff-Lightning") + model_builder = ModelBuilder( + model="ByteDance/AnimateDiff-Lightning", sagemaker_session=sagemaker_session + ) with pytest.raises( TaskNotFoundException, match="Error Message: HuggingFace Schema builder samples for text-to-video could not be found locally or " @@ -86,6 +90,7 @@ def test_model_builder_happy_path_with_task_provided_local_schema_mode( model=model_id, model_metadata={"HF_TASK": task_provided}, instance_type=instance_type_provided, + sagemaker_session=sagemaker_session, ) model = model_builder.build(sagemaker_session=sagemaker_session) @@ -111,13 +116,13 @@ def test_model_builder_happy_path_with_task_provided_local_schema_mode( if container_startup_timeout: predictor = model.deploy( role=role_arn, - instance_count=1, + initial_instance_count=1, instance_type=instance_type_provided, container_startup_health_check_timeout=container_startup_timeout, ) else: predictor = model.deploy( - role=role_arn, instance_count=1, instance_type=instance_type_provided + role=role_arn, initial_instance_count=1, instance_type=instance_type_provided ) predicted_outputs = predictor.predict(inputs) @@ -162,6 +167,7 @@ def test_model_builder_happy_path_with_task_provided_remote_schema_mode( model=model_id, model_metadata={"HF_TASK": task_provided}, instance_type=instance_type_provided, + sagemaker_session=sagemaker_session, ) model = model_builder.build(sagemaker_session=sagemaker_session) @@ -181,7 +187,7 @@ def test_model_builder_happy_path_with_task_provided_remote_schema_mode( logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") predictor = model.deploy( - role=role_arn, instance_count=1, instance_type=instance_type_provided + role=role_arn, initial_instance_count=1, instance_type=instance_type_provided ) predicted_outputs = predictor.predict(inputs) @@ -217,6 +223,7 @@ def test_model_builder_with_task_provided_remote_schema_mode_asr( model=model_id, model_metadata={"HF_TASK": task_provided}, instance_type=instance_type_provided, + sagemaker_session=sagemaker_session, ) model = model_builder.build(sagemaker_session=sagemaker_session) @@ -231,7 +238,9 @@ def test_model_builder_with_task_provided_remote_schema_mode_asr( def test_model_builder_negative_path_with_invalid_task(sagemaker_session): model_builder = ModelBuilder( - model="bert-base-uncased", model_metadata={"HF_TASK": "invalid-task"} + model="bert-base-uncased", + model_metadata={"HF_TASK": "invalid-task"}, + sagemaker_session=sagemaker_session, ) with pytest.raises( diff --git a/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py b/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py index 933c18bacf..8724fc5116 100644 --- a/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py +++ b/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py @@ -71,9 +71,12 @@ def model_input(): @pytest.fixture -def model_builder_model_schema_builder(): +def model_builder_model_schema_builder(sagemaker_session): return ModelBuilder( - model_path=HF_DIR, model=model_id, schema_builder=SchemaBuilder(sample_input, sample_output) + sagemaker_session=sagemaker_session, + model_path=HF_DIR, + model=model_id, + schema_builder=SchemaBuilder(sample_input, sample_output), ) diff --git a/tests/integ/sagemaker/serve/test_serve_model_builder_handshake.py b/tests/integ/sagemaker/serve/test_serve_model_builder_handshake.py new file mode 100644 index 0000000000..d024e761a8 --- /dev/null +++ b/tests/integ/sagemaker/serve/test_serve_model_builder_handshake.py @@ -0,0 +1,208 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import os +import uuid + +import numpy as np +import pandas as pd +from sagemaker_core.main.resources import TrainingJob +from xgboost import XGBClassifier + +from sagemaker.serve import ModelBuilder, SchemaBuilder +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker_core.main.shapes import ( + OutputDataConfig, + StoppingCondition, + Channel, + DataSource, + S3DataSource, + AlgorithmSpecification, + ResourceConfig, +) +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split + +from sagemaker import get_execution_role, image_uris +from sagemaker.modules.train import ModelTrainer + +prefix = "DEMO-scikit-iris" +TRAIN_DATA = "train.csv" +TEST_DATA = "test.csv" +DATA_DIRECTORY = "data" + + +class XGBoostSpec(InferenceSpec): + def load(self, model_dir: str): + print(model_dir) + model = XGBClassifier() + model.load_model(model_dir + "/xgboost-model") + return model + + def invoke(self, input_object: object, model: object): + prediction_probabilities = model.predict_proba(input_object) + predictions = np.argmax(prediction_probabilities, axis=1) + return predictions + + +@pytest.fixture(scope="module") +def data_setup(mb_sagemaker_session): + sagemaker_session = mb_sagemaker_session + bucket = sagemaker_session.default_bucket() + + iris = load_iris() + iris_df = pd.DataFrame(iris.data, columns=iris.feature_names) + iris_df["target"] = iris.target + + os.makedirs("./data", exist_ok=True) + + iris_df = iris_df[["target"] + [col for col in iris_df.columns if col != "target"]] + + train_data, test_data = train_test_split(iris_df, test_size=0.2, random_state=42) + + train_data.to_csv("./data/train.csv", index=False, header=False) + test_data.to_csv("./data/test.csv", index=False, header=False) + + data = {"Name": ["Alice", "Bob", "Charlie"]} + df = pd.DataFrame(data) + schema_builder = SchemaBuilder(sample_input=df, sample_output=df) + + # Remove the target column from the testing data. We will use this to call invoke_endpoint later + test_data.drop("target", axis=1) + + sagemaker_session.upload_data( + DATA_DIRECTORY, bucket=bucket, key_prefix="{}/{}".format(prefix, DATA_DIRECTORY) + ) + + s3_input_path = "s3://{}/{}/data/{}".format(bucket, prefix, TRAIN_DATA) + s3_output_path = "s3://{}/{}/output".format(bucket, prefix) + + data_setup = { + "s3_input_path": s3_input_path, + "s3_output_path": s3_output_path, + "schema_builder": schema_builder, + } + return data_setup + + +def test_model_trainer_handshake(mb_sagemaker_session, mb_sagemaker_core_session, data_setup): + sagemaker_session = mb_sagemaker_session + role = get_execution_role(sagemaker_session=sagemaker_session) + xgboost_image = image_uris.retrieve( + framework="xgboost", region="us-west-2", image_scope="training" + ) + + model_trainer = ModelTrainer( + sagemaker_session=mb_sagemaker_core_session, + base_job_name="test-mb-handshake", + hyperparameters={ + "objective": "multi:softmax", + "num_class": "3", + "num_round": "10", + "eval_metric": "merror", + }, + training_image=xgboost_image, + training_input_mode="File", + role=role, + output_data_config=OutputDataConfig(s3_output_path=data_setup["s3_output_path"]), + stopping_condition=StoppingCondition(max_runtime_in_seconds=600), + ) + + model_trainer.train( + input_data_config=[ + Channel( + channel_name="train", + content_type="csv", + compression_type="None", + record_wrapper_type="None", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=data_setup["s3_input_path"], + s3_data_distribution_type="FullyReplicated", + ) + ), + ) + ] + ) + + model_builder = ModelBuilder( + model=model_trainer, # ModelTrainer object passed onto ModelBuilder directly + sagemaker_session=sagemaker_session, + role_arn=role, + image_uri=xgboost_image, + inference_spec=XGBoostSpec(), + schema_builder=data_setup["schema_builder"], + instance_type="ml.c6i.xlarge", + ) + model = model_builder.build() + assert model.model_data == model_trainer._latest_training_job.model_artifacts.s3_model_artifacts + + +def test_sagemaker_core_handshake(mb_sagemaker_session, data_setup): + sagemaker_session = mb_sagemaker_session + role = get_execution_role(sagemaker_session=sagemaker_session) + xgboost_image = image_uris.retrieve( + framework="xgboost", region="us-west-2", image_scope="training" + ) + + training_job_name = str(uuid.uuid4()) + training_job = TrainingJob.create( + training_job_name=training_job_name, + hyper_parameters={ + "objective": "multi:softmax", + "num_class": "3", + "num_round": "10", + "eval_metric": "merror", + }, + algorithm_specification=AlgorithmSpecification( + training_image=xgboost_image, training_input_mode="File" + ), + role_arn=role, + input_data_config=[ + Channel( + channel_name="train", + content_type="csv", + compression_type="None", + record_wrapper_type="None", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=data_setup["s3_input_path"], + s3_data_distribution_type="FullyReplicated", + ) + ), + ) + ], + output_data_config=OutputDataConfig(s3_output_path=data_setup["s3_output_path"]), + resource_config=ResourceConfig( + instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=30 + ), + stopping_condition=StoppingCondition(max_runtime_in_seconds=600), + ) + training_job.wait() + + model_builder = ModelBuilder( + sagemaker_session=sagemaker_session, + model=training_job, + role_arn=role, + inference_spec=XGBoostSpec(), + image_uri=xgboost_image, + schema_builder=data_setup["schema_builder"], + instance_type="ml.c6i.xlarge", + ) + model = model_builder.build() + + assert model.model_data == training_job.model_artifacts.s3_model_artifacts diff --git a/tests/integ/sagemaker/serve/test_serve_tei.py b/tests/integ/sagemaker/serve/test_serve_tei.py index 5cf1a3635c..4c824da401 100644 --- a/tests/integ/sagemaker/serve/test_serve_tei.py +++ b/tests/integ/sagemaker/serve/test_serve_tei.py @@ -39,11 +39,16 @@ def model_input(): @pytest.fixture -def model_builder_model_schema_builder(): +def model_builder_model_schema_builder(sagemaker_session): return ModelBuilder( + sagemaker_session=sagemaker_session, model_path=HF_DIR, model="BAAI/bge-m3", schema_builder=SchemaBuilder(sample_input, loaded_response), + env_vars={ + # Add this to bypass JumpStart model mapping + "HF_MODEL_ID": "BAAI/bge-m3" + }, ) diff --git a/tests/integ/sagemaker/serve/test_serve_transformers.py b/tests/integ/sagemaker/serve/test_serve_transformers.py index 33a1ae6708..5f172f3edb 100644 --- a/tests/integ/sagemaker/serve/test_serve_transformers.py +++ b/tests/integ/sagemaker/serve/test_serve_transformers.py @@ -72,11 +72,12 @@ def model_input(): @pytest.fixture -def model_builder_model_schema_builder(): +def model_builder_model_schema_builder(sagemaker_session): return ModelBuilder( model_path=HF_DIR, model="bert-base-uncased", schema_builder=SchemaBuilder(sample_input, loaded_response), + sagemaker_session=sagemaker_session, ) diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index 3ad641a321..0eee116e5d 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -241,7 +241,7 @@ def test_huggingface( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_version, f"pytorch{huggingface_pytorch_training_version}" diff --git a/tests/unit/sagemaker/image_uris/test_retrieve.py b/tests/unit/sagemaker/image_uris/test_retrieve.py index fd0bcbd150..360587677f 100644 --- a/tests/unit/sagemaker/image_uris/test_retrieve.py +++ b/tests/unit/sagemaker/image_uris/test_retrieve.py @@ -221,7 +221,6 @@ def test_retrieve_default_version_if_possible(config_for_framework, caplog): image_scope="training", ) assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri - assert "Ignoring framework/algorithm version: invalid-version." in caplog.text @patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG) @@ -239,18 +238,6 @@ def test_retrieve_unsupported_version(config_for_framework): assert "Unsupported some-framework version: 1." in str(e.value) assert "Supported some-framework version(s): 1.0.0, 1.1.0." in str(e.value) - with pytest.raises(ValueError) as e: - image_uris.retrieve( - framework="some-framework", - py_version="py3", - instance_type="ml.c4.xlarge", - region="us-west-2", - image_scope="training", - ) - - assert "Unsupported some-framework version: None." in str(e.value) - assert "Supported some-framework version(s): 1.0.0, 1.1.0." in str(e.value) - @patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG) def test_retrieve_unsupported_region(config_for_framework): @@ -780,3 +767,105 @@ def test_retrieve_with_pipeline_variable(): ), image_scope="training", ) + + +@patch("sagemaker.image_uris.config_for_framework") +def test_get_latest_version_function_with_invalid_framework(config_for_framework): + config_for_framework.side_effect = FileNotFoundError + + with pytest.raises(Exception) as e: + image_uris.retrieve("xgboost", "inference") + assert "No framework config for framework" in str(e.exception) + + +@patch("sagemaker.image_uris.config_for_framework") +def test_get_latest_version_function_with_no_framework(config_for_framework): + config_for_framework.side_effect = {} + + with pytest.raises(Exception) as e: + image_uris.retrieve("xgboost", "inference") + assert "No framework config for framework" in str(e.exception) + + +@pytest.mark.parametrize( + "framework", + [ + "object-detection", + "instance_gpu_info", + "object2vec", + "pytorch", + "djl-lmi", + "mxnet", + "debugger", + "data-wrangler", + "spark", + "blazingtext", + "pytorch-neuron", + "forecasting-deepar", + "huggingface-neuron", + "ntm", + "neo-mxnet", + "image-classification", + "xgboost", + "autogluon", + "sparkml-serving", + "clarify", + "inferentia-pytorch", + "neo-tensorflow", + "huggingface-tei-cpu", + "huggingface", + "sagemaker-tritonserver", + "pytorch-smp", + "knn", + "linear-learner", + "model-monitor", + "ray-tensorflow", + "djl-neuronx", + "huggingface-llm-neuronx", + "image-classification-neo", + "lda", + "stabilityai", + "ray-pytorch", + "chainer", + "coach-mxnet", + "pca", + "sagemaker-geospatial", + "djl-tensorrtllm", + "huggingface-training-compiler", + "pytorch-training-compiler", + "vw", + "huggingface-neuronx", + "ipinsights", + "detailed-profiler", + "inferentia-tensorflow", + "semantic-segmentation", + "inferentia-mxnet", + "xgboost-neo", + "neo-pytorch", + "djl-deepspeed", + "djl-fastertransformer", + "sklearn", + "tensorflow", + "randomcutforest", + "huggingface-llm", + "factorization-machines", + "huggingface-tei", + "coach-tensorflow", + "seq2seq", + "kmeans", + "sagemaker-base-python", + ], +) +@patch("sagemaker.image_uris.config_for_framework") +@patch("sagemaker.image_uris.retrieve") +def test_retrieve_with_parameterized(mock_image_retrieve, mock_config_for_framework, framework): + try: + image_uris.retrieve( + framework=framework, + region="us-east-1", + version=None, + instance_type="ml.c4.xlarge", + image_scope="inference", + ) + except ValueError as e: + pytest.fail(e.value) diff --git a/tests/unit/sagemaker/modules/__init__.py b/tests/unit/sagemaker/modules/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/modules/local_core/test_local_container.py b/tests/unit/sagemaker/modules/local_core/test_local_container.py new file mode 100644 index 0000000000..88f6f81707 --- /dev/null +++ b/tests/unit/sagemaker/modules/local_core/test_local_container.py @@ -0,0 +1,179 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""LocalContainer Tests.""" +from __future__ import absolute_import +import os +import shutil + +import pytest + +from sagemaker.modules.configs import Channel, FileSystemDataSource +from sagemaker.modules.local_core.local_container import DOCKER_COMPOSE_FILENAME, _LocalContainer +from sagemaker_core.shapes import DataSource + +TRAINING_JOB_NAME = "job_name" +INSTANCE_TYPE = "ml.m5.xlarge" +TEST_IMAGE_NAME = "test_image" +CONTAINER_ROOT = os.getcwd() +CONTAINER_ENTRYPOINT = ["/bin/bash"] +CONTAINER_ARGUMENTS = [ + "-c", + ( + "chmod +x /opt/ml/input/data/sm_drivers/sm_train.sh " + + "&& /opt/ml/input/data/sm_drivers/sm_train.sh" + ), +] + + +@pytest.fixture +def input_data_config(): + return [ + Channel( + channel_name="local_input_channel", + data_source=DataSource( + file_system_data_source=FileSystemDataSource.model_construct( + directory_path=CONTAINER_ROOT, + file_system_type="EFS", + ), + ), + input_mode="File", + ) + ] + + +@pytest.fixture +def hyper_parameters(): + return { + "epochs": "1", + "optimizer": "adamw_torch", + } + + +@pytest.fixture +def shared_volumes(): + return [ + f"{CONTAINER_ROOT}/model:/opt/ml/model", + f"{CONTAINER_ROOT}:/opt/ml/input/data/local_input_channel", + ] + + +@pytest.fixture +def environment(): + return { + "SM_OUTPUT_DIR": "/opt/ml/output", + "SM_INPUT_CONFIG_DIR": "/opt/ml/input/config", + "SM_OUTPUT_DATA_DIR": "/opt/ml/output/data", + } + + +@pytest.fixture +def local_container(input_data_config, hyper_parameters, environment): + container = _LocalContainer( + training_job_name=TRAINING_JOB_NAME, + instance_type=INSTANCE_TYPE, + instance_count=2, + image=TEST_IMAGE_NAME, + container_root=CONTAINER_ROOT, + is_studio=False, + input_data_config=input_data_config, + hyper_parameters=hyper_parameters, + environment=environment, + sagemaker_session=None, + container_entrypoint=CONTAINER_ENTRYPOINT, + container_arguments=CONTAINER_ARGUMENTS, + ) + return container + + +def expected_host_config(shared_volumes, host): + return { + "entrypoint": [ + "/bin/bash", + "-c", + "chmod +x /opt/ml/input/data/sm_drivers/sm_train.sh && " + "/opt/ml/input/data/sm_drivers/sm_train.sh", + ], + "environment": [ + "SM_OUTPUT_DIR=/opt/ml/output", + "SM_INPUT_CONFIG_DIR=/opt/ml/input/config", + "SM_OUTPUT_DATA_DIR=/opt/ml/output/data", + ], + "image": "test_image", + "networks": { + "sagemaker-local": { + "aliases": [ + host, + ], + }, + }, + "volumes": shared_volumes + + [ + f"{CONTAINER_ROOT}/{host}/output:/opt/ml/output", + f"{CONTAINER_ROOT}/{host}/output/data:/opt/ml/output/data", + f"{CONTAINER_ROOT}/{host}/input:/opt/ml/input", + ], + } + + +def expected_compose_file(shared_volumes, hosts): + return { + "networks": { + "sagemaker-local": { + "name": "sagemaker-local", + }, + }, + "services": {host: expected_host_config(shared_volumes, host) for host in hosts}, + } + + +def test_write_config_files(local_container, input_data_config, hyper_parameters): + config_path = os.path.join(local_container.container_root, "algo-1", "input", "config") + os.makedirs(config_path, exist_ok=True) + local_container._write_config_files( + host="algo-1", + input_data_config=input_data_config, + hyper_parameters=hyper_parameters, + ) + + assert os.path.exists(os.path.join(config_path, "hyperparameters.json")) + assert os.path.exists(os.path.join(config_path, "resourceconfig.json")) + assert os.path.exists(os.path.join(config_path, "inputdataconfig.json")) + + shutil.rmtree(config_path) + + +def test_prepare_training_volumes( + local_container, input_data_config, hyper_parameters, shared_volumes +): + data_dir = os.path.join(local_container.container_root, "input", "data") + output = local_container._prepare_training_volumes( + data_dir, input_data_config, hyper_parameters + ) + + assert output == shared_volumes + + +def test_create_docker_host(local_container, environment, shared_volumes): + host = "algo-1" + output = local_container._create_docker_host(host, environment, shared_volumes) + assert output == expected_host_config(shared_volumes, host) + + +def test_generate_compose_file(local_container, environment, shared_volumes): + output = local_container._generate_compose_file(environment, shared_volumes) + + assert output == expected_compose_file(shared_volumes, local_container.hosts) + + docker_compose_path = os.path.join(local_container.container_root, DOCKER_COMPOSE_FILENAME) + assert os.path.exists(docker_compose_path) + os.remove(docker_compose_path) diff --git a/tests/unit/sagemaker/modules/test_utils.py b/tests/unit/sagemaker/modules/test_utils.py new file mode 100644 index 0000000000..efe43f1792 --- /dev/null +++ b/tests/unit/sagemaker/modules/test_utils.py @@ -0,0 +1,140 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utils Tests.""" +from __future__ import absolute_import + +import pytest + +from tests.unit import DATA_DIR +from sagemaker.modules.utils import ( + _is_valid_s3_uri, + _is_valid_path, + _get_unique_name, + _get_repo_name_from_image, +) + + +@pytest.mark.parametrize( + "test_case", + [ + { + "path": "s3://bucket/key", + "path_type": "Any", + "expected": True, + }, + { + "path": "s3://bucket/key", + "path_type": "File", + "expected": True, + }, + { + "path": "s3://bucket/key/", + "path_type": "Directory", + "expected": True, + }, + { + "path": "s3://bucket/key/", + "path_type": "File", + "expected": False, + }, + { + "path": "s3://bucket/key", + "path_type": "Directory", + "expected": False, + }, + { + "path": "/bucket/key", + "path_type": "Any", + "expected": False, + }, + ], +) +def test_is_valid_s3_uri(test_case): + assert _is_valid_s3_uri(test_case["path"], test_case["path_type"]) == test_case["expected"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "path": DATA_DIR, + "path_type": "Any", + "expected": True, + }, + { + "path": DATA_DIR, + "path_type": "Directory", + "expected": True, + }, + { + "path": f"{DATA_DIR}/dummy_input.txt", + "path_type": "File", + "expected": True, + }, + { + "path": f"{DATA_DIR}/dummy_input.txt", + "path_type": "Directory", + "expected": False, + }, + { + "path": f"{DATA_DIR}/non_existent", + "path_type": "Any", + "expected": False, + }, + ], +) +def test_is_valid_path(test_case): + assert _is_valid_path(test_case["path"], test_case["path_type"]) == test_case["expected"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "base": "test", + "max_length": 5, + }, + { + "base": "1111111111" * 7, + "max_length": None, + }, + ], +) +def test_get_unique_name(test_case): + assert ( + len(_get_unique_name(test_case["base"], test_case.get("max_length"))) + <= test_case["max_length"] + if test_case.get("max_length") + else 63 + ) + + +@pytest.mark.parametrize( + "test_case", + [ + { + "image": "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:latest", + "expected": "my-custom-image", + }, + { + "image": "my-custom-image:latest", + "expected": "my-custom-image", + }, + { + "image": "public.ecr.aws/docker/library/my-custom-image:latest", + "expected": "my-custom-image", + }, + ], +) +def test_get_repo_name_from_image(test_case): + assert _get_repo_name_from_image(test_case["image"]) == test_case["expected"] diff --git a/tests/unit/sagemaker/modules/train/__init__.py b/tests/unit/sagemaker/modules/train/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/modules/train/container_drivers/__init__.py b/tests/unit/sagemaker/modules/train/container_drivers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py new file mode 100644 index 0000000000..30d6dfdf6c --- /dev/null +++ b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py @@ -0,0 +1,186 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Enviornment Variable Script Unit Tests.""" +from __future__ import absolute_import + +import os +import io +import logging + +from unittest.mock import patch + +from sagemaker.modules.train.container_drivers.scripts.environment import ( + set_env, + log_key_value, + log_env_variables, + mask_sensitive_info, + HIDDEN_VALUE, +) +from sagemaker.modules.train.container_drivers.utils import safe_serialize, safe_deserialize + +RESOURCE_CONFIG = dict( + current_host="algo-1", + hosts=["algo-1", "algo-2", "algo-3"], + current_group_name="train1", + current_instance_type="ml.p3.16xlarge", + instance_groups=[ + dict( + instance_group_name="train1", + instance_type="ml.p3.16xlarge", + hosts=["algo-1", "algo-2"], + ), + dict( + instance_group_name="train2", + instance_type="ml.p3.8xlarge", + hosts=["algo-3"], + ), + ], + network_interface_name="eth0", +) + +INPUT_DATA_CONFIG = { + "train": { + "ContentType": "trainingContentType", + "TrainingInputMode": "File", + "S3DistributionType": "FullyReplicated", + "RecordWrapperType": "None", + }, + "validation": { + "TrainingInputMode": "File", + "S3DistributionType": "FullyReplicated", + "RecordWrapperType": "None", + }, +} + +USER_HYPERPARAMETERS = { + "batch_size": 32, + "learning_rate": 0.001, + "hosts": ["algo-1", "algo-2"], + "mp_parameters": { + "microbatches": 2, + "partitions": 2, + "pipeline": "interleaved", + "optimize": "memory", + "horovod": True, + }, +} + +OUTPUT_FILE = os.path.join(os.path.dirname(__file__), "sm_training.env") + +# flake8: noqa +EXPECTED_ENV = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_LOG_LEVEL='20' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CHANNEL_TRAIN='/opt/ml/input/data/train' +export SM_CHANNEL_VALIDATION='/opt/ml/input/data/validation' +export SM_CHANNELS='["train", "validation"]' +export SM_HP_BATCH_SIZE='32' +export SM_HP_LEARNING_RATE='0.001' +export SM_HP_HOSTS='["algo-1", "algo-2"]' +export SM_HP_MP_PARAMETERS='{"microbatches": 2, "partitions": 2, "pipeline": "interleaved", "optimize": "memory", "horovod": true}' +export SM_HPS='{"batch_size": 32, "learning_rate": 0.001, "hosts": ["algo-1", "algo-2"], "mp_parameters": {"microbatches": 2, "partitions": 2, "pipeline": "interleaved", "optimize": "memory", "horovod": true}}' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.p3.16xlarge' +export SM_HOSTS='["algo-1", "algo-2", "algo-3"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='3' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='8' +export SM_NUM_GPUS='0' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3"], "current_group_name": "train1", "current_instance_type": "ml.p3.16xlarge", "instance_groups": [{"instance_group_name": "train1", "instance_type": "ml.p3.16xlarge", "hosts": ["algo-1", "algo-2"]}, {"instance_group_name": "train2", "instance_type": "ml.p3.8xlarge", "hosts": ["algo-3"]}], "network_interface_name": "eth0"}' +export SM_INPUT_DATA_CONFIG='{"train": {"ContentType": "trainingContentType", "TrainingInputMode": "File", "S3DistributionType": "FullyReplicated", "RecordWrapperType": "None"}, "validation": {"TrainingInputMode": "File", "S3DistributionType": "FullyReplicated", "RecordWrapperType": "None"}}' +export SM_TRAINING_ENV='{"channel_input_dirs": {"train": "/opt/ml/input/data/train", "validation": "/opt/ml/input/data/validation"}, "current_host": "algo-1", "current_instance_type": "ml.p3.16xlarge", "hosts": ["algo-1", "algo-2", "algo-3"], "master_addr": "algo-1", "master_port": 7777, "hyperparameters": {"batch_size": 32, "learning_rate": 0.001, "hosts": ["algo-1", "algo-2"], "mp_parameters": {"microbatches": 2, "partitions": 2, "pipeline": "interleaved", "optimize": "memory", "horovod": true}}, "input_data_config": {"train": {"ContentType": "trainingContentType", "TrainingInputMode": "File", "S3DistributionType": "FullyReplicated", "RecordWrapperType": "None"}, "validation": {"TrainingInputMode": "File", "S3DistributionType": "FullyReplicated", "RecordWrapperType": "None"}}, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "log_level": 20, "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 8, "num_gpus": 0, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3"], "current_group_name": "train1", "current_instance_type": "ml.p3.16xlarge", "instance_groups": [{"instance_group_name": "train1", "instance_type": "ml.p3.16xlarge", "hosts": ["algo-1", "algo-2"]}, {"instance_group_name": "train2", "instance_type": "ml.p3.8xlarge", "hosts": ["algo-3"]}], "network_interface_name": "eth0"}}' +""" + + +@patch("sagemaker.modules.train.container_drivers.scripts.environment.num_cpus", return_value=8) +@patch("sagemaker.modules.train.container_drivers.scripts.environment.num_gpus", return_value=0) +@patch("sagemaker.modules.train.container_drivers.scripts.environment.num_neurons", return_value=0) +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize", + side_effect=safe_serialize, +) +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.safe_deserialize", + side_effect=safe_deserialize, +) +def test_set_env( + mock_safe_deserialize, mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + set_env( + resource_config=RESOURCE_CONFIG, + input_data_config=INPUT_DATA_CONFIG, + hyperparameters_config=USER_HYPERPARAMETERS, + output_file=OUTPUT_FILE, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(OUTPUT_FILE, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + os.remove(OUTPUT_FILE) + assert not os.path.exists(OUTPUT_FILE) + + +@patch.dict(os.environ, {"SECRET_TOKEN": "122345678", "CLEAR_DATA": "123456789"}, clear=True) +def test_log_env_variables(): + log_stream = io.StringIO() + handler = logging.StreamHandler(log_stream) + + logger = logging.getLogger("sagemaker.modules.train.container_drivers.scripts.environment") + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + env_vars = { + "SM_MODEL_DIR": "/opt/ml/model", + "SM_INPUT_DIR": "/opt/ml/input", + "SM_HPS": {"batch_size": 32, "learning_rate": 0.001, "access_token": "123456789"}, + "SM_HP_BATCH_SIZE": 32, + "SM_HP_LEARNING_RATE": 0.001, + "SM_HP_ACCESS_TOKEN": "123456789", + } + log_env_variables(env_vars_dict=env_vars) + + log_output = log_stream.getvalue() + + assert f"SECRET_TOKEN={HIDDEN_VALUE}" in log_output + assert "CLEAR_DATA=123456789" in log_output + assert "SM_MODEL_DIR=/opt/ml/model" in log_output + assert ( + f'SM_HPS={{"batch_size": 32, "learning_rate": 0.001, "access_token": "{HIDDEN_VALUE}"}}' + in log_output + ) + assert "SM_HP_BATCH_SIZE=32" in log_output + assert "SM_HP_LEARNING_RATE=0.001" in log_output + assert f"SM_HP_ACCESS_TOKEN={HIDDEN_VALUE}" in log_output + + +def _remove_extra_lines(string): + """Removes extra blank lines from a string.""" + return "\n".join([line for line in string.splitlines() if line.strip()]) diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py new file mode 100644 index 0000000000..a1a84da1ab --- /dev/null +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py @@ -0,0 +1,151 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""MPI Driver Unit Tests.""" +from __future__ import absolute_import + +import os +import sys + +from unittest.mock import patch, MagicMock + +sys.modules["utils"] = MagicMock() +sys.modules["mpi_utils"] = MagicMock() + +from sagemaker.modules.train.container_drivers import mpi_driver # noqa: E402 + + +DUMMY_MPI_COMMAND = [ + "mpirun", + "--host", + "algo-1,algo-2", + "-np", + "2", + "--verbose", + "-x", + "ENV_VAR1", + "python", + "-m", + "mpi4py", + "-m", + "script.py", +] + +DUMMY_SOURCE_CODE = { + "source_code": "source_code", + "entry_script": "script.py", +} +DUMMY_DISTRIBUTED = { + "_type": "mpi", + "process_count_per_node": 2, + "mpi_additional_options": [ + "--verbose", + "-x", + "ENV_VAR1", + ], +} + + +@patch.dict( + os.environ, + { + "SM_CURRENT_HOST": "algo-2", + "SM_HOSTS": '["algo-1", "algo-2"]', + "SM_MASTER_ADDR": "algo-1", + "SM_HOST_COUNT": "2", + }, +) +@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_json") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_master_node") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_worker_node") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.hyperparameters_to_cli_args") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_mpirun_command") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.execute_commands") +def test_mpi_driver_worker( + mock_execute_commands, + mock_get_mpirun_command, + mock_hyperparameters_to_cli_args, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, + mock_start_sshd_daemon, + mock_write_env_vars_to_file, + mock_read_source_code_json, + mock_read_distributed_json, +): + mock_hyperparameters_to_cli_args.return_value = [] + mock_read_source_code_json.return_value = DUMMY_SOURCE_CODE + mock_read_distributed_json.return_value = DUMMY_DISTRIBUTED + + mpi_driver.main() + + mock_write_env_vars_to_file.assert_called_once() + mock_start_sshd_daemon.assert_called_once() + mock_bootstrap_worker_node.assert_called_once() + + mock_bootstrap_master_node.assert_not_called() + mock_get_mpirun_command.assert_not_called() + mock_execute_commands.assert_not_called() + + +@patch.dict( + os.environ, + { + "SM_CURRENT_HOST": "algo-1", + "SM_HOSTS": '["algo-1", "algo-2"]', + "SM_MASTER_ADDR": "algo-1", + "SM_HOST_COUNT": "2", + }, +) +@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_json") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_master_node") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_worker_node") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_process_count") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.hyperparameters_to_cli_args") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_mpirun_command") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.execute_commands") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_status_file_to_workers") +def test_mpi_driver_master( + mock_write_status_file_to_workers, + mock_execute_commands, + mock_get_mpirun_command, + mock_hyperparameters_to_cli_args, + mock_get_process_count, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, + mock_start_sshd_daemon, + mock_write_env_vars_to_file, + mock_read_source_code_config_json, + mock_read_distributed_json, +): + mock_hyperparameters_to_cli_args.return_value = [] + mock_read_source_code_config_json.return_value = DUMMY_SOURCE_CODE + mock_read_distributed_json.return_value = DUMMY_DISTRIBUTED + mock_get_mpirun_command.return_value = DUMMY_MPI_COMMAND + mock_get_process_count.return_value = 2 + mock_execute_commands.return_value = (0, "") + + mpi_driver.main() + + mock_write_env_vars_to_file.assert_called_once() + mock_start_sshd_daemon.assert_called_once() + mock_bootstrap_master_node.assert_called_once() + mock_get_mpirun_command.assert_called_once() + mock_execute_commands.assert_called_once_with(DUMMY_MPI_COMMAND) + mock_write_status_file_to_workers.assert_called_once() + + mock_bootstrap_worker_node.assert_not_called() diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py new file mode 100644 index 0000000000..4cff07a0c0 --- /dev/null +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py @@ -0,0 +1,168 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Torchrun Driver Unit Tests.""" +from __future__ import absolute_import + +import os +import sys + +from unittest.mock import patch, MagicMock + +sys.modules["utils"] = MagicMock() + +from sagemaker.modules.train.container_drivers import torchrun_driver # noqa: E402 + +DUMMY_SOURCE_CODE = { + "source_code": "source_code", + "entry_script": "script.py", +} + +DUMMY_distributed = {"_type": "torchrun", "process_count_per_node": 2} + + +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.get_python_executable", + return_value="python3", +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) +) +def test_get_base_pytorch_command_torchrun(mock_pytorch_version, mock_get_python_executable): + assert torchrun_driver.get_base_pytorch_command() == ["torchrun"] + + +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.get_python_executable", + return_value="python3", +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(1, 8) +) +def test_get_base_pytorch_command_torch_distributed_launch( + mock_pytorch_version, mock_get_python_executable +): + assert torchrun_driver.get_base_pytorch_command() == ( + ["python3", "-m", "torch.distributed.launch"] + ) + + +@patch.dict( + os.environ, + { + "SM_CURRENT_INSTANCE_TYPE": "ml.p4d.24xlarge", + "SM_NETWORK_INTERFACE_NAME": "eth0", + "SM_HOST_COUNT": "1", + }, +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.USER_CODE_PATH", + "/opt/ml/input/data/code", +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.get_process_count", return_value=2 +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.get_base_pytorch_command", + return_value=["torchrun"], +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.read_source_code_json", + return_value=DUMMY_SOURCE_CODE, +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_json", + return_value=DUMMY_distributed, +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.hyperparameters_to_cli_args", + return_value=[], +) +def test_create_commands_single_node( + mock_hyperparameters_to_cli_args, + mock_read_distributed_json, + mock_read_source_code_json, + mock_get_base_pytorch_command, + mock_pytorch_version, + mock_get_process_count, +): + expected_command = [ + "torchrun", + "--nnodes=1", + "--nproc_per_node=2", + "/opt/ml/input/data/code/script.py", + ] + + command = torchrun_driver.create_commands() + assert command == expected_command + + +@patch.dict( + os.environ, + { + "SM_CURRENT_INSTANCE_TYPE": "ml.p4d.24xlarge", + "SM_NETWORK_INTERFACE_NAME": "eth0", + "SM_HOST_COUNT": "2", + "SM_MASTER_ADDR": "algo-1", + "SM_MASTER_PORT": "7777", + "SM_CURRENT_HOST_RANK": "0", + }, +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.USER_CODE_PATH", + "/opt/ml/input/data/code", +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.get_process_count", return_value=2 +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.get_base_pytorch_command", + return_value=["torchrun"], +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.read_source_code_json", + return_value=DUMMY_SOURCE_CODE, +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_json", + return_value=DUMMY_distributed, +) +@patch( + "sagemaker.modules.train.container_drivers.torchrun_driver.hyperparameters_to_cli_args", + return_value=[], +) +def test_create_commands_multi_node( + mock_hyperparameters_to_cli_args, + mock_read_distributed_json, + mock_read_source_code_json, + mock_get_base_pytorch_command, + mock_pytorch_version, + mock_get_process_count, +): + expected_command = [ + "torchrun", + "--nnodes=2", + "--nproc_per_node=2", + "--master_addr=algo-1", + "--master_port=7777", + "--node_rank=0", + "/opt/ml/input/data/code/script.py", + ] + + command = torchrun_driver.create_commands() + assert command == expected_command diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py new file mode 100644 index 0000000000..aba97996b0 --- /dev/null +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py @@ -0,0 +1,121 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Container Utils Unit Tests.""" +from __future__ import absolute_import + +from sagemaker.modules.train.container_drivers.utils import ( + safe_deserialize, + safe_serialize, + hyperparameters_to_cli_args, +) + +SM_HPS = { + "boolean": "true", + "dict": '{"string":"value","integer":3,"list":[1,2,3],"dict":{"key":"value"},"boolean":true}', + "float": "3.14", + "integer": "1", + "list": "[1,2,3]", + "string": "Hello World", +} + + +def test_hyperparameters_to_cli_args(): + args = hyperparameters_to_cli_args(SM_HPS) + + assert args == [ + "--boolean", + "true", + "--dict", + '{"string": "value", "integer": 3, "list": [1, 2, 3], "dict": {"key": "value"}, "boolean": true}', + "--float", + "3.14", + "--integer", + "1", + "--list", + "[1, 2, 3]", + "--string", + "Hello World", + ] + + +def test_safe_deserialize_not_a_string(): + assert safe_deserialize(123) == 123 + assert safe_deserialize([1, 2, 3]) == [1, 2, 3] + assert safe_deserialize({"key": "value"}) == {"key": "value"} + + +def test_safe_deserialize_boolean_strings(): + assert safe_deserialize("true") is True + assert safe_deserialize("false") is False + assert safe_deserialize("True") is True + assert safe_deserialize("False") is False + + +def test_safe_deserialize_valid_json_string(): + json_data = '{"key": "value", "number": 123, "boolean": true}' + expected_output = {"key": "value", "number": 123, "boolean": True} + assert safe_deserialize(json_data) == expected_output + + assert safe_deserialize("Hello World") == "Hello World" + assert safe_deserialize("12345") == 12345 + + assert safe_deserialize("3.14") == 3.14 + assert safe_deserialize("[1,2,3]") == [1, 2, 3] + + +def test_safe_deserialize_invalid_json_string(): + invalid_json = '{"key": value}' # Missing quotes around value so not valid json + assert safe_deserialize(invalid_json) == invalid_json + + +def test_safe_deserialize_null_string(): + assert safe_deserialize("null") == None # noqa: E711 + assert safe_deserialize("None") == "None" + + +def test_safe_serialize_string(): + assert safe_serialize("Hello World") == "Hello World" + assert safe_serialize("12345") == "12345" + assert safe_serialize("true") == "true" + + +def test_safe_serialize_serializable_data(): + assert safe_serialize({"key": "value", "number": 123, "boolean": True}) == ( + '{"key": "value", "number": 123, "boolean": true}' + ) + assert safe_serialize([1, 2, 3]) == "[1, 2, 3]" + assert safe_serialize(123) == "123" + assert safe_serialize(3.14) == "3.14" + assert safe_serialize(True) == "true" + assert safe_serialize(False) == "false" + assert safe_serialize(None) == "null" + + +def test_safe_serialize_custom_object(): + class CustomObject: + def __str__(self): + return "CustomObject" + + obj = CustomObject() + assert safe_serialize(obj) == "CustomObject" + + +def test_safe_serialize_invalid_data(): + invalid_data = {"key": set([1, 2, 3])} # Sets are not JSON serializable + assert safe_serialize(invalid_data) == str(invalid_data) + + +def test_safe_serialize_empty_data(): + assert safe_serialize("") == "" + assert safe_serialize([]) == "[]" + assert safe_serialize({}) == "{}" diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/__init__.py b/tests/unit/sagemaker/modules/train/sm_recipes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py new file mode 100644 index 0000000000..66eafab4f0 --- /dev/null +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -0,0 +1,180 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utility functions for SageMaker training recipes Tests.""" +from __future__ import absolute_import + +import pytest +from unittest.mock import patch + +import yaml +from urllib.request import urlretrieve +from tempfile import NamedTemporaryFile + +from sagemaker.modules.train.sm_recipes.utils import ( + _load_base_recipe, + _get_args_from_recipe, + _load_recipes_cfg, + _configure_gpu_args, + _configure_trainium_args, +) +from sagemaker.modules.utils import _run_clone_command_silent +from sagemaker.modules.configs import Compute + + +@pytest.fixture(scope="module") +def training_recipes_cfg(): + return _load_recipes_cfg() + + +@pytest.fixture(scope="module") +def temporary_recipe(): + data = { + "trainer": {"num_nodes": 2, "max_epochs": 10}, + "model": {"model_type": "llama_v3", "num_classes": 10, "num_layers": 10}, + } + with NamedTemporaryFile(suffix=".yaml", delete=False) as f: + with open(f.name, "w") as file: + yaml.dump(data, file) + yield f.name + + +def test_load_base_recipe_with_overrides(temporary_recipe, training_recipes_cfg): + expected_epochs = 20 + expected_layers = 15 + + recipe_overrides = { + "trainer": {"max_epochs": expected_epochs}, + "model": {"num_layers": expected_layers}, + } + + load_recipe = _load_base_recipe( + training_recipe=temporary_recipe, + recipe_overrides=recipe_overrides, + training_recipes_cfg=training_recipes_cfg, + ) + + assert ( + load_recipe["trainer"]["max_epochs"] == expected_epochs + and load_recipe["model"]["num_layers"] == expected_layers + ) + + +@pytest.mark.parametrize( + "test_case", + [ + {"recipe_type": "local"}, + {"recipe_type": "sagemaker"}, + {"recipe_type": "url"}, + {"recipe_type": "not_found"}, + ], +) +@patch("sagemaker.modules.train.sm_recipes.utils.urlretrieve") +@patch("sagemaker.modules.train.sm_recipes.utils._run_clone_command_silent") +def test_load_base_recipe_types( + mock_clone, mock_retrieve, temporary_recipe, training_recipes_cfg, test_case +): + recipe_type = test_case["recipe_type"] + + if recipe_type == "not_found": + with pytest.raises(ValueError): + _load_base_recipe( + training_recipe="not_found", + recipe_overrides=None, + training_recipes_cfg=training_recipes_cfg, + ) + + if recipe_type == "local": + load_recipe = _load_base_recipe( + training_recipe=temporary_recipe, + recipe_overrides=None, + training_recipes_cfg=training_recipes_cfg, + ) + assert load_recipe is not None + assert "trainer" in load_recipe + + if recipe_type == "sagemaker": + mock_clone.side_effect = _run_clone_command_silent + load_recipe = _load_base_recipe( + training_recipe="training/llama/p4_hf_llama3_70b_seq8k_gpu", + recipe_overrides=None, + training_recipes_cfg=training_recipes_cfg, + ) + assert load_recipe is not None + assert "trainer" in load_recipe + assert mock_clone.call_args.args[0] == training_recipes_cfg.get("launcher_repo") + + if recipe_type == "url": + url = "https://raw.githubusercontent.com/aws-neuron/neuronx-distributed-training/refs/heads/main/examples/conf/hf_llama3_8B_config.yaml" # noqa + mock_retrieve.side_effect = urlretrieve + load_recipe = _load_base_recipe( + training_recipe=url, + recipe_overrides=None, + training_recipes_cfg=training_recipes_cfg, + ) + assert load_recipe is not None + assert "trainer" in load_recipe + assert mock_retrieve.call_args.args[0] == url + + +@pytest.mark.parametrize( + "test_case", + [ + {"type": "gpu", "instance_type": "ml.p4d.24xlarge"}, + {"type": "trn", "instance_type": "ml.trn1.32xlarge"}, + {"type": "cpu", "instance_type": "ml.c5.4xlarge"}, + ], +) +@patch("sagemaker.modules.train.sm_recipes.utils._configure_gpu_args") +@patch("sagemaker.modules.train.sm_recipes.utils._configure_trainium_args") +def test_get_args_from_recipe_compute( + mock_trainium_args, mock_gpu_args, temporary_recipe, test_case +): + compute = Compute(instance_type=test_case["instance_type"]) + if test_case["type"] == "gpu": + mock_gpu_args.side_effect = _configure_gpu_args + + args = _get_args_from_recipe( + training_recipe=temporary_recipe, + compute=compute, + region_name="us-west-2", + recipe_overrides=None, + requirements=None, + ) + assert mock_gpu_args.call_count == 1 + assert mock_trainium_args.call_count == 0 + + if test_case["type"] == "trn": + mock_trainium_args.side_effect = _configure_trainium_args + + args = _get_args_from_recipe( + training_recipe=temporary_recipe, + compute=compute, + region_name="us-west-2", + recipe_overrides=None, + requirements=None, + ) + assert mock_gpu_args.call_count == 0 + assert mock_trainium_args.call_count == 1 + + if test_case["type"] == "cpu": + with pytest.raises(ValueError): + args = _get_args_from_recipe( + training_recipe=temporary_recipe, + compute=compute, + region_name="us-west-2", + recipe_overrides=None, + requirements=None, + ) + assert mock_gpu_args.call_count == 0 + assert mock_trainium_args.call_count == 0 + assert args is None diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py new file mode 100644 index 0000000000..049ebaa9c4 --- /dev/null +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -0,0 +1,1059 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""ModelTrainer Tests.""" +from __future__ import absolute_import + +import shutil +import tempfile +import json +import os +import pytest +from unittest.mock import patch, MagicMock, ANY + +from sagemaker import image_uris +from sagemaker_core.main.resources import TrainingJob +from sagemaker_core.main.shapes import ( + ResourceConfig, + VpcConfig, + AlgorithmSpecification, +) + +from sagemaker.config import SAGEMAKER, PYTHON_SDK, MODULES +from sagemaker.config.config_schema import ( + MODEL_TRAINER, + _simple_path, + TRAINING_JOB_RESOURCE_CONFIG_PATH, +) +from sagemaker.modules import Session +from sagemaker.modules.train.model_trainer import ModelTrainer, Mode +from sagemaker.modules.constants import ( + DEFAULT_INSTANCE_TYPE, + DISTRIBUTED_JSON, + SOURCE_CODE_JSON, + TRAIN_SCRIPT, +) +from sagemaker.modules.configs import ( + Compute, + StoppingCondition, + RetryStrategy, + OutputDataConfig, + SourceCode, + RemoteDebugConfig, + TensorBoardOutputConfig, + InfraCheckConfig, + SessionChainingConfig, + InputData, + Networking, + TrainingImageConfig, + TrainingRepositoryAuthConfig, + CheckpointConfig, + Tag, + S3DataSource, + FileSystemDataSource, + Channel, + DataSource, +) +from sagemaker.modules.distributed import Torchrun, SMP, MPI +from sagemaker.modules.train.sm_recipes.utils import _load_recipes_cfg +from sagemaker.modules.templates import EXEUCTE_TORCHRUN_DRIVER, EXECUTE_MPI_DRIVER +from tests.unit import DATA_DIR + +DEFAULT_BASE_NAME = "dummy-image-job" +DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest" +DEFAULT_BUCKET = "sagemaker-us-west-2-000000000000" +DEFAULT_ROLE = "arn:aws:iam::000000000000:role/test-role" +DEFAULT_BUCKET_PREFIX = "sample-prefix" +DEFAULT_REGION = "us-west-2" +DEFAULT_SOURCE_DIR = f"{DATA_DIR}/modules/script_mode" +DEFAULT_COMPUTE_CONFIG = Compute(instance_type=DEFAULT_INSTANCE_TYPE, instance_count=1) +DEFAULT_OUTPUT_DATA_CONFIG = OutputDataConfig( + s3_output_path=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{DEFAULT_BASE_NAME}", + compression_type="GZIP", + kms_key_id=None, +) +DEFAULT_STOPPING_CONDITION = StoppingCondition( + max_runtime_in_seconds=3600, + max_pending_time_in_seconds=None, + max_wait_time_in_seconds=None, +) +DEFAULT_SOURCE_CODE = SourceCode( + source_dir=DEFAULT_SOURCE_DIR, + entry_script="custom_script.py", +) +UNSUPPORTED_SOURCE_CODE = SourceCode( + entry_script="train.py", +) +DEFAULT_ENTRYPOINT = ["/bin/bash"] +DEFAULT_ARGUMENTS = [ + "-c", + ( + "chmod +x /opt/ml/input/data/sm_drivers/sm_train.sh " + + "&& /opt/ml/input/data/sm_drivers/sm_train.sh" + ), +] + + +@pytest.fixture(scope="module", autouse=True) +def modules_session(): + with patch("sagemaker.modules.Session", spec=Session) as session_mock: + session_instance = session_mock.return_value + session_instance.default_bucket.return_value = DEFAULT_BUCKET + session_instance.get_caller_identity_arn.return_value = DEFAULT_ROLE + session_instance.default_bucket_prefix = DEFAULT_BUCKET_PREFIX + session_instance.boto_session = MagicMock(spec="boto3.session.Session") + session_instance.boto_region_name = DEFAULT_REGION + yield session_instance + + +@pytest.fixture +def model_trainer(): + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE_CONFIG, + stopping_condition=DEFAULT_STOPPING_CONDITION, + output_data_config=DEFAULT_OUTPUT_DATA_CONFIG, + ) + return trainer + + +@pytest.mark.parametrize( + "test_case", + [ + { + "init_params": {}, + "should_throw": True, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + }, + "should_throw": False, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "algorithm_name": "dummy-arn", + }, + "should_throw": True, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "source_code": UNSUPPORTED_SOURCE_CODE, + }, + "should_throw": True, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "source_code": DEFAULT_SOURCE_CODE, + }, + "should_throw": False, + }, + ], + ids=[ + "no_params", + "training_image_and_algorithm_name", + "only_training_image", + "unsupported_source_code", + "supported_source_code", + ], +) +def test_model_trainer_param_validation(test_case, modules_session): + if test_case["should_throw"]: + with pytest.raises(ValueError): + ModelTrainer(**test_case["init_params"], sagemaker_session=modules_session) + else: + trainer = ModelTrainer(**test_case["init_params"], sagemaker_session=modules_session) + assert trainer is not None + assert trainer.training_image == DEFAULT_IMAGE + assert trainer.compute == DEFAULT_COMPUTE_CONFIG + assert trainer.output_data_config == DEFAULT_OUTPUT_DATA_CONFIG + assert trainer.stopping_condition == DEFAULT_STOPPING_CONDITION + assert trainer.base_job_name == DEFAULT_BASE_NAME + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_train_with_default_params(mock_training_job, model_trainer): + model_trainer.train() + + mock_training_job.create.assert_called_once() + + training_job_instance = mock_training_job.create.return_value + training_job_instance.wait.assert_called_once_with(logs=True) + + +@pytest.mark.parametrize( + "default_config", + [ + { + "path_name": "sourceCode", + "config_value": {"command": "echo 'Hello World' && env"}, + }, + { + "path_name": "compute", + "config_value": {"volume_size_in_gb": 45}, + }, + { + "path_name": "networking", + "config_value": { + "enable_network_isolation": True, + "security_group_ids": ["sg-1"], + "subnets": ["subnet-1"], + }, + }, + { + "path_name": "stoppingCondition", + "config_value": {"max_runtime_in_seconds": 15}, + }, + { + "path_name": "trainingImageConfig", + "config_value": {"training_repository_access_mode": "private"}, + }, + { + "path_name": "outputDataConfig", + "config_value": {"s3_output_path": "Sample S3 path"}, + }, + { + "path_name": "checkpointConfig", + "config_value": {"s3_uri": "sample uri"}, + }, + ], +) +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +@patch("sagemaker.modules.train.model_trainer.resolve_value_from_config") +@patch("sagemaker.modules.train.model_trainer.ModelTrainer.create_input_data_channel") +def test_train_with_intelligent_defaults( + mock_create_input_data_channel, + mock_resolve_value_from_config, + mock_training_job, + default_config, + model_trainer, +): + mock_resolve_value_from_config.side_effect = lambda **kwargs: ( + default_config["config_value"] + if kwargs["config_path"] + == _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, MODEL_TRAINER, default_config["path_name"]) + else None + ) + + model_trainer.train() + + mock_training_job.create.assert_called_once() + + training_job_instance = mock_training_job.create.return_value + training_job_instance.wait.assert_called_once_with(logs=True) + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +@patch("sagemaker.modules.train.model_trainer.resolve_value_from_config") +def test_train_with_intelligent_defaults_training_job_space( + mock_resolve_value_from_config, mock_training_job, model_trainer +): + mock_resolve_value_from_config.side_effect = lambda **kwargs: ( + { + "instanceType": DEFAULT_INSTANCE_TYPE, + "instanceCount": 1, + "volumeSizeInGB": 30, + } + if kwargs["config_path"] == TRAINING_JOB_RESOURCE_CONFIG_PATH + else None + ) + + model_trainer.train() + + mock_training_job.create.assert_called_once_with( + training_job_name=ANY, + algorithm_specification=ANY, + hyper_parameters={}, + input_data_config=[], + resource_config=ResourceConfig( + volume_size_in_gb=30, + instance_type="ml.m5.xlarge", + instance_count=1, + volume_kms_key_id=None, + keep_alive_period_in_seconds=None, + instance_groups=None, + ), + vpc_config=None, + session=ANY, + role_arn="arn:aws:iam::000000000000:" "role/test-role", + tags=None, + stopping_condition=StoppingCondition( + max_runtime_in_seconds=3600, + max_wait_time_in_seconds=None, + max_pending_time_in_seconds=None, + ), + output_data_config=OutputDataConfig( + s3_output_path="s3://" + "sagemaker-us-west-2" + "-000000000000/" + "sample-prefix/" + "dummy-image-job", + kms_key_id=None, + compression_type="GZIP", + ), + checkpoint_config=None, + environment=None, + enable_managed_spot_training=None, + enable_inter_container_traffic_encryption=None, + enable_network_isolation=None, + remote_debug_config=None, + tensor_board_output_config=None, + retry_strategy=None, + infra_check_config=None, + session_chaining_config=None, + ) + + training_job_instance = mock_training_job.create.return_value + training_job_instance.wait.assert_called_once_with(logs=True) + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +@patch.object(ModelTrainer, "_get_input_data_config") +def test_train_with_input_data_channels(mock_get_input_config, mock_training_job, model_trainer): + train_data = InputData(channel_name="train", data_source="train/dir") + test_data = InputData(channel_name="test", data_source="test/dir") + mock_input_data_config = [train_data, test_data] + + model_trainer.train(input_data_config=mock_input_data_config) + + mock_get_input_config.assert_called_once_with(mock_input_data_config, ANY) + mock_training_job.create.assert_called_once() + + +@pytest.mark.parametrize( + "test_case", + [ + { + "channel_name": "test", + "data_source": DATA_DIR, + "valid": True, + }, + { + "channel_name": "test", + "data_source": f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}-job/input/test", + "valid": True, + }, + { + "channel_name": "test", + "data_source": S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}-job/input/test", + s3_data_distribution_type="FullyReplicated", + ), + "valid": True, + }, + { + "channel_name": "test", + "data_source": FileSystemDataSource( + file_system_id="fs-000000000000", + file_system_access_mode="ro", + file_system_type="EFS", + directory_path="/data/test", + ), + "valid": True, + }, + { + "channel_name": "test", + "data_source": "fake/path", + "valid": False, + }, + ], + ids=[ + "valid_local_path", + "valid_s3_path", + "valid_s3_data_source", + "valid_file_system_data_source", + "invalid_path", + ], +) +@patch("sagemaker.modules.train.model_trainer.Session.upload_data") +@patch("sagemaker.modules.train.model_trainer.Session.default_bucket") +def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_trainer, test_case): + expected_s3_uri = f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}-job/input/test" + mock_upload_data.return_value = expected_s3_uri + mock_default_bucket.return_value = DEFAULT_BUCKET + if not test_case["valid"]: + with pytest.raises(ValueError): + model_trainer.create_input_data_channel( + test_case["channel_name"], test_case["data_source"] + ) + else: + channel = model_trainer.create_input_data_channel( + test_case["channel_name"], test_case["data_source"] + ) + assert channel.channel_name == test_case["channel_name"] + if isinstance(test_case["data_source"], S3DataSource): + assert channel.data_source.s3_data_source == test_case["data_source"] + elif isinstance(test_case["data_source"], FileSystemDataSource): + assert channel.data_source.file_system_data_source == test_case["data_source"] + else: + assert channel.data_source.s3_data_source.s3_uri == expected_s3_uri + + +@pytest.mark.parametrize( + "test_case", + [ + { + "source_code": DEFAULT_SOURCE_CODE, + "distributed": Torchrun(), + "expected_template": EXEUCTE_TORCHRUN_DRIVER, + "expected_hyperparameters": {}, + }, + { + "source_code": DEFAULT_SOURCE_CODE, + "distributed": Torchrun( + smp=SMP( + hybrid_shard_degree=3, + sm_activation_offloading=True, + allow_empty_shards=True, + tensor_parallel_degree=5, + ) + ), + "expected_template": EXEUCTE_TORCHRUN_DRIVER, + "expected_hyperparameters": { + "mp_parameters": json.dumps( + { + "hybrid_shard_degree": 3, + "sm_activation_offloading": True, + "allow_empty_shards": True, + "tensor_parallel_degree": 5, + } + ), + }, + }, + { + "source_code": DEFAULT_SOURCE_CODE, + "distributed": MPI( + custom_mpi_options=["-x", "VAR1", "-x", "VAR2"], + ), + "expected_template": EXECUTE_MPI_DRIVER, + "expected_hyperparameters": {}, + }, + ], + ids=[ + "torchrun", + "torchrun_smp", + "mpi", + ], +) +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory") +@patch("sagemaker.modules.train.model_trainer.resolve_value_from_config") +def test_train_with_distributed_config( + mock_resolve_value_from_config, + mock_tmp_dir, + mock_training_job, + test_case, + request, + modules_session, +): + mock_resolve_value_from_config.return_value = None + modules_session.upload_data.return_value = ( + f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}-job/input/test" + ) + + tmp_dir = tempfile.TemporaryDirectory() + tmp_dir._cleanup = False + tmp_dir.cleanup = lambda: None + mock_tmp_dir.return_value = tmp_dir + + expected_train_script_path = os.path.join(tmp_dir.name, TRAIN_SCRIPT) + expected_runner_json_path = os.path.join(tmp_dir.name, DISTRIBUTED_JSON) + expected_source_code_json_path = os.path.join(tmp_dir.name, SOURCE_CODE_JSON) + + try: + model_trainer = ModelTrainer( + sagemaker_session=modules_session, + training_image=DEFAULT_IMAGE, + source_code=test_case["source_code"], + distributed=test_case["distributed"], + ) + + model_trainer.train() + mock_training_job.create.assert_called_once() + assert mock_training_job.create.call_args.kwargs["hyper_parameters"] == ( + test_case["expected_hyperparameters"] + ) + + assert os.path.exists(expected_train_script_path) + with open(expected_train_script_path, "r") as f: + train_script_content = f.read() + assert test_case["expected_template"] in train_script_content + + assert os.path.exists(expected_runner_json_path) + with open(expected_runner_json_path, "r") as f: + runner_json_content = f.read() + assert test_case["distributed"].model_dump(exclude_none=True) == ( + json.loads(runner_json_content) + ) + assert os.path.exists(expected_source_code_json_path) + with open(expected_source_code_json_path, "r") as f: + source_code_json_content = f.read() + assert test_case["source_code"].model_dump(exclude_none=True) == ( + json.loads(source_code_json_content) + ) + assert os.path.exists(expected_source_code_json_path) + with open(expected_source_code_json_path, "r") as f: + source_code_json_content = f.read() + assert test_case["source_code"].model_dump(exclude_none=True) == ( + json.loads(source_code_json_content) + ) + finally: + shutil.rmtree(tmp_dir.name) + assert not os.path.exists(tmp_dir.name) + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_train_stores_created_training_job(mock_training_job, model_trainer): + mock_training_job.create.return_value = TrainingJob(training_job_name="Created-job") + model_trainer.train(wait=False) + assert model_trainer._latest_training_job is not None + assert model_trainer._latest_training_job == TrainingJob(training_job_name="Created-job") + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_tensorboard_output_config(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + tensorboard_output_config = TensorBoardOutputConfig( + s3_output_path=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}", + local_path="/opt/ml/output/tensorboard", + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_tensorboard_output_config(tensorboard_output_config) + + assert model_trainer._tensorboard_output_config == tensorboard_output_config + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert ( + mock_training_job.create.call_args.kwargs["tensor_board_output_config"] + == tensorboard_output_config + ) + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_retry_strategy(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + retry_strategy = RetryStrategy( + maximum_retry_attempts=3, + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_retry_strategy(retry_strategy) + + assert model_trainer._retry_strategy == retry_strategy + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert mock_training_job.create.call_args.kwargs["retry_strategy"] == retry_strategy + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_infra_check_config(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + infra_check_config = InfraCheckConfig( + enable_infra_check=True, + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_infra_check_config(infra_check_config) + + assert model_trainer._infra_check_config == infra_check_config + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert mock_training_job.create.call_args.kwargs["infra_check_config"] == infra_check_config + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_session_chaining_config(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + session_chaining_config = SessionChainingConfig( + enable_session_tag_chaining=True, + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_session_chaining_config(session_chaining_config) + + assert model_trainer._session_chaining_config == session_chaining_config + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert ( + mock_training_job.create.call_args.kwargs["session_chaining_config"] + == session_chaining_config + ) + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_remote_debug_config(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + remote_debug_config = RemoteDebugConfig( + enable_remote_debug=True, + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_remote_debug_config(remote_debug_config) + + assert model_trainer._remote_debug_config == remote_debug_config + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert ( + mock_training_job.create.call_args.kwargs["remote_debug_config"] == remote_debug_config + ) + + +@patch("sagemaker.modules.train.model_trainer._get_unique_name") +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_model_trainer_full_init(mock_training_job, mock_unique_name, modules_session): + def mock_upload_data(path, bucket, key_prefix): + return f"s3://{bucket}/{key_prefix}" + + modules_session.upload_data.side_effect = mock_upload_data + + training_mode = Mode.SAGEMAKER_TRAINING_JOB + role = DEFAULT_ROLE + source_code = DEFAULT_SOURCE_CODE + distributed = Torchrun() + compute = Compute( + instance_type=DEFAULT_INSTANCE_TYPE, + instance_count=1, + volume_size_in_gb=30, + volume_kms_key_id="key-id", + keep_alive_period_in_seconds=3600, + enable_managed_spot_training=True, + ) + networking = Networking( + security_group_ids=["sg-000000000000"], + subnets=["subnet-000000000000"], + enable_network_isolation=True, + enable_inter_container_traffic_encryption=True, + ) + stopping_condition = DEFAULT_STOPPING_CONDITION + training_image = DEFAULT_IMAGE + training_image_config = TrainingImageConfig( + training_repository_access_mode="Platform", + training_repository_auth_config=TrainingRepositoryAuthConfig( + training_repository_credentials_provider_arn="arn:aws:lambda:us-west-2:000000000000:function:dummy-function" + ), + ) + output_data_config = DEFAULT_OUTPUT_DATA_CONFIG + + local_input_data = InputData( + channel_name="train", data_source=f"{DEFAULT_SOURCE_DIR}/data/train" + ) + s3_data_source_input = InputData( + channel_name="test", + data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}/data/test", + s3_data_distribution_type="FullyReplicated", + attribute_names=["label"], + instance_group_names=["instance-group"], + ), + ) + file_system_input = InputData( + channel_name="validation", + data_source=FileSystemDataSource( + file_system_id="fs-000000000000", + file_system_access_mode="ro", + file_system_type="EFS", + directory_path="/data/validation", + ), + ) + input_data_config = [local_input_data, s3_data_source_input, file_system_input] + checkpoint_config = CheckpointConfig( + local_path="/opt/ml/checkpoints", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}/checkpoints", + ) + training_input_mode = "File" + environment = {"ENV_VAR": "value"} + hyperparameters = {"key": "value"} + tags = [Tag(key="key", value="value")] + + model_trainer = ModelTrainer( + training_mode=training_mode, + sagemaker_session=modules_session, + role=role, + source_code=source_code, + distributed=distributed, + compute=compute, + networking=networking, + stopping_condition=stopping_condition, + training_image=training_image, + training_image_config=training_image_config, + output_data_config=output_data_config, + input_data_config=input_data_config, + checkpoint_config=checkpoint_config, + training_input_mode=training_input_mode, + environment=environment, + hyperparameters=hyperparameters, + tags=tags, + ) + + assert model_trainer.training_mode == training_mode + assert model_trainer.sagemaker_session == modules_session + assert model_trainer.role == role + assert model_trainer.source_code == source_code + assert model_trainer.distributed == distributed + assert model_trainer.compute == compute + assert model_trainer.networking == networking + assert model_trainer.stopping_condition == stopping_condition + assert model_trainer.training_image == training_image + assert model_trainer.training_image_config == training_image_config + assert model_trainer.output_data_config == output_data_config + assert model_trainer.input_data_config == input_data_config + assert model_trainer.checkpoint_config == checkpoint_config + assert model_trainer.training_input_mode == training_input_mode + assert model_trainer.environment == environment + assert model_trainer.hyperparameters == hyperparameters + assert model_trainer.tags == tags + + unique_name = "training-job" + mock_unique_name.return_value = unique_name + + model_trainer.train() + + mock_training_job.create.assert_called_once_with( + training_job_name=unique_name, + algorithm_specification=AlgorithmSpecification( + training_input_mode=training_input_mode, + training_image=training_image, + algorithm_name=None, + container_entrypoint=DEFAULT_ENTRYPOINT, + container_arguments=DEFAULT_ARGUMENTS, + training_image_config=training_image_config, + ), + hyper_parameters=hyperparameters, + input_data_config=[ + Channel( + channel_name=local_input_data.channel_name, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{DEFAULT_BASE_NAME}/{unique_name}/input/train", # noqa: E501 + s3_data_distribution_type="FullyReplicated", + ) + ), + input_mode="File", + ), + Channel( + channel_name=s3_data_source_input.channel_name, + data_source=DataSource(s3_data_source=s3_data_source_input.data_source), + ), + Channel( + channel_name=file_system_input.channel_name, + data_source=DataSource(file_system_data_source=file_system_input.data_source), + ), + Channel( + channel_name="code", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{DEFAULT_BASE_NAME}/{unique_name}/input/code", # noqa: E501 + s3_data_distribution_type="FullyReplicated", + ) + ), + input_mode="File", + ), + Channel( + channel_name="sm_drivers", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{DEFAULT_BASE_NAME}/{unique_name}/input/sm_drivers", # noqa: E501 + s3_data_distribution_type="FullyReplicated", + ), + ), + input_mode="File", + ), + ], + resource_config=ResourceConfig( + instance_type=compute.instance_type, + instance_count=compute.instance_count, + volume_size_in_gb=compute.volume_size_in_gb, + volume_kms_key_id=compute.volume_kms_key_id, + keep_alive_period_in_seconds=compute.keep_alive_period_in_seconds, + instance_groups=None, + ), + vpc_config=VpcConfig( + security_group_ids=networking.security_group_ids, + subnets=networking.subnets, + ), + session=ANY, + role_arn=role, + tags=tags, + stopping_condition=stopping_condition, + output_data_config=output_data_config, + checkpoint_config=checkpoint_config, + environment=environment, + enable_managed_spot_training=compute.enable_managed_spot_training, + enable_inter_container_traffic_encryption=( + networking.enable_inter_container_traffic_encryption + ), + enable_network_isolation=networking.enable_network_isolation, + remote_debug_config=None, + tensor_board_output_config=None, + retry_strategy=None, + infra_check_config=None, + session_chaining_config=None, + ) + + +def test_model_trainer_gpu_recipe_full_init(modules_session): + training_recipe = "training/llama/p4_hf_llama3_70b_seq8k_gpu" + recipe_overrides = {"run": {"results_dir": "/opt/ml/model"}} + compute = Compute(instance_type="ml.p4d.24xlarge", instance_count="2") + + gpu_image_cfg = _load_recipes_cfg().get("gpu_image") + if isinstance(gpu_image_cfg, str): + expected_training_image = gpu_image_cfg + else: + expected_training_image = image_uris.retrieve( + gpu_image_cfg.get("framework"), + region=modules_session.boto_region_name, + version=gpu_image_cfg.get("version"), + image_scope="training", + **gpu_image_cfg.get("additional_args"), + ) + + expected_distributed = Torchrun(smp=SMP(random_seed=123456)) + expected_hyperparameters = {"config-path": ".", "config-name": "recipe.yaml"} + + networking = Networking( + security_group_ids=["sg-000000000000"], + subnets=["subnet-000000000000"], + enable_network_isolation=True, + enable_inter_container_traffic_encryption=True, + ) + stopping_condition = DEFAULT_STOPPING_CONDITION + output_data_config = DEFAULT_OUTPUT_DATA_CONFIG + local_input_data = InputData( + channel_name="train", data_source=f"{DEFAULT_SOURCE_DIR}/data/train" + ) + input_data_config = [local_input_data] + checkpoint_config = CheckpointConfig( + local_path="/opt/ml/checkpoints", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}/checkpoints", + ) + training_input_mode = "File" + environment = {"ENV_VAR": "value"} + tags = [Tag(key="key", value="value")] + requirements = f"{DEFAULT_SOURCE_DIR}/requirements.txt" + + model_trainer = ModelTrainer.from_recipe( + training_recipe=training_recipe, + recipe_overrides=recipe_overrides, + compute=compute, + networking=networking, + stopping_condition=stopping_condition, + requirements=requirements, + output_data_config=output_data_config, + input_data_config=input_data_config, + checkpoint_config=checkpoint_config, + training_input_mode=training_input_mode, + environment=environment, + tags=tags, + sagemaker_session=modules_session, + role=DEFAULT_ROLE, + base_job_name=DEFAULT_BASE_NAME, + ) + + assert model_trainer.training_image == expected_training_image + assert model_trainer.distributed == expected_distributed + assert model_trainer.hyperparameters == expected_hyperparameters + assert model_trainer.source_code is not None + assert model_trainer.source_code.requirements == "requirements.txt" + + assert model_trainer.compute == compute + assert model_trainer.networking == networking + assert model_trainer.stopping_condition == stopping_condition + assert model_trainer.output_data_config == output_data_config + assert model_trainer.input_data_config == input_data_config + assert model_trainer.checkpoint_config == checkpoint_config + assert model_trainer.training_input_mode == training_input_mode + assert model_trainer.environment == environment + assert model_trainer.tags == tags + + +@patch("sagemaker.modules.train.model_trainer._LocalContainer") +@patch("sagemaker.modules.train.model_trainer._get_unique_name") +@patch("sagemaker.modules.local_core.local_container.download_folder") +def test_model_trainer_local_full_init( + mock_download_folder, mock_unique_name, mock_local_container, modules_session +): + def mock_upload_data(path, bucket, key_prefix): + return f"s3://{bucket}/{key_prefix}" + + modules_session.upload_data.side_effect = mock_upload_data + mock_download_folder.return_value = f"{DEFAULT_SOURCE_DIR}/data/test" + mock_local_container.train.return_value = None + + training_mode = Mode.LOCAL_CONTAINER + role = DEFAULT_ROLE + source_code = DEFAULT_SOURCE_CODE + distributed = Torchrun() + compute = Compute( + instance_type=DEFAULT_INSTANCE_TYPE, + instance_count=1, + volume_size_in_gb=30, + volume_kms_key_id="key-id", + keep_alive_period_in_seconds=3600, + enable_managed_spot_training=True, + ) + networking = Networking( + security_group_ids=["sg-000000000000"], + subnets=["subnet-000000000000"], + enable_network_isolation=True, + enable_inter_container_traffic_encryption=True, + ) + stopping_condition = DEFAULT_STOPPING_CONDITION + training_image = DEFAULT_IMAGE + training_image_config = TrainingImageConfig( + training_repository_access_mode="Platform", + training_repository_auth_config=TrainingRepositoryAuthConfig( + training_repository_credentials_provider_arn="arn:aws:lambda:us-west-2:000000000000:function:dummy-function" + ), + ) + output_data_config = DEFAULT_OUTPUT_DATA_CONFIG + + local_input_data = InputData( + channel_name="train", data_source=f"{DEFAULT_SOURCE_DIR}/data/train" + ) + s3_data_source_input = InputData( + channel_name="test", + data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}/data/test", + s3_data_distribution_type="FullyReplicated", + attribute_names=["label"], + instance_group_names=["instance-group"], + ), + ) + file_system_input = InputData( + channel_name="validation", + data_source=FileSystemDataSource( + file_system_id="fs-000000000000", + file_system_access_mode="ro", + file_system_type="EFS", + directory_path="/data/validation", + ), + ) + input_data_config = [local_input_data, s3_data_source_input, file_system_input] + checkpoint_config = CheckpointConfig( + local_path="/opt/ml/checkpoints", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}/checkpoints", + ) + training_input_mode = "File" + environment = {"ENV_VAR": "value"} + hyperparameters = {"key": "value"} + tags = [Tag(key="key", value="value")] + + local_container_root = os.getcwd() + + model_trainer = ModelTrainer( + training_mode=training_mode, + sagemaker_session=modules_session, + role=role, + source_code=source_code, + distributed=distributed, + compute=compute, + networking=networking, + stopping_condition=stopping_condition, + training_image=training_image, + training_image_config=training_image_config, + output_data_config=output_data_config, + input_data_config=input_data_config, + checkpoint_config=checkpoint_config, + training_input_mode=training_input_mode, + environment=environment, + hyperparameters=hyperparameters, + tags=tags, + local_container_root=local_container_root, + ) + + assert model_trainer.training_mode == training_mode + assert model_trainer.sagemaker_session == modules_session + assert model_trainer.role == role + assert model_trainer.source_code == source_code + assert model_trainer.distributed == distributed + assert model_trainer.compute == compute + assert model_trainer.networking == networking + assert model_trainer.stopping_condition == stopping_condition + assert model_trainer.training_image == training_image + assert model_trainer.training_image_config == training_image_config + assert model_trainer.output_data_config == output_data_config + assert model_trainer.input_data_config == input_data_config + assert model_trainer.checkpoint_config == checkpoint_config + assert model_trainer.training_input_mode == training_input_mode + assert model_trainer.environment == environment + assert model_trainer.hyperparameters == hyperparameters + assert model_trainer.tags == tags + + unique_name = "training-job" + mock_unique_name.return_value = unique_name + + model_trainer.train() + + assert mock_local_container.train.called_once_with( + training_job_name=unique_name, + instance_type=compute.instance_type, + instance_count=compute.instance_count, + image=training_image, + container_root=local_container_root, + sagemaker_session=modules_session, + container_entry_point=DEFAULT_ENTRYPOINT, + container_arguments=DEFAULT_ARGUMENTS, + hyper_parameters=hyperparameters, + environment=environment, + ) diff --git a/tests/unit/sagemaker/partner_app/__init__.py b/tests/unit/sagemaker/partner_app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/partner_app/test_auth_provider.py b/tests/unit/sagemaker/partner_app/test_auth_provider.py new file mode 100644 index 0000000000..c5a27cff3a --- /dev/null +++ b/tests/unit/sagemaker/partner_app/test_auth_provider.py @@ -0,0 +1,152 @@ +from __future__ import absolute_import + +import os +import unittest +from unittest.mock import patch, MagicMock +from requests import PreparedRequest +from sagemaker.partner_app.auth_provider import RequestsAuth, PartnerAppAuthProvider + + +class TestRequestsAuth(unittest.TestCase): + + @patch("sagemaker.partner_app.auth_provider.PartnerAppAuthUtils.get_signed_request") + @patch("sagemaker.partner_app.auth_provider.SigV4Auth") + def test_requests_auth_call(self, mock_sigv4_auth, mock_get_signed_request): + # Prepare mock data + mock_signed_url = "https://returned-url.test.com/" + mock_signed_headers = {"Authorization": "SigV4", "x-amz-date": "20241016T120000Z"} + mock_get_signed_request.return_value = (mock_signed_url, mock_signed_headers) + + # Create the objects needed for testing + app_arn = "arn:aws:lambda:us-west-2:123456789012:sagemaker:test" + under_test = RequestsAuth(sigv4=mock_sigv4_auth, app_arn=app_arn) + + # Create a prepared request object to simulate an actual request + request = PreparedRequest() + request.method = "GET" + request_url = "https://test.com" + request.url = request_url + request_headers = {} + request.headers = request_headers + request.body = "{}" + + # Call the method under test + updated_request = under_test(request) + + # Assertions to verify the behavior + mock_get_signed_request.assert_called_once_with( + sigv4=mock_sigv4_auth, + app_arn=app_arn, + url=request_url, + method="GET", + headers=request_headers, + body=request.body, + ) + + self.assertEqual(updated_request.url, mock_signed_url) + self.assertIn("Authorization", updated_request.headers) + self.assertIn("x-amz-date", updated_request.headers) + self.assertEqual(updated_request.headers["Authorization"], "SigV4") + self.assertEqual(updated_request.headers["x-amz-date"], "20241016T120000Z") + + +class TestPartnerAppAuthProvider(unittest.TestCase): + + @patch("sagemaker.partner_app.auth_provider.boto3.Session") + @patch("sagemaker.partner_app.auth_provider.SigV4Auth") + @patch("sagemaker.partner_app.auth_provider.PartnerAppAuthUtils.get_signed_request") + def test_get_signed_request( + self, mock_get_signed_request, mock_sigv4auth_class, mock_boto3_session + ): + # Set up environment variable + test_app_arn = "arn:aws-us-gov:sagemaker:us-west-2:123456789012:partner-app/my-app" + os.environ["AWS_PARTNER_APP_ARN"] = test_app_arn + + # Mock the return value of boto3.Session().get_credentials() + mock_credentials = MagicMock() + mock_boto3_session.return_value.get_credentials.return_value = mock_credentials + + # Mock the SigV4Auth instance + mock_sigv4auth_instance = MagicMock() + mock_sigv4auth_class.return_value = mock_sigv4auth_instance + + # Initialize the PartnerAppAuthProvider class + provider = PartnerAppAuthProvider() + + # Mock return value for get_signed_request + mock_get_signed_request.return_value = { + "Authorization": "SigV4", + "x-amz-date": "20241016T120000Z", + } + + # Call get_signed_request method + signed_request = provider.get_signed_request( + url="https://example.com", + method="GET", + headers={"Content-Type": "application/json"}, + body=None, + ) + + # Assert that the get_signed_request method was called with correct parameters + mock_get_signed_request.assert_called_once_with( + sigv4=mock_sigv4auth_instance, + app_arn=test_app_arn, + url="https://example.com", + method="GET", + headers={"Content-Type": "application/json"}, + body=None, + ) + + # Assert the response matches the mocked return value + self.assertEqual(signed_request["Authorization"], "SigV4") + self.assertEqual(signed_request["x-amz-date"], "20241016T120000Z") + + @patch("sagemaker.partner_app.auth_provider.SigV4Auth") + def test_get_auth(self, mock_sigv4auth_class): + # Set up environment variable + os.environ["AWS_PARTNER_APP_ARN"] = ( + "arn:aws:sagemaker:us-west-2:123456789012:partner-app/app-abc" + ) + + # Mock the SigV4Auth instance + mock_sigv4auth_instance = MagicMock() + mock_sigv4auth_class.return_value = mock_sigv4auth_instance + + # Initialize the PartnerAppAuthProvider class + provider = PartnerAppAuthProvider() + + # Call get_auth method + auth_instance = provider.get_auth() + + # Assert that the returned object is a RequestsAuth instance + self.assertIsInstance(auth_instance, RequestsAuth) + + # Assert that RequestsAuth was initialized with correct arguments + self.assertEqual(auth_instance.sigv4, mock_sigv4auth_instance) + self.assertEqual(auth_instance.app_arn, os.environ["AWS_PARTNER_APP_ARN"]) + + def test_init_raises_value_error_with_missing_app_arn(self): + # Remove the environment variable + if "AWS_PARTNER_APP_ARN" in os.environ: + del os.environ["AWS_PARTNER_APP_ARN"] + + # Ensure ValueError is raised when AWS_PARTNER_APP_ARN is not set + with self.assertRaises(ValueError) as context: + PartnerAppAuthProvider() + + self.assertIn( + "Must specify the AWS_PARTNER_APP_ARN environment variable", str(context.exception) + ) + + def test_init_raises_value_error_with_invalid_app_arn(self): + os.environ["AWS_PARTNER_APP_ARN"] = ( + "arn:aws:lambda:us-west-2:123456789012:function:my-function" + ) + + # Ensure ValueError is raised when AWS_PARTNER_APP_ARN is not set + with self.assertRaises(ValueError) as context: + PartnerAppAuthProvider() + + self.assertIn( + "Must specify a valid AWS_PARTNER_APP_ARN environment variable", str(context.exception) + ) diff --git a/tests/unit/sagemaker/partner_app/test_auth_utils.py b/tests/unit/sagemaker/partner_app/test_auth_utils.py new file mode 100644 index 0000000000..b75dc9a30e --- /dev/null +++ b/tests/unit/sagemaker/partner_app/test_auth_utils.py @@ -0,0 +1,111 @@ +from __future__ import absolute_import + +import unittest +from unittest.mock import Mock, patch +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest +from hashlib import sha256 + +from sagemaker.partner_app.auth_utils import ( + PartnerAppAuthUtils, + EMPTY_SHA256_HASH, + UNSIGNED_PAYLOAD, +) + + +class TestPartnerAppAuthUtils(unittest.TestCase): + def setUp(self): + self.sigv4_mock = Mock(spec=SigV4Auth) + self.app_arn = "arn:aws:sagemaker:us-west-2:123456789012:partner-app/abc123" + self.url = "https://partner-app-abc123.us-west-2.amazonaws.com?fileName=Jupyter+interactive" + self.method = "POST" + self.headers = {"Authorization": "API_KEY", "Connection": "conn"} + self.body = b'{"key": "value"}' # Byte type body for hashing + + @patch("sagemaker.partner_app.auth_utils.AWSRequest") + def test_get_signed_request_with_body(self, AWSRequestMock): + aws_request_mock = Mock(spec=AWSRequest) + AWSRequestMock.return_value = aws_request_mock + + expected_hash = sha256(self.body).hexdigest() + # Authorization still has the original value as the sigv4 mock does not add this header + expected_sign_headers = { + "Authorization": "API_KEY", + "X-Amz-Partner-App-Authorization": "API_KEY", + "X-SageMaker-Partner-App-Server-Arn": self.app_arn, + "X-Amz-Target": "SageMaker.CallPartnerAppApi", + "X-Amz-Content-SHA256": expected_hash, + } + aws_request_mock.headers = expected_sign_headers + + # Mock the add_auth method on the SigV4Auth + self.sigv4_mock.add_auth = Mock() + + url, signed_headers = PartnerAppAuthUtils.get_signed_request( + self.sigv4_mock, self.app_arn, self.url, self.method, self.headers, self.body + ) + + # Assert X-SageMaker-Partner-App-Server-Arn header is correct + self.assertEqual(signed_headers["X-SageMaker-Partner-App-Server-Arn"], self.app_arn) + + # Assert the Authorization header was moved to X-Amz-Partner-App-Authorization + self.assertIn("X-Amz-Partner-App-Authorization", signed_headers) + + # Assert X-Amz-Content-SHA256 is set + self.assertEqual(signed_headers["X-Amz-Content-SHA256"], expected_hash) + + # Assert the Connection header is reserved + self.assertEqual(signed_headers["Connection"], "conn") + + expected_canonical_url = self.url.replace("+", "%20") + # Assert AWSRequestMock was called + AWSRequestMock.assert_called_once_with( + method=self.method, + url=expected_canonical_url, + headers=expected_sign_headers, + data=self.body, + ) + + def test_get_signed_request_with_no_body(self): + body = None + url, signed_headers = PartnerAppAuthUtils.get_signed_request( + self.sigv4_mock, self.app_arn, self.url, self.method, self.headers, body + ) + + # Assert X-Amz-Content-SHA256 is EMPTY_SHA256_HASH + self.assertEqual(signed_headers["X-Amz-Content-SHA256"], EMPTY_SHA256_HASH) + + def test_get_signed_request_with_bytes_body(self): + body = Mock() + body.seek = Mock() + body.tell = Mock(return_value=0) + body.read = Mock(side_effect=[b"test", b""]) + + url, signed_headers = PartnerAppAuthUtils.get_signed_request( + self.sigv4_mock, self.app_arn, self.url, self.method, self.headers, body + ) + + # Verify the seek method was called + body.seek.assert_called() + + # Calculate the expected checksum for the body + checksum = sha256(b"test").hexdigest() + + # Assert X-Amz-Content-SHA256 is the calculated checksum + self.assertEqual(signed_headers["X-Amz-Content-SHA256"], checksum) + + def test_get_body_header_unsigned_payload(self): + body = {"key": "value"} + + result = PartnerAppAuthUtils.get_body_header(body) + + # Assert the result is UNSIGNED_PAYLOAD for unrecognized body type + self.assertEqual(result, UNSIGNED_PAYLOAD) + + def test_get_body_header_empty_body(self): + body = None + + result = PartnerAppAuthUtils.get_body_header(body) + + # Assert the result is EMPTY_SHA256_HASH for empty body + self.assertEqual(result, EMPTY_SHA256_HASH) diff --git a/tests/unit/sagemaker/remote_function/core/test_serialization.py b/tests/unit/sagemaker/remote_function/core/test_serialization.py index a0742240ea..e87dc39b59 100644 --- a/tests/unit/sagemaker/remote_function/core/test_serialization.py +++ b/tests/unit/sagemaker/remote_function/core/test_serialization.py @@ -99,6 +99,7 @@ def test_serialize_deserialize_lambda(): assert deserialized(3) == 9 +@pytest.mark.flaky(reruns=3, reruns_delay=5) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @patch("sagemaker.s3.S3Downloader.read_bytes", new=read) @patch("sagemaker.experiments.run.Experiment") @@ -106,9 +107,11 @@ def test_serialize_deserialize_lambda(): @patch("sagemaker.experiments.run._TrialComponent._load_or_create", return_value=(Mock(), False)) @patch("sagemaker.experiments.run._MetricsManager") @patch("sagemaker.remote_function.job.Session") -def test_serialize_func_referencing_to_run(*args, **kwargs): +def test_serialize_func_referencing_to_run(sagemaker_session, *args, **kwargs): - with Run(experiment_name="exp_name", run_name="run_name") as run: + with Run( + sagemaker_session=sagemaker_session, experiment_name="exp_name", run_name="run_name" + ) as run: def train(x): return run.log_metric() @@ -302,6 +305,7 @@ def test_serialize_deserialize_none(): assert deserialized is None +@pytest.mark.flaky(reruns=3, reruns_delay=5) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @patch("sagemaker.s3.S3Downloader.read_bytes", new=read) @patch("sagemaker.experiments.run.Experiment") @@ -309,8 +313,10 @@ def test_serialize_deserialize_none(): @patch("sagemaker.experiments.run._TrialComponent._load_or_create", return_value=(Mock(), False)) @patch("sagemaker.experiments.run._MetricsManager") @patch("sagemaker.remote_function.job.Session") -def test_serialize_run(*args, **kwargs): - with Run(experiment_name="exp_name", run_name="run_name") as run: +def test_serialize_run(sagemaker_session, *args, **kwargs): + with Run( + sagemaker_session=sagemaker_session, experiment_name="exp_name", run_name="run_name" + ) as run: s3_uri = random_s3_uri() with pytest.raises( SerializationError, diff --git a/tests/unit/sagemaker/serve/builder/test_djl_builder.py b/tests/unit/sagemaker/serve/builder/test_djl_builder.py index 69f8c7a8d5..9c4488fa3e 100644 --- a/tests/unit/sagemaker/serve/builder/test_djl_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_djl_builder.py @@ -24,7 +24,7 @@ LocalModelOutOfMemoryException, LocalModelInvocationException, ) -from sagemaker.serve.utils.predictors import DjlLocalModePredictor +from sagemaker.serve.utils.predictors import DjlLocalModePredictor, InProcessModePredictor from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG mock_model_id = "TheBloke/Llama-2-7b-chat-fp16" @@ -46,6 +46,8 @@ "OPTION_DTYPE": "bf16", "MODEL_LOADING_TIMEOUT": "1800", } +mock_inference_spec = MagicMock() +mock_inference_spec.get_model = "TheBloke/Llama-2-7b-chat-fp16" mock_schema_builder = MagicMock() mock_schema_builder.sample_input = mock_sample_input @@ -115,6 +117,59 @@ def test_build_deploy_for_djl_local_container( with self.assertRaises(ValueError) as _: model.deploy(mode=Mode.IN_PROCESS) + @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=False, + ) + @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) + @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") + @patch( + "sagemaker.serve.builder.djl_builder._get_default_djl_configurations", + return_value=(mock_default_configs, 128), + ) + def test_build_deploy_for_djl_in_process( + self, + mock_default_djl_config, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model=mock_model_id, + name="mock_model_name", + schema_builder=mock_schema_builder, + mode=Mode.IN_PROCESS, + model_server=ModelServer.DJL_SERVING, + image_config=MOCK_IMAGE_CONFIG, + vpc_config=MOCK_VPC_CONFIG, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + + model = builder.build() + assert model.name == "mock_model_name" + + builder.serve_settings.telemetry_opt_out = True + + assert isinstance(model, DJLModel) + assert builder.schema_builder.sample_input["parameters"]["max_new_tokens"] == 128 + assert builder.nb_instance_type == "ml.g5.24xlarge" + assert model.image_config == MOCK_IMAGE_CONFIG + assert model.vpc_config == MOCK_VPC_CONFIG + assert "lmi" in builder.image_uri + + builder.modes[str(Mode.IN_PROCESS)] = MagicMock() + predictor = model.deploy(model_data_download_timeout=1800) + + assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" + assert isinstance(predictor, InProcessModePredictor) + + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 4e34c5f864..7355fe4f38 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -18,7 +18,19 @@ from pathlib import Path from copy import deepcopy +import deepdiff +import pytest +from sagemaker.enums import EndpointType + +from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig +from sagemaker.batch_inference.batch_transform_inference_config import BatchTransformInferenceConfig + +from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements + +from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig + from sagemaker.model import Model + from sagemaker.serve import SchemaBuilder from sagemaker.serve.builder.model_builder import ModelBuilder from sagemaker.serve.mode.function_pointers import Mode @@ -66,20 +78,17 @@ mock_session = MagicMock() +RESOURCE_REQUIREMENTS = ResourceRequirements( + requests={ + "num_cpus": 0.5, + "memory": 512, + "copies": 2, + }, + limits={}, +) -class TestModelBuilder(unittest.TestCase): - @patch("sagemaker.serve.builder.model_builder._ServeSettings") - def test_validation_in_progress_mode_supported(self, mock_serveSettings): - builder = ModelBuilder(model_server=ModelServer.TORCHSERVE) - self.assertRaisesRegex( - Exception, - "IN_PROCESS mode is only supported for MMS/Transformers server in beta release.", - builder.build, - Mode.IN_PROCESS, - mock_role_arn, - mock_session, - ) +class TestModelBuilder(unittest.TestCase): @patch("sagemaker.serve.builder.model_builder._ServeSettings") def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSettings): builder = ModelBuilder(inference_spec="some value", model=Mock(spec=object)) @@ -2425,7 +2434,7 @@ def test_optimize( self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) - mock_send_telemetry.assert_called_once() + assert mock_send_telemetry.call_count == 2 mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( OptimizationJobName="my-optimization-job", DeploymentInstanceType="ml.g5.24xlarge", @@ -2881,6 +2890,366 @@ def test_optimize_for_hf_without_custom_s3_path( }, ) + def test_deploy_invalid_inputs(self): + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HUGGING_FACE_HUB_TOKEN": "token"}, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + inputs = {"endpoint_name": "endpoint-001"} + + try: + model_builder.deploy(**inputs) + except ValueError as e: + assert "Model Needs to be built before deploying" in str(e) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + def test_display_benchmark_metrics_non_string_model(self, mock_is_jumpstart): + """Test that ValueError is raised when model is not a string""" + builder = ModelBuilder(model=Mock()) # Non-string model + + self.assertRaisesRegex( + ValueError, + "Benchmarking is only supported for JumpStart or HuggingFace models", + builder.display_benchmark_metrics, + ) + mock_is_jumpstart.assert_not_called() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.display_benchmark_metrics") + def test_display_benchmark_metrics_jumpstart_model( + self, mock_display_benchmark_metrics, mock_is_jumpstart + ): + """Test successful execution for jumpstart model""" + mock_is_jumpstart.return_value = True + + builder = ModelBuilder(model="jumpstart-model-id") + builder.display_benchmark_metrics() + + mock_is_jumpstart.assert_called_once() + mock_display_benchmark_metrics.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.display_benchmark_metrics") + def test_display_benchmark_metrics_with_jumpstart_equivalent( + self, mock_display_benchmark_metrics, mock_has_equivalent, mock_is_jumpstart + ): + """Test successful execution for model with jumpstart equivalent""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = True + + builder = ModelBuilder(model="hf-model-id") + builder.display_benchmark_metrics() + + mock_is_jumpstart.assert_called_once() + mock_has_equivalent.assert_called_once() + mock_display_benchmark_metrics.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + def test_display_benchmark_metrics_unsupported_model( + self, mock_has_equivalent, mock_is_jumpstart + ): + """Test that ValueError is raised for unsupported models""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = False + + builder = ModelBuilder(model="huggingface-model-id") + + self.assertRaisesRegex( + ValueError, + "This model does not have benchmark metrics yet", + builder.display_benchmark_metrics, + ) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + def test_get_deployment_config_non_string_model(self, mock_is_jumpstart): + """Test that ValueError is raised when model is not a string""" + builder = ModelBuilder(model=Mock()) # Non-string model + + self.assertRaisesRegex( + ValueError, + "Deployment config is only supported for JumpStart or HuggingFace models", + builder.get_deployment_config, + ) + mock_is_jumpstart.assert_not_called() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.get_deployment_config") + def test_get_deployment_config_jumpstart_model( + self, mock_get_deployment_config, mock_is_jumpstart + ): + """Test successful execution for jumpstart model""" + mock_is_jumpstart.return_value = True + + builder = ModelBuilder(model="jumpstart-model-id") + builder.get_deployment_config() + + mock_is_jumpstart.assert_called_once() + mock_get_deployment_config.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.get_deployment_config") + def test_get_deployment_config_with_jumpstart_equivalent( + self, mock_get_deployment_config, mock_has_equivalent, mock_is_jumpstart + ): + """Test successful execution for model with jumpstart equivalent""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = True + + builder = ModelBuilder(model="hf-model-id") + builder.get_deployment_config() + + mock_is_jumpstart.assert_called_once() + mock_has_equivalent.assert_called_once() + mock_get_deployment_config.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + def test_get_deployment_config_unsupported_model(self, mock_has_equivalent, mock_is_jumpstart): + """Test that ValueError is raised for unsupported models""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = False + + builder = ModelBuilder(model="huggingface-model-id") + + self.assertRaisesRegex( + ValueError, + "This model does not have any deployment config yet", + builder.get_deployment_config, + ) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + def test_list_deployment_configs_non_string_model(self, mock_is_jumpstart): + """Test that ValueError is raised when model is not a string""" + builder = ModelBuilder(model=Mock()) # Non-string model + + self.assertRaisesRegex( + ValueError, + "Deployment config is only supported for JumpStart or HuggingFace models", + builder.list_deployment_configs, + ) + mock_is_jumpstart.assert_not_called() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.list_deployment_configs") + def test_list_deployment_configs_jumpstart_model( + self, mock_list_deployment_configs, mock_is_jumpstart + ): + """Test successful execution for jumpstart model""" + mock_is_jumpstart.return_value = True + + builder = ModelBuilder(model="jumpstart-model-id") + builder.list_deployment_configs() + + mock_is_jumpstart.assert_called_once() + mock_list_deployment_configs.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.list_deployment_configs") + def test_list_deployment_configs_with_jumpstart_equivalent( + self, mock_list_deployment_configs, mock_has_equivalent, mock_is_jumpstart + ): + """Test successful execution for model with jumpstart equivalent""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = True + + builder = ModelBuilder(model="hf-model-id") + builder.list_deployment_configs() + + mock_is_jumpstart.assert_called_once() + mock_has_equivalent.assert_called_once() + mock_list_deployment_configs.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + def test_list_deployment_configs_unsupported_model( + self, mock_has_equivalent, mock_is_jumpstart + ): + """Test that ValueError is raised for unsupported models""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = False + + builder = ModelBuilder(model="huggingface-model-id") + + self.assertRaisesRegex( + ValueError, + "This model does not have any deployment config yet", + builder.list_deployment_configs, + ) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + def test_set_deployment_config_non_string_model(self, mock_is_jumpstart): + """Test that ValueError is raised when model is not a string""" + builder = ModelBuilder(model=Mock()) # Non-string model + instance_type = "ml.g5.xlarge" + config_name = "config-name" + self.assertRaisesRegex( + ValueError, + "Deployment config is only supported for JumpStart or HuggingFace models", + builder.set_deployment_config, + config_name, + instance_type, + ) + mock_is_jumpstart.assert_not_called() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.set_deployment_config") + def test_set_deployment_config_jumpstart_model( + self, mock_set_deployment_config, mock_is_jumpstart + ): + """Test successful execution for jumpstart model""" + mock_is_jumpstart.return_value = True + instance_type = "ml.g5.xlarge" + config_name = "config-name" + + builder = ModelBuilder(model="jumpstart-model-id") + builder.set_deployment_config(config_name, instance_type) + + mock_is_jumpstart.assert_called_once() + mock_set_deployment_config.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.set_deployment_config") + def test_set_deployment_config_with_jumpstart_equivalent( + self, mock_set_deployment_config, mock_has_equivalent, mock_is_jumpstart + ): + """Test successful execution for model with jumpstart equivalent""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = True + instance_type = "ml.g5.xlarge" + config_name = "config-name" + + builder = ModelBuilder(model="hf-model-id") + builder.set_deployment_config(config_name, instance_type) + + mock_is_jumpstart.assert_called_once() + mock_has_equivalent.assert_called_once() + mock_set_deployment_config.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + def test_set_deployment_config_unsupported_model(self, mock_has_equivalent, mock_is_jumpstart): + """Test that ValueError is raised for unsupported models""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = False + instance_type = "ml.g5.xlarge" + config_name = "config-name" + + builder = ModelBuilder(model="huggingface-model-id") + + self.assertRaisesRegex( + ValueError, + f"The deployment config {config_name} cannot be set on this model", + builder.set_deployment_config, + config_name, + instance_type, + ) + + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._retrieve_hugging_face_model_mapping" + ) + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_gated_model") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_jumpstart") + def test_use_jumpstart_equivalent_return_true( + self, mock_build_for_jumpstart, mock_is_gated_model, mock_retrieve_mapping + ): + """Test that _use_jumpstart_equivalent returns True when equivalent exists""" + mock_retrieve_mapping.return_value = { + "HuggingFaceH4/zephyr-7b-beta": { + "jumpstart-model-id": "js-model", + "jumpstart-model-version": "1.0.0", + "hf-model-repo-sha": None, + } + } + mock_is_gated_model.return_value = False + + builder = ModelBuilder(model="HuggingFaceH4/zephyr-7b-beta") + + self.assertTrue(builder._use_jumpstart_equivalent()) + + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._retrieve_hugging_face_model_mapping" + ) + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_gated_model") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_jumpstart") + def test_use_jumpstart_equivalent_return_true_with_schema_builder( + self, mock_build_for_jumpstart, mock_is_gated_model, mock_retrieve_mapping + ): + """Test that _use_jumpstart_equivalent returns True when equivalent exists""" + mock_retrieve_mapping.return_value = { + "HuggingFaceH4/zephyr-7b-beta": { + "jumpstart-model-id": "js-model", + "jumpstart-model-version": "1.0.0", + "hf-model-repo-sha": None, + } + } + mock_is_gated_model.return_value = False + + builder = ModelBuilder(model="HuggingFaceH4/zephyr-7b-beta", sagemaker_session=mock_session) + + self.assertTrue(builder._use_jumpstart_equivalent()) + self.assertIsNotNone(builder.schema_builder) + inputs, outputs = task.retrieve_local_schemas("text-generation") + self.assertEqual(builder.schema_builder.sample_input["inputs"], inputs["inputs"]) + self.assertEqual(builder.schema_builder.sample_output, outputs) + + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._retrieve_hugging_face_model_mapping" + ) + def test_use_jumpstart_equivalent_return_false(self, mock_retrieve_mapping): + """Test that _use_jumpstart_equivalent returns false when equivalent doesn't exist""" + mock_retrieve_mapping.return_value = { + "hf-model-id": { + "jumpstart-model-id": "js-model", + "jumpstart-model-version": "1.0.0", + "hf-model-repo-sha": None, + } + } + + builder = ModelBuilder(model="model-id") + + self.assertFalse(builder._use_jumpstart_equivalent()) + + def test_use_jumpstart_equivalent_return_false_with_env_vars(self): + """Test that _use_jumpstart_equivalent returns false when env_vars is provided""" + builder = ModelBuilder(model="model-id", env_vars={"mock-key": "mock-value"}) + + self.assertFalse(builder._use_jumpstart_equivalent()) + + def test_use_jumpstart_equivalent_return_false_with_image_uri(self): + """Test that _use_jumpstart_equivalent returns false when image_uri is provided""" + builder = ModelBuilder(model="model-id", image_uri="mock-uri") + + self.assertFalse(builder._use_jumpstart_equivalent()) + + @patch("sagemaker.serve.builder.model_builder.JumpStartS3PayloadAccessor") + @patch("sagemaker.serve.builder.model_builder.get_jumpstart_content_bucket") + def test_retrieve_hugging_face_model_mapping(self, mock_content_bucket, mock_payload_accessor): + """Test that _retrieve_hugging_face_model_mapping returns the correct mapping""" + mock_get_object = Mock() + mock_get_object.return_value = ( + '{"js-model-id": {"hf-model-id": "hf-model", "jumpstart-model-version": "1.0.0"}}' + ) + mock_payload_accessor.get_object_cached = mock_get_object + expected_mapping = { + "hf-model": { + "jumpstart-model-id": "js-model-id", + "jumpstart-model-version": "1.0.0", + "hf-model-repo-sha": None, + "merged-at": None, + } + } + + builder = ModelBuilder(model="hf-model", sagemaker_session=mock_session) + + self.assertEqual(builder._retrieve_hugging_face_model_mapping(), expected_mapping) + @patch.object(ModelBuilder, "_prepare_for_mode") @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation( @@ -3076,7 +3445,7 @@ def test_optimize_sharding_with_env_vars( self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) - mock_send_telemetry.assert_called_once() + assert mock_send_telemetry.call_count == 2 mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( OptimizationJobName="my-optimization-job", DeploymentInstanceType="ml.g5.24xlarge", @@ -3184,7 +3553,7 @@ def test_optimize_sharding_with_override_and_env_var( self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) - mock_send_telemetry.assert_called_once() + assert mock_send_telemetry.call_count == 2 mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( OptimizationJobName="my-optimization-job", DeploymentInstanceType="ml.g5.24xlarge", @@ -3298,7 +3667,7 @@ def test_optimize_sharding_with_override( self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) - mock_send_telemetry.assert_called_once() + assert mock_send_telemetry.call_count == 2 mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( OptimizationJobName="my-optimization-job", DeploymentInstanceType="ml.g5.24xlarge", @@ -3426,7 +3795,7 @@ def test_optimize_sharding_with_override_for_js( self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") - mock_send_telemetry.assert_called_once() + assert mock_send_telemetry.call_count == 2 mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( OptimizationJobName="my-optimization-job", ModelSource={"S3": {"S3Uri": ANY}}, @@ -3663,3 +4032,86 @@ def test_neuron_configurations_rule_set(self): speculative_decoding_config=None, compilation_config={}, ) + + +@pytest.mark.parametrize( + "test_case", + [ + { + "input_args": {"endpoint_name": "test"}, + "call_params": { + "instance_type": "ml.g5.2xlarge", + "initial_instance_count": 1, + "endpoint_name": "test", + }, + }, + { + "input_args": { + "endpoint_name": "test", + "inference_config": ServerlessInferenceConfig(), + }, + "call_params": { + "serverless_inference_config": ServerlessInferenceConfig(), + "endpoint_name": "test", + }, + }, + { + "input_args": { + "endpoint_name": "test", + "inference_config": AsyncInferenceConfig(output_path="op-path"), + }, + "call_params": { + "async_inference_config": AsyncInferenceConfig(output_path="op-path"), + "instance_type": "ml.g5.2xlarge", + "initial_instance_count": 1, + "endpoint_name": "test", + }, + }, + { + "input_args": {"endpoint_name": "test", "inference_config": RESOURCE_REQUIREMENTS}, + "call_params": { + "resources": RESOURCE_REQUIREMENTS, + "role": "role-arn", + "initial_instance_count": 1, + "instance_type": "ml.g5.2xlarge", + "mode": Mode.SAGEMAKER_ENDPOINT, + "endpoint_type": EndpointType.INFERENCE_COMPONENT_BASED, + }, + }, + { + "input_args": { + "inference_config": BatchTransformInferenceConfig( + instance_count=1, instance_type="ml.m5.large", output_path="op-path" + ) + }, + "call_params": { + "instance_count": 1, + "instance_type": "ml.m5.large", + "output_path": "op-path", + }, + "id": "Batch", + }, + ], + ids=["Real Time", "Serverless", "Async", "Multi-Model", "Batch"], +) +@patch("sagemaker.serve.builder.model_builder.unique_name_from_base") +def test_deploy(mock_unique_name_from_base, test_case): + mock_unique_name_from_base.return_value = "test" + model: Model = MagicMock() + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HUGGING_FACE_HUB_TOKEN": "token"}, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + setattr(model_builder, "built_model", model) + + model_builder.deploy(**test_case["input_args"]) + + if "id" in test_case and test_case["id"] == "Batch": + args, kwargs = model.transformer.call_args_list[0] + else: + args, kwargs = model.deploy.call_args_list[0] + + diff = deepdiff.DeepDiff(kwargs, test_case["call_params"]) + assert diff == {} diff --git a/tests/unit/sagemaker/serve/builder/test_tgi_builder.py b/tests/unit/sagemaker/serve/builder/test_tgi_builder.py index 0fa227f5d4..22109c93e2 100644 --- a/tests/unit/sagemaker/serve/builder/test_tgi_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_tgi_builder.py @@ -18,7 +18,8 @@ from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.predictors import TgiLocalModePredictor -MOCK_MODEL_ID = "meta-llama/Meta-Llama-3-8B" +MOCK_MODEL_ID_GATED = "meta-llama/Meta-Llama-3-8B" +MOCK_MODEL_ID_NON_GATED = "openai-community/gpt2.0" MOCK_PROMPT = "The man worked as a [MASK]." MOCK_SAMPLE_INPUT = {"inputs": "Hello, I'm a language model", "parameters": {"max_new_tokens": 128}} MOCK_SAMPLE_OUTPUT = [{"generated_text": "Hello, I'm a language modeler."}] @@ -60,7 +61,7 @@ def test_tgi_builder_sagemaker_endpoint_mode_no_s3_upload_success( ): # verify SAGEMAKER_ENDPOINT deploy builder = ModelBuilder( - model=MOCK_MODEL_ID, + model=MOCK_MODEL_ID_NON_GATED, name="mock_model_name", schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.SAGEMAKER_ENDPOINT, @@ -109,7 +110,7 @@ def test_tgi_builder_overwritten_deploy_from_local_container_to_sagemaker_endpoi ): # verify LOCAL_CONTAINER deploy builder = ModelBuilder( - model=MOCK_MODEL_ID, + model=MOCK_MODEL_ID_NON_GATED, schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.LOCAL_CONTAINER, model_path=MOCK_MODEL_PATH, @@ -169,7 +170,7 @@ def test_tgi_builder_optimized_sagemaker_endpoint_mode_no_s3_upload_success( ): # verify LOCAL_CONTAINER deploy builder = ModelBuilder( - model=MOCK_MODEL_ID, + model=MOCK_MODEL_ID_NON_GATED, schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.LOCAL_CONTAINER, model_path=MOCK_MODEL_PATH, @@ -192,6 +193,50 @@ def test_tgi_builder_optimized_sagemaker_endpoint_mode_no_s3_upload_success( # verify that if optimized, no s3 upload occurs builder._prepare_for_mode.assert_called_with() + @patch( + "sagemaker.serve.builder.tgi_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tgi_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.model_builder.get_huggingface_model_metadata", + return_value={"pipeline_tag": "text-generation"}, + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_model_config_properties_from_hf", + return_value=({}, None), + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_default_tgi_configurations", + return_value=({}, None), + ) + def test_tgi_builder_in_process_mode( + self, + mock_default_tgi_configurations, + mock_hf_model_config, + mock_hf_model_md, + mock_get_nb_instance, + mock_telemetry, + ): + # verify IN_PROCESS deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID_GATED, schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.IN_PROCESS + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder.modes[str(Mode.IN_PROCESS)] = MagicMock() + + model.deploy() + + # verify SAGEMAKER_ENDPOINT overwritten deploy + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + # verify that if optimized, no s3 upload occurs + builder._prepare_for_mode.assert_called_with() + @patch( "sagemaker.serve.builder.tgi_builder._get_nb_instance", return_value="ml.g5.24xlarge", @@ -238,7 +283,7 @@ def test_tgi_builder_tune_success( mock_concurrent_benchmark.side_effect = [(10, 10), (50, 5)] builder = ModelBuilder( - model=MOCK_MODEL_ID, + model=MOCK_MODEL_ID_NON_GATED, schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.LOCAL_CONTAINER, model_path=MOCK_MODEL_PATH, diff --git a/tests/unit/sagemaker/serve/mode/test_in_process_mode.py b/tests/unit/sagemaker/serve/mode/test_in_process_mode.py index 0b1747029a..29d625dbbc 100644 --- a/tests/unit/sagemaker/serve/mode/test_in_process_mode.py +++ b/tests/unit/sagemaker/serve/mode/test_in_process_mode.py @@ -17,7 +17,6 @@ from sagemaker.serve.mode.in_process_mode import InProcessMode from sagemaker.serve import SchemaBuilder -from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.utils.exceptions import InProcessDeepPingException @@ -25,6 +24,7 @@ mock_response = "Hello, I'm a language model, and I'm here to help you with your English." mock_sample_input = {"inputs": mock_prompt, "parameters": {}} mock_sample_output = [{"generated_text": mock_response}] +mock_model = "gpt2" class TestInProcessMode(unittest.TestCase): @@ -32,7 +32,7 @@ class TestInProcessMode(unittest.TestCase): @patch("sagemaker.serve.mode.in_process_mode.Path") @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") @patch("sagemaker.session.Session") - def test_load_happy(self, mock_session, mock_inference_spec, mock_path): + def test_load_happy_transformers(self, mock_session, mock_inference_spec, mock_path): mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True @@ -40,8 +40,35 @@ def test_load_happy(self, mock_session, mock_inference_spec, mock_path): mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) in_process_mode = InProcessMode( - model_server=ModelServer.MMS, inference_spec=mock_inference_spec, + model=mock_model, + schema_builder=mock_schema_builder, + session=mock_session, + model_path="model_path", + env_vars={"key": "val"}, + ) + + res = in_process_mode.load(model_path="/tmp/model-builder/code/") + + self.assertEqual(res, "Dummy load") + self.assertEqual(in_process_mode.inference_spec, mock_inference_spec) + self.assertEqual(in_process_mode.schema_builder, mock_schema_builder) + self.assertEqual(in_process_mode.model_path, "model_path") + self.assertEqual(in_process_mode.env_vars, {"key": "val"}) + + @patch("sagemaker.serve.mode.in_process_mode.Path") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + @patch("sagemaker.session.Session") + def test_load_happy_djl_serving(self, mock_session, mock_inference_spec, mock_path): + mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True + mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True + + mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" + + mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) + in_process_mode = InProcessMode( + inference_spec=mock_inference_spec, + model=mock_model, schema_builder=mock_schema_builder, session=mock_session, model_path="model_path", @@ -67,8 +94,8 @@ def test_load_ex(self, mock_session, mock_inference_spec, mock_path): mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) in_process_mode = InProcessMode( - model_server=ModelServer.MMS, inference_spec=mock_inference_spec, + model=mock_model, schema_builder=mock_schema_builder, session=mock_session, model_path="model_path", @@ -82,8 +109,8 @@ def test_load_ex(self, mock_session, mock_inference_spec, mock_path): mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) in_process_mode = InProcessMode( - model_server=ModelServer.MMS, inference_spec=mock_inference_spec, + model=mock_model, schema_builder=mock_schema_builder, session=mock_session, model_path="model_path", @@ -112,21 +139,19 @@ def test_create_server_happy( ) in_process_mode = InProcessMode( - model_server=ModelServer.MMS, inference_spec=mock_inference_spec, + model=mock_model, schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), session=mock_session, model_path="model_path", ) - in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping + in_process_mode._deep_ping = mock_multi_model_server_deep_ping in_process_mode._start_serving = mock_start_serving in_process_mode.create_server(predictor=mock_predictor) - mock_logger.info.assert_called_once_with( - "Waiting for model server %s to start up...", ModelServer.MMS - ) + mock_logger.info.assert_called_once_with("Waiting for fastapi server to start up...") mock_logger.debug.assert_called_once_with( "Ping health check has passed. Returned %s", str(mock_response) ) @@ -153,20 +178,20 @@ def test_create_server_ex( ) in_process_mode = InProcessMode( - model_server=ModelServer.MMS, inference_spec=mock_inference_spec, + model=mock_model, schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), session=mock_session, model_path="model_path", ) - in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping + in_process_mode._deep_ping = mock_multi_model_server_deep_ping in_process_mode._start_serving = mock_start_serving self.assertRaises(InProcessDeepPingException, in_process_mode.create_server, mock_predictor) @patch( - "sagemaker.serve.model_server.multi_model_server.server.InProcessMultiModelServer._stop_serving" + "sagemaker.serve.model_server.in_process_model_server.in_process_server.InProcessServing._stop_serving" ) @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") @patch("sagemaker.session.Session") @@ -177,8 +202,8 @@ def test_destroy_server( mock_stop_serving, ): in_process_mode = InProcessMode( - model_server=ModelServer.MMS, inference_spec=mock_inference_spec, + model=mock_model, schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), session=mock_session, model_path="model_path", diff --git a/tests/unit/sagemaker/serve/test_app.py b/tests/unit/sagemaker/serve/model_server/in_process_model_server/test_app.py similarity index 60% rename from tests/unit/sagemaker/serve/test_app.py rename to tests/unit/sagemaker/serve/model_server/in_process_model_server/test_app.py index 5b9a2218bb..65ba80c370 100644 --- a/tests/unit/sagemaker/serve/test_app.py +++ b/tests/unit/sagemaker/serve/model_server/in_process_model_server/test_app.py @@ -16,7 +16,7 @@ import pytest from unittest.mock import patch, Mock -from sagemaker.serve.app import InProcessServer +from sagemaker.serve.model_server.in_process_model_server.app import InProcessServer from tests.integ.sagemaker.serve.constants import ( PYTHON_VERSION_IS_NOT_310, ) @@ -29,29 +29,29 @@ class TestAppInProcessServer(unittest.TestCase): PYTHON_VERSION_IS_NOT_310, reason="The goal of these tests are to test the serving components of our feature", ) - @patch("sagemaker.serve.app.threading") - @patch("sagemaker.serve.app.pipeline") - def test_in_process_server_init(self, mock_pipeline, mock_threading): + @patch("sagemaker.serve.model_server.in_process_model_server.app.threading") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + def test_in_process_server_init(self, mock_inference_spec, mock_threading): mock_generator = Mock() mock_generator.side_effect = None - in_process_server = InProcessServer(model_id=mock_model_id) + in_process_server = InProcessServer(inference_spec=mock_inference_spec) in_process_server._generator = mock_generator @pytest.mark.skipif( PYTHON_VERSION_IS_NOT_310, reason="The goal of these test are to test the serving components of our feature", ) - @patch("sagemaker.serve.app.logger") - @patch("sagemaker.serve.app.threading") - @patch("sagemaker.serve.app.pipeline") - def test_start_server(self, mock_pipeline, mock_threading, mock_logger): + @patch("sagemaker.serve.model_server.in_process_model_server.app.logger") + @patch("sagemaker.serve.model_server.in_process_model_server.app.threading") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + def test_start_server(self, mock_inference_spec, mock_threading, mock_logger): mock_generator = Mock() mock_generator.side_effect = None mock_thread = Mock() mock_threading.Thread.return_value = mock_thread - in_process_server = InProcessServer(model_id=mock_model_id) + in_process_server = InProcessServer(inference_spec=mock_inference_spec) in_process_server._generator = mock_generator in_process_server.start_server() @@ -63,27 +63,27 @@ def test_start_server(self, mock_pipeline, mock_threading, mock_logger): PYTHON_VERSION_IS_NOT_310, reason="The goal of these test are to test the serving components of our feature", ) - @patch("sagemaker.serve.app.asyncio") - @patch("sagemaker.serve.app.pipeline") - def test_start_run_async_in_thread(self, mock_pipeline, mock_asyncio): - mock_pipeline.side_effect = None + @patch("sagemaker.serve.model_server.in_process_model_server.app.asyncio") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + def test_start_run_async_in_thread(self, mock_inference_spec, mock_asyncio): + mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" mock_loop = Mock() mock_asyncio.new_event_loop.side_effect = lambda: mock_loop - in_process_server = InProcessServer(model_id=mock_model_id) + in_process_server = InProcessServer(inference_spec=mock_inference_spec) in_process_server._start_run_async_in_thread() mock_asyncio.set_event_loop.assert_called_once_with(mock_loop) mock_loop.run_until_complete.assert_called() - @patch("sagemaker.serve.app.pipeline") - async def test_serve(self, mock_pipeline): - mock_pipeline.side_effect = None + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + async def test_serve(self, mock_inference_spec): + mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" mock_server = Mock() - in_process_server = InProcessServer(model_id=mock_model_id) + in_process_server = InProcessServer(inference_spec=mock_inference_spec) in_process_server.server = mock_server await in_process_server._serve() diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index e1c21cf662..c127d4b5ef 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -402,7 +402,7 @@ def test_pytorchxla_distribution( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, @@ -463,7 +463,7 @@ def test_default_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, @@ -519,7 +519,7 @@ def test_debug_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, @@ -575,7 +575,7 @@ def test_disable_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index e0d172f6e0..b7802f5a6b 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -349,7 +349,7 @@ def test_default_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, @@ -407,7 +407,7 @@ def test_debug_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, @@ -465,7 +465,7 @@ def test_disable_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py index 34a1236a7f..56c6e9966f 100644 --- a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -344,7 +344,7 @@ def test_pytorchxla_distribution( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( pytorch_training_compiler_version, @@ -403,7 +403,7 @@ def test_default_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( pytorch_training_compiler_version, @@ -458,7 +458,7 @@ def test_debug_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( pytorch_training_compiler_version, @@ -513,7 +513,7 @@ def test_disable_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( pytorch_training_compiler_version, diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index ac42bb53ab..54d701ad4e 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -289,7 +289,7 @@ def test_default( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( tensorflow_training_version, @@ -348,7 +348,7 @@ def test_byoc( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( tensorflow_training_version, @@ -399,7 +399,7 @@ def test_debug_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( tensorflow_training_version, @@ -450,7 +450,7 @@ def test_disable_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( tensorflow_training_version, diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 5c273460ee..0ac8cb0888 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -361,7 +361,7 @@ def test_chainer(strftime, time, sagemaker_session, chainer_version, chainer_py_ sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job(chainer_version, chainer_py_version) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 4a584dfae4..2c47356921 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -360,7 +360,7 @@ def test_mxnet( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names actual_train_args = sagemaker_session.method_calls[0][2] job_name = actual_train_args["job_name"] diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 618d0d7ea8..6076d44e90 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -18,10 +18,12 @@ import pytest from mock import ANY, MagicMock, Mock, patch from packaging.version import Version +import tempfile from sagemaker import image_uris from sagemaker.pytorch import defaults from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel +from sagemaker.pytorch.estimator import _get_training_recipe_image_uri from sagemaker.instance_group import InstanceGroup from sagemaker.session_settings import SessionSettings @@ -35,6 +37,8 @@ BUCKET_NAME = "mybucket" INSTANCE_COUNT = 1 INSTANCE_TYPE = "ml.c4.4xlarge" +INSTANCE_TYPE_GPU = "ml.p4d.24xlarge" +INSTANCE_TYPE_TRAINIUM = "ml.trn1.32xlarge" ACCELERATOR_TYPE = "ml.eia.medium" IMAGE_URI = "sagemaker-pytorch" JOB_NAME = "{}-{}".format(IMAGE_URI, TIMESTAMP) @@ -59,6 +63,18 @@ } DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}} +NEURON_RECIPE = ( + "https://raw.githubusercontent.com/aws-neuron/" + "neuronx-distributed-training/refs/heads/main/examples/" + "conf/hf_llama3_8B_config.yaml" +) +RECIPE_GPU_IMAGE = ( + "658645717510.dkr.ecr.us-west-2.amazonaws.com/smdistributed-modelparallel:2.4.1-gpu-py311" +) +RECIPE_NEURON_IMAGE = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training-neuronx:2.1.2-neuronx-py310-sdk2.20.2-ubuntu20.04" +) @pytest.fixture(name="sagemaker_session") @@ -337,7 +353,7 @@ def test_pytorch( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job(pytorch_inference_version, pytorch_inference_py_version) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs @@ -826,3 +842,263 @@ def test_predictor_with_component_name(sagemaker_session, component_name): predictor = PyTorchPredictor("endpoint", sagemaker_session, component_name=component_name) assert predictor._get_component_name() == component_name + + +def test_training_recipe_for_cpu(sagemaker_session): + container_log_level = '"logging.INFO"' + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "explicit_log_dir": "/opt/ml/output/tensorboard", + "checkpoint_dir": "/opt/ml/checkpoints", + }, + "model": { + "data": { + "train_dir": "/opt/ml/input/data/train", + "val_dir": "/opt/ml/input/data/val", + }, + }, + } + + with pytest.raises(ValueError): + PyTorch( + output_path="s3://output_path", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + base_job_name="job", + container_log_level=container_log_level, + training_recipe="training/llama/hf_llama3_8b_seq8k_gpu_p5x16_pretrain", + recipe_overrides=recipe_overrides, + ) + + +@pytest.mark.parametrize( + "recipe, model", + [ + ("hf_llama3_8b_seq8k_gpu_p5x16_pretrain", "llama"), + ("hf_mistral_7b_seq8k_gpu_p5x16_pretrain", "mistral"), + ("hf_mixtral_8x7b_seq8k_gpu_p5x16_pretrain", "mixtral"), + ], +) +def test_training_recipe_for_gpu(sagemaker_session, recipe, model): + container_log_level = '"logging.INFO"' + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "explicit_log_dir": "/opt/ml/output", + "checkpoint_dir": "/opt/ml/checkpoints", + }, + "model": { + "data": { + "train_dir": "/opt/ml/input/data/train", + "val_dir": "/opt/ml/input/data/val", + }, + }, + } + pytorch = PyTorch( + output_path="s3://output_path", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + base_job_name="job", + container_log_level=container_log_level, + training_recipe=f"training/{model}/{recipe}", + recipe_overrides=recipe_overrides, + ) + + assert pytorch.source_dir == "." + assert pytorch.entry_point == f"{model}_pretrain.py" + expected_distribution = { + "torch_distributed": { + "enabled": True, + }, + "smdistributed": { + "modelparallel": { + "enabled": True, + "parameters": { + "placement_strategy": "cluster", + }, + }, + }, + } + assert pytorch.distribution.items() == expected_distribution.items() + + +def test_training_recipe_with_override(sagemaker_session): + container_log_level = '"logging.INFO"' + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "explicit_log_dir": "/opt/ml/output", + "checkpoint_dir": "/opt/ml/checkpoints", + }, + "model": { + "data": { + "train_dir": "/opt/ml/input/data/train", + "val_dir": "/opt/ml/input/data/val", + }, + "model_type": "mistral", + }, + } + pytorch = PyTorch( + output_path="s3://output_path", + role=ROLE, + image_uri=IMAGE_URI, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + base_job_name="job", + container_log_level=container_log_level, + training_recipe="training/llama/hf_llama3_8b_seq8k_gpu_p5x16_pretrain", + recipe_overrides=recipe_overrides, + ) + + assert pytorch.source_dir == "." + assert pytorch.entry_point == "mistral_pretrain.py" + assert pytorch.image_uri == IMAGE_URI + + +def test_training_recipe_gpu_custom_source_dir(sagemaker_session): + container_log_level = '"logging.INFO"' + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "explicit_log_dir": "/opt/ml/output", + "checkpoint_dir": "/opt/ml/checkpoints", + }, + "model": { + "data": { + "train_dir": "/opt/ml/input/data/train", + "val_dir": "/opt/ml/input/data/val", + }, + "model_type": "mistral", + }, + } + source_dir = tempfile.TemporaryDirectory(prefix="source_") + pytorch = PyTorch( + output_path="s3://output_path", + role=ROLE, + image_uri=IMAGE_URI, + source_dir=source_dir.name, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + base_job_name="job", + container_log_level=container_log_level, + training_recipe="training/llama/hf_llama3_8b_seq8k_gpu_p5x16_pretrain", + recipe_overrides=recipe_overrides, + ) + + assert pytorch.source_dir == source_dir.name + assert pytorch.entry_point == "mistral_pretrain.py" + assert pytorch.image_uri == IMAGE_URI + + +def test_training_recipe_for_trainium(sagemaker_session): + container_log_level = '"logging.INFO"' + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "explicit_log_dir": "/opt/ml/output", + }, + "data": { + "train_dir": "/opt/ml/input/data/train", + }, + "model": { + "model_config": "/opt/ml/input/data/train/config.json", + }, + "compiler_cache_url": "s3://s3://output_path/neuron-cache", + } + pytorch = PyTorch( + output_path="s3://output_path", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_TRAINIUM, + base_job_name="job", + container_log_level=container_log_level, + training_recipe=NEURON_RECIPE, + recipe_overrides=recipe_overrides, + ) + + assert pytorch.source_dir == "." + assert pytorch.entry_point == "training_orchestrator.py" + expected_distribution = { + "torch_distributed": { + "enabled": True, + }, + } + assert pytorch.distribution == expected_distribution + + +def test_training_recipe_for_trainium_custom_source_dir(sagemaker_session): + container_log_level = '"logging.INFO"' + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "explicit_log_dir": "/opt/ml/output", + }, + "data": { + "train_dir": "/opt/ml/input/data/train", + }, + "model": { + "model_config": "/opt/ml/input/data/train/config.json", + }, + "compiler_cache_url": "s3://s3://output_path/neuron-cache", + } + source_dir = tempfile.TemporaryDirectory(prefix="source_") + pytorch = PyTorch( + output_path="s3://output_path", + role=ROLE, + source_dir=source_dir.name, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_TRAINIUM, + base_job_name="job", + container_log_level=container_log_level, + training_recipe=NEURON_RECIPE, + recipe_overrides=recipe_overrides, + ) + + assert pytorch.source_dir == source_dir.name + assert pytorch.entry_point == "training_orchestrator.py" + expected_distribution = { + "torch_distributed": { + "enabled": True, + }, + } + assert pytorch.distribution == expected_distribution + + +def test_training_recipe_images_uri(): + gpu_image_cfg = {"framework": "pytorch-smp", "version": "2.4.1", "additional_args": {}} + gpu_image_uri = _get_training_recipe_image_uri(gpu_image_cfg, "us-west-2") + assert gpu_image_uri == RECIPE_GPU_IMAGE + neuron_image_cfg = { + "framework": "hyperpod-recipes-neuron", + "version": "2.1.2", + "additional_args": {}, + } + neuron_image_uri = _get_training_recipe_image_uri(neuron_image_cfg, "us-west-2") + assert neuron_image_uri == RECIPE_NEURON_IMAGE diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index 49b145afca..27ab48d025 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -335,7 +335,7 @@ def test_rl(time, strftime, sagemaker_session, coach_mxnet_version): sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( RLToolkit.COACH.value, coach_mxnet_version, RLFramework.MXNET.value diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index b0df31fee1..c418be4646 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -332,7 +332,7 @@ def test_sklearn(time, strftime, sagemaker_session, sklearn_version): sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job(sklearn_version) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 18eab98149..b694e63fe1 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -330,7 +330,7 @@ def test_xgboost_cpu(time, strftime, sagemaker_session, xgboost_framework_versio sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job(xgboost_framework_version) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs @@ -377,7 +377,7 @@ def test_xgboost_gpu(time, strftime, sagemaker_session, xgboost_gpu_framework_ve sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( xgboost_gpu_framework_version, instance_type=GPU_INSTANCE_TYPE @@ -427,7 +427,7 @@ def test_distributed_training(time, strftime, sagemaker_session, xgboost_framewo sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job(xgboost_framework_version, DIST_INSTANCE_COUNT) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs