diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 60904c51b0..56805ebc7a 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -38,3 +38,4 @@ accelerate>=0.24.1,<=0.27.0 schema==0.7.5 tensorflow>=2.1,<=2.16 mlflow>=2.12.2,<2.13 +huggingface_hub>=0.23.4 diff --git a/setup.py b/setup.py index 9242e69cfd..b9486bbf18 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ def read_requirements(filename): # Declare minimal set for installation required_packages = [ "attrs>=23.1.0,<24", - "boto3>=1.33.3,<2.0", + "boto3>=1.34.142,<2.0", "cloudpickle==2.2.1", "google-pasta", "numpy>=1.9.0,<2.0", diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 0327ef3845..b48adda44c 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -82,6 +82,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default accept type for the model matching the given arguments. @@ -105,6 +106,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default accept type to use for the model. @@ -125,4 +127,5 @@ def retrieve_default( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 3154c1e4fe..16c81d6d77 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -82,6 +82,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default content type for the model matching the given arguments. @@ -105,6 +106,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default content type to use for the model. @@ -125,6 +127,7 @@ def retrieve_default( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 3081daea23..957a9dfb0c 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -102,6 +102,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. @@ -125,6 +126,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseDeserializer: The default deserializer to use for the model. @@ -146,4 +148,5 @@ def retrieve_default( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/enums.py b/src/sagemaker/enums.py index f02b275cbe..f8c618620b 100644 --- a/src/sagemaker/enums.py +++ b/src/sagemaker/enums.py @@ -40,3 +40,12 @@ class RoutingStrategy(Enum): """The endpoint routes requests to the specific instances that have more capacity to process them. """ + + +class Tag(str, Enum): + """Enum class for tag keys to apply to models.""" + + OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name" + SPECULATIVE_DRAFT_MODEL_PROVIDER = "sagemaker-sdk:speculative-draft-model-provider" + FINE_TUNING_MODEL_PATH = "sagemaker-sdk:fine-tuning-model-path" + FINE_TUNING_JOB_NAME = "sagemaker-sdk:fine-tuning-job-name" diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index 57851d112a..173266e1a1 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -37,6 +37,7 @@ def retrieve_default( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, ) -> Dict[str, str]: """Retrieves the default container environment variables for the model matching the arguments. @@ -68,6 +69,7 @@ def retrieve_default( variables specific for the instance type. script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: The variables to use for the model. @@ -91,4 +93,5 @@ def retrieve_default( sagemaker_session=sagemaker_session, instance_type=instance_type, script=script, + config_name=config_name, ) diff --git a/src/sagemaker/huggingface/llm_utils.py b/src/sagemaker/huggingface/llm_utils.py index 9927d1d293..c7a1316760 100644 --- a/src/sagemaker/huggingface/llm_utils.py +++ b/src/sagemaker/huggingface/llm_utils.py @@ -13,7 +13,9 @@ """Functions for generating ECR image URIs for pre-built SageMaker Docker images.""" from __future__ import absolute_import +import os from typing import Optional +import importlib.util import urllib.request from urllib.error import HTTPError, URLError @@ -123,3 +125,26 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = "Did not find model metadata for the following HuggingFace Model ID %s" % model_id ) return hf_model_metadata_json + + +def download_huggingface_model_metadata( + model_id: str, model_local_path: str, hf_hub_token: Optional[str] = None +) -> None: + """Downloads the HuggingFace Model snapshot via HuggingFace API. + + Args: + model_id (str): The HuggingFace Model ID + model_local_path (str): The local path to save the HuggingFace Model snapshot. + hf_hub_token (str): The HuggingFace Hub Token + + Raises: + ImportError: If huggingface_hub is not installed. + """ + if not importlib.util.find_spec("huggingface_hub"): + raise ImportError("Unable to import huggingface_hub, check if huggingface_hub is installed") + + from huggingface_hub import snapshot_download + + os.makedirs(model_local_path, exist_ok=True) + logger.info("Downloading model %s from Hugging Face Hub to %s", model_id, model_local_path) + snapshot_download(repo_id=model_id, local_dir=model_local_path, token=hf_hub_token) diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 49ced478dd..f1353cc8ff 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -37,6 +37,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> Dict[str, str]: """Retrieves the default training hyperparameters for the model matching the given arguments. @@ -69,6 +70,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: The hyperparameters to use for the model. @@ -90,6 +92,7 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 22328c4183..65497927e9 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -71,6 +71,7 @@ def retrieve( inference_tool=None, serverless_inference_config=None, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name=None, ) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. @@ -126,6 +127,7 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The ECR URI for the corresponding SageMaker Docker image. @@ -166,6 +168,7 @@ def retrieve( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]): diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 66e8e5127f..1b664fc9ae 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -37,6 +37,7 @@ def retrieve_default( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default instance type for the model matching the given arguments. @@ -67,6 +68,7 @@ def retrieve_default( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default instance type to use for the model. @@ -92,6 +94,7 @@ def retrieve_default( sagemaker_session=sagemaker_session, training_instance_type=training_instance_type, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index 5f00783ed3..f10bfe4a5d 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -40,6 +40,7 @@ def _retrieve_default_environment_variables( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, ) -> Dict[str, str]: """Retrieves the inference environment variables for the model matching the given arguments. @@ -71,6 +72,7 @@ def _retrieve_default_environment_variables( environment variables specific for the instance type. script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the inference environment variables to use for the model. """ @@ -88,6 +90,7 @@ def _retrieve_default_environment_variables( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) default_environment_variables: Dict[str, str] = {} @@ -126,6 +129,7 @@ def _retrieve_default_environment_variables( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, instance_type=instance_type, + config_name=config_name, ) ) @@ -173,6 +177,7 @@ def _retrieve_gated_model_uri_env_var_value( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> Optional[str]: """Retrieves the gated model env var URI matching the given arguments. @@ -198,6 +203,7 @@ def _retrieve_gated_model_uri_env_var_value( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get environment variables specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: Optional[str]: the s3 URI to use for the environment variable, or None if the model does not @@ -220,6 +226,7 @@ def _retrieve_gated_model_uri_env_var_value( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) s3_key: Optional[str] = ( diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index 308c3a5386..4383a17bf9 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -37,6 +37,7 @@ def _retrieve_default_hyperparameters( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ): """Retrieves the training hyperparameters for the model matching the given arguments. @@ -69,6 +70,7 @@ def _retrieve_default_hyperparameters( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get hyperparameters specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the hyperparameters to use for the model. """ @@ -86,6 +88,7 @@ def _retrieve_default_hyperparameters( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) default_hyperparameters: Dict[str, str] = {} diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 4f34cdd1e2..0d4a61d112 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -47,6 +47,7 @@ def _retrieve_image_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ): """Retrieves the container image URI for JumpStart models. @@ -98,6 +99,7 @@ def _retrieve_image_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the ECR URI for the corresponding SageMaker Docker image. @@ -120,6 +122,7 @@ def _retrieve_image_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if image_scope == JumpStartScriptScope.INFERENCE: @@ -213,4 +216,5 @@ def _retrieve_image_uri( distribution=distribution, base_framework_version=base_framework_version_override or base_framework_version, training_compiler_config=training_compiler_config, + config_name=config_name, ) diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 17328c44e0..80b5aa8ef5 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -34,6 +34,7 @@ def _model_supports_incremental_training( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> bool: """Returns True if the model supports incremental training. @@ -57,6 +58,7 @@ def _model_supports_incremental_training( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: bool: the support status for incremental training. """ @@ -74,6 +76,7 @@ def _model_supports_incremental_training( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.supports_incremental_training() diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 4c9e8075c5..25119266cf 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -41,6 +41,7 @@ def _retrieve_default_instance_type( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default instance type for the model. @@ -71,6 +72,7 @@ def _retrieve_default_instance_type( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default instance type to use for the model or None. @@ -93,6 +95,7 @@ def _retrieve_default_instance_type( tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: @@ -133,6 +136,7 @@ def _retrieve_instance_types( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported instance types for the model. @@ -163,6 +167,7 @@ def _retrieve_instance_types( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported instance types to use for the model or None. @@ -184,6 +189,7 @@ def _retrieve_instance_types( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 84c26bdda2..eb7980b88f 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -38,6 +38,7 @@ def _retrieve_model_init_kwargs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Model`. @@ -61,6 +62,7 @@ def _retrieve_model_init_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. """ @@ -79,6 +81,7 @@ def _retrieve_model_init_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) kwargs = deepcopy(model_specs.model_kwargs) @@ -99,6 +102,7 @@ def _retrieve_model_deploy_kwargs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Model.deploy`. @@ -124,6 +128,7 @@ def _retrieve_model_deploy_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. @@ -143,6 +148,7 @@ def _retrieve_model_deploy_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None: @@ -160,6 +166,7 @@ def _retrieve_estimator_init_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Estimator`. @@ -185,6 +192,7 @@ def _retrieve_estimator_init_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. """ @@ -202,6 +210,7 @@ def _retrieve_estimator_init_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) kwargs = deepcopy(model_specs.estimator_kwargs) @@ -223,6 +232,7 @@ def _retrieve_estimator_fit_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Estimator.fit`. @@ -246,6 +256,7 @@ def _retrieve_estimator_fit_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. @@ -264,6 +275,7 @@ def _retrieve_estimator_fit_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.fit_kwargs diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 901f5cc455..5e5c0d79a0 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -36,6 +36,7 @@ def _retrieve_default_training_metric_definitions( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model. @@ -61,6 +62,7 @@ def _retrieve_default_training_metric_definitions( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get metric definitions specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the default training metric definitions to use for the model or None. """ @@ -78,6 +80,7 @@ def _retrieve_default_training_metric_definitions( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) default_metric_definitions = ( diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index b1f931eac4..7aa5be7507 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -38,6 +38,7 @@ def _retrieve_model_package_arn( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. @@ -63,6 +64,7 @@ def _retrieve_model_package_arn( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the model package arn to use for the model or None. @@ -82,6 +84,7 @@ def _retrieve_model_package_arn( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: @@ -97,7 +100,10 @@ def _retrieve_model_package_arn( if instance_specific_arn is not None: return instance_specific_arn - if model_specs.hosting_model_package_arns is None: + if ( + model_specs.hosting_model_package_arns is None + or model_specs.hosting_model_package_arns == {} + ): return None regional_arn = model_specs.hosting_model_package_arns.get(region) @@ -123,6 +129,7 @@ def _retrieve_model_package_model_artifact_s3_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> Optional[str]: """Retrieves s3 artifact uri associated with model package. @@ -148,6 +155,7 @@ def _retrieve_model_package_model_artifact_s3_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the model package artifact uri to use for the model or None. @@ -170,6 +178,7 @@ def _retrieve_model_package_model_artifact_s3_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if model_specs.training_model_package_artifact_uris is None: diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 2cebacb9c0..5fac979b14 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -96,6 +96,7 @@ def _retrieve_model_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -123,6 +124,8 @@ def _retrieve_model_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + Returns: str: the model artifact S3 URI for the corresponding model. @@ -145,6 +148,7 @@ def _retrieve_model_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) model_artifact_key: str @@ -190,6 +194,7 @@ def _model_supports_training_model_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> bool: """Returns True if the model supports training with model uri field. @@ -213,6 +218,7 @@ def _model_supports_training_model_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: bool: the support status for model uri with training. """ @@ -230,6 +236,7 @@ def _model_supports_training_model_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.use_training_model_artifact() diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 41c9c93ad2..c217495ede 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -38,6 +38,7 @@ def _retrieve_example_payloads( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Optional[Dict[str, JumpStartSerializablePayload]]: """Returns example payloads. @@ -61,6 +62,7 @@ def _retrieve_example_payloads( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: Optional[Dict[str, JumpStartSerializablePayload]]: dictionary mapping payload aliases to the serializable payload object. @@ -80,6 +82,7 @@ def _retrieve_example_payloads( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_payloads = model_specs.default_payloads diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 96d1c1f7fb..352a4384f8 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -79,6 +79,7 @@ def _retrieve_default_deserializer( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -101,6 +102,7 @@ def _retrieve_default_deserializer( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseDeserializer: the default deserializer to use for the model. @@ -115,6 +117,7 @@ def _retrieve_default_deserializer( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type)) @@ -129,6 +132,7 @@ def _retrieve_default_serializer( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -151,6 +155,7 @@ def _retrieve_default_serializer( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseSerializer: the default serializer to use for the model. """ @@ -164,6 +169,7 @@ def _retrieve_default_serializer( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type)) @@ -178,6 +184,7 @@ def _retrieve_deserializer_options( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -200,6 +207,7 @@ def _retrieve_deserializer_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[BaseDeserializer]: the supported deserializers to use for the model. """ @@ -213,6 +221,7 @@ def _retrieve_deserializer_options( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) seen_classes: Set[Type] = set() @@ -240,6 +249,7 @@ def _retrieve_serializer_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model. @@ -262,6 +272,7 @@ def _retrieve_serializer_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[BaseSerializer]: the supported serializers to use for the model. """ @@ -274,6 +285,7 @@ def _retrieve_serializer_options( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) seen_classes: Set[Type] = set() @@ -302,6 +314,7 @@ def _retrieve_default_content_type( tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieves the default content type for the model. @@ -324,6 +337,7 @@ def _retrieve_default_content_type( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default content type to use for the model. """ @@ -342,6 +356,7 @@ def _retrieve_default_content_type( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_content_type = model_specs.predictor_specs.default_content_type @@ -357,6 +372,7 @@ def _retrieve_default_accept_type( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default accept type for the model. @@ -379,6 +395,7 @@ def _retrieve_default_accept_type( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default accept type to use for the model. """ @@ -397,6 +414,7 @@ def _retrieve_default_accept_type( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_accept_type = model_specs.predictor_specs.default_accept_type @@ -413,6 +431,7 @@ def _retrieve_supported_accept_types( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported accept types for the model. @@ -435,6 +454,7 @@ def _retrieve_supported_accept_types( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported accept types to use for the model. """ @@ -453,6 +473,7 @@ def _retrieve_supported_accept_types( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) supported_accept_types = model_specs.predictor_specs.supported_accept_types @@ -469,6 +490,7 @@ def _retrieve_supported_content_types( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported content types for the model. @@ -491,6 +513,7 @@ def _retrieve_supported_content_types( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported content types to use for the model. """ @@ -509,6 +532,7 @@ def _retrieve_supported_content_types( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) supported_content_types = model_specs.predictor_specs.supported_content_types diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 0b92d46a23..8c47750061 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -36,6 +36,8 @@ def _retrieve_resource_name_base( tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + scope: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, ) -> bool: """Returns default resource name. @@ -59,6 +61,7 @@ def _retrieve_resource_name_base( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config. (Default: None). Returns: str: the default resource name. """ @@ -71,12 +74,13 @@ def _retrieve_resource_name_base( model_id=model_id, version=model_version, hub_arn=hub_arn, - scope=JumpStartScriptScope.INFERENCE, + scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.resource_name_base diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 8936a3f824..74523be1de 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -55,6 +55,7 @@ def _retrieve_default_resources( model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> ResourceRequirements: """Retrieves the default resource requirements for the model. @@ -82,6 +83,7 @@ def _retrieve_default_resources( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get host requirements specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default resource requirements to use for the model or None. @@ -106,6 +108,7 @@ def _retrieve_default_resources( tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index 3c79f93985..5029f53cfb 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -38,6 +38,7 @@ def _retrieve_script_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ): """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -65,6 +66,7 @@ def _retrieve_script_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the model script URI for the corresponding model. @@ -87,6 +89,7 @@ def _retrieve_script_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if script_scope == JumpStartScriptScope.INFERENCE: @@ -113,6 +116,7 @@ def _model_supports_inference_script_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> bool: """Returns True if the model supports inference with script uri field. @@ -153,6 +157,7 @@ def _model_supports_inference_script_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.use_inference_script_uri() diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 076e1d2fa1..02b3f1836d 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -44,31 +44,37 @@ region_name="us-west-2", content_bucket="jumpstart-cache-prod-us-west-2", gated_content_bucket="jumpstart-private-cache-prod-us-west-2", + neo_content_bucket="sagemaker-sd-models-prod-us-west-2", ), JumpStartLaunchedRegionInfo( region_name="us-east-1", content_bucket="jumpstart-cache-prod-us-east-1", gated_content_bucket="jumpstart-private-cache-prod-us-east-1", + neo_content_bucket="sagemaker-sd-models-prod-us-east-1", ), JumpStartLaunchedRegionInfo( region_name="us-east-2", content_bucket="jumpstart-cache-prod-us-east-2", gated_content_bucket="jumpstart-private-cache-prod-us-east-2", + neo_content_bucket="sagemaker-sd-models-prod-us-east-2", ), JumpStartLaunchedRegionInfo( region_name="eu-west-1", content_bucket="jumpstart-cache-prod-eu-west-1", gated_content_bucket="jumpstart-private-cache-prod-eu-west-1", + neo_content_bucket="sagemaker-sd-models-prod-eu-west-1", ), JumpStartLaunchedRegionInfo( region_name="eu-central-1", content_bucket="jumpstart-cache-prod-eu-central-1", gated_content_bucket="jumpstart-private-cache-prod-eu-central-1", + neo_content_bucket="sagemaker-sd-models-prod-eu-central-1", ), JumpStartLaunchedRegionInfo( region_name="eu-north-1", content_bucket="jumpstart-cache-prod-eu-north-1", gated_content_bucket="jumpstart-private-cache-prod-eu-north-1", + neo_content_bucket="sagemaker-sd-models-prod-eu-north-1", ), JumpStartLaunchedRegionInfo( region_name="me-south-1", @@ -84,11 +90,13 @@ region_name="ap-south-1", content_bucket="jumpstart-cache-prod-ap-south-1", gated_content_bucket="jumpstart-private-cache-prod-ap-south-1", + neo_content_bucket="sagemaker-sd-models-prod-ap-south-1", ), JumpStartLaunchedRegionInfo( region_name="eu-west-3", content_bucket="jumpstart-cache-prod-eu-west-3", gated_content_bucket="jumpstart-private-cache-prod-eu-west-3", + neo_content_bucket="sagemaker-sd-models-prod-eu-west-3", ), JumpStartLaunchedRegionInfo( region_name="af-south-1", @@ -99,6 +107,7 @@ region_name="sa-east-1", content_bucket="jumpstart-cache-prod-sa-east-1", gated_content_bucket="jumpstart-private-cache-prod-sa-east-1", + neo_content_bucket="sagemaker-sd-models-prod-sa-east-1", ), JumpStartLaunchedRegionInfo( region_name="ap-east-1", @@ -109,21 +118,25 @@ region_name="ap-northeast-2", content_bucket="jumpstart-cache-prod-ap-northeast-2", gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-2", + neo_content_bucket="sagemaker-sd-models-prod-ap-northeast-2", ), JumpStartLaunchedRegionInfo( region_name="ap-northeast-3", content_bucket="jumpstart-cache-prod-ap-northeast-3", gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-3", + neo_content_bucket="sagemaker-sd-models-prod-ap-northeast-3", ), JumpStartLaunchedRegionInfo( region_name="ap-southeast-3", content_bucket="jumpstart-cache-prod-ap-southeast-3", gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-3", + neo_content_bucket="sagemaker-sd-models-prod-ap-southeast-3", ), JumpStartLaunchedRegionInfo( region_name="eu-west-2", content_bucket="jumpstart-cache-prod-eu-west-2", gated_content_bucket="jumpstart-private-cache-prod-eu-west-2", + neo_content_bucket="sagemaker-sd-models-prod-eu-west-2", ), JumpStartLaunchedRegionInfo( region_name="eu-south-1", @@ -134,26 +147,31 @@ region_name="ap-northeast-1", content_bucket="jumpstart-cache-prod-ap-northeast-1", gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-1", + neo_content_bucket="sagemaker-sd-models-prod-ap-northeast-1", ), JumpStartLaunchedRegionInfo( region_name="us-west-1", content_bucket="jumpstart-cache-prod-us-west-1", gated_content_bucket="jumpstart-private-cache-prod-us-west-1", + neo_content_bucket="sagemaker-sd-models-prod-us-west-1", ), JumpStartLaunchedRegionInfo( region_name="ap-southeast-1", content_bucket="jumpstart-cache-prod-ap-southeast-1", gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-1", + neo_content_bucket="sagemaker-sd-models-prod-ap-southeast-1", ), JumpStartLaunchedRegionInfo( region_name="ap-southeast-2", content_bucket="jumpstart-cache-prod-ap-southeast-2", gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-2", + neo_content_bucket="sagemaker-sd-models-prod-ap-southeast-2", ), JumpStartLaunchedRegionInfo( region_name="ca-central-1", content_bucket="jumpstart-cache-prod-ca-central-1", gated_content_bucket="jumpstart-private-cache-prod-ca-central-1", + neo_content_bucket="sagemaker-sd-models-prod-ca-central-1", ), JumpStartLaunchedRegionInfo( region_name="cn-north-1", @@ -184,6 +202,9 @@ ) JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" +NEO_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" + +JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub" JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub" @@ -206,6 +227,7 @@ "AWS_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE" ) ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE = "AWS_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE" +ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE = "AWS_NEO_CONTENT_BUCKET_OVERRIDE" JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart" diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index 6c77e72b9b..a83964e394 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -94,6 +94,9 @@ class JumpStartTag(str, Enum): MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" + INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name" + TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name" + HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 7d600ddfbc..8b30317a52 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -34,8 +34,10 @@ from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job +from sagemaker.jumpstart.session_utils import get_model_info_from_training_job +from sagemaker.jumpstart.types import JumpStartMetadataConfig from sagemaker.jumpstart.utils import ( + get_jumpstart_configs, validate_model_id_and_get_type, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, @@ -111,6 +113,7 @@ def __init__( container_arguments: Optional[List[str]] = None, disable_output_compression: Optional[bool] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ): """Initializes a ``JumpStartEstimator``. @@ -504,6 +507,8 @@ def __init__( to Amazon S3 without compression after training finishes. enable_remote_debug (bool or PipelineVariable): Optional. Specifies whether RemoteDebug is enabled for the training job + config_name (Optional[str]): + Name of the training configuration to apply to the Estimator. (Default: None). enable_session_tag_chaining (bool or PipelineVariable): Optional. Specifies whether SessionTagChaining is enabled for the training job @@ -592,6 +597,7 @@ def _validate_model_id_and_get_type_hook(): disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, + config_name=config_name, enable_session_tag_chaining=enable_session_tag_chaining, ) @@ -607,6 +613,8 @@ def _validate_model_id_and_get_type_hook(): self.role = estimator_init_kwargs.role self.sagemaker_session = estimator_init_kwargs.sagemaker_session self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation + self.config_name = estimator_init_kwargs.config_name + self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False) super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict()) @@ -682,6 +690,7 @@ def fit( tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, sagemaker_session=self.sagemaker_session, + config_name=self.config_name, ) return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) @@ -695,6 +704,7 @@ def attach( hub_arn: Optional[str] = None, sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_channel_name: str = "model", + config_name: Optional[str] = None, ) -> "JumpStartEstimator": """Attach to an existing training job. @@ -730,6 +740,8 @@ def attach( model data will be downloaded (default: 'model'). If no channel with the same name exists in the training job, this option will be ignored. + config_name (str): Optional. Name of the training configuration to use + when attaching to the training job. (Default: None). Returns: Instance of the calling ``JumpStartEstimator`` Class with the attached @@ -739,10 +751,9 @@ def attach( ValueError: if the model ID or version cannot be inferred from the training job. """ - + config_name = None if model_id is None: - - model_id, model_version = get_model_id_version_from_training_job( + model_id, model_version, _, config_name = get_model_info_from_training_job( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) @@ -755,6 +766,9 @@ def attach( "tolerate_deprecated_model": True, # model is already trained } + if config_name: + additional_kwargs.update({"config_name": config_name}) + model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, @@ -764,6 +778,7 @@ def attach( tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable sagemaker_session=sagemaker_session, + config_name=config_name, ) # eula was already accepted if the model was successfully trained @@ -813,6 +828,7 @@ def deploy( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, use_compiled_model: bool = False, + inference_config_name: Optional[str] = None, ) -> PredictorBase: """Creates endpoint from training job. @@ -1048,6 +1064,8 @@ def deploy( (Default: None). use_compiled_model (bool): Flag to select whether to use compiled (optimized) model. (Default: False). + inference_config_name (Optional[str]): Name of the inference configuration to + be used in the model. (Default: None). """ self.orig_predictor_cls = predictor_cls @@ -1101,6 +1119,8 @@ def deploy( git_config=git_config, use_compiled_model=use_compiled_model, training_instance_type=self.instance_type, + training_config_name=self.config_name, + inference_config_name=inference_config_name, ) predictor = super(JumpStartEstimator, self).deploy( @@ -1118,11 +1138,43 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + config_name=estimator_deploy_kwargs.config_name, ) # If a predictor class was passed, do not mutate predictor return predictor + def list_training_configs(self) -> List[JumpStartMetadataConfig]: + """Returns a list of configs associated with the estimator. + + Raises: + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + configs_dict = get_jumpstart_configs( + model_id=self.model_id, + model_version=self.model_version, + model_type=self.model_type, + region=self.region, + scope=JumpStartScriptScope.TRAINING, + sagemaker_session=self.sagemaker_session, + ) + return list(configs_dict.values()) + + def set_training_config(self, config_name: str) -> None: + """Sets the config to apply to the model. + + Args: + config_name (str): The name of the config. + """ + self.__init__( + model_id=self.model_id, + model_version=self.model_version, + config_name=config_name, + ) + def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" return stringify_object(self) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index d3e597c395..8540f53ca4 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -61,7 +61,7 @@ ) from sagemaker.jumpstart.utils import ( add_hub_content_arn_tags, - add_jumpstart_model_id_version_tags, + add_jumpstart_model_info_tags, get_eula_message, get_default_jumpstart_session_with_user_agent_suffix, update_dict_if_key_not_present, @@ -132,6 +132,7 @@ def get_init_kwargs( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" @@ -192,6 +193,7 @@ def get_init_kwargs( disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, + config_name=config_name, enable_session_tag_chaining=enable_session_tag_chaining, ) @@ -210,6 +212,7 @@ def get_init_kwargs( estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_tags_to_kwargs(estimator_init_kwargs) + estimator_init_kwargs = _add_config_name_to_kwargs(estimator_init_kwargs) return estimator_init_kwargs @@ -227,6 +230,7 @@ def get_fit_kwargs( tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, + config_name: Optional[str] = None, ) -> JumpStartEstimatorFitKwargs: """Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object.""" @@ -243,6 +247,7 @@ def get_fit_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) estimator_fit_kwargs = _add_model_version_to_kwargs(estimator_fit_kwargs) @@ -296,6 +301,8 @@ def get_deploy_kwargs( use_compiled_model: Optional[bool] = None, model_name: Optional[str] = None, training_instance_type: Optional[str] = None, + training_config_name: Optional[str] = None, + inference_config_name: Optional[str] = None, ) -> JumpStartEstimatorDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object.""" @@ -325,6 +332,8 @@ def get_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + training_config_name=training_config_name, + config_name=inference_config_name, ) model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs( @@ -359,6 +368,7 @@ def get_deploy_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, training_instance_type=training_instance_type, disable_instance_type_logging=True, + config_name=model_deploy_kwargs.config_name, ) estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs( @@ -404,6 +414,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=model_init_kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=model_init_kwargs.tolerate_deprecated_model, use_compiled_model=use_compiled_model, + config_name=model_deploy_kwargs.config_name, ) return estimator_deploy_kwargs @@ -479,6 +490,7 @@ def _add_instance_type_and_count_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) kwargs.instance_count = kwargs.instance_count or 1 @@ -503,11 +515,16 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: - kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version + kwargs.tags = add_jumpstart_model_info_tags( + kwargs.tags, + kwargs.model_id, + full_model_version, + config_name=kwargs.config_name, + scope=JumpStartScriptScope.TRAINING, ) if kwargs.hub_arn: @@ -530,6 +547,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) return kwargs @@ -556,6 +574,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE sagemaker_session=kwargs.sagemaker_session, region=kwargs.region, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) if ( @@ -569,6 +588,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) ): JUMPSTART_LOGGER.warning( @@ -605,6 +625,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, region=kwargs.region, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) return kwargs @@ -626,6 +647,7 @@ def _add_env_to_kwargs( sagemaker_session=kwargs.sagemaker_session, script=JumpStartScriptScope.TRAINING, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( @@ -637,6 +659,7 @@ def _add_env_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) if model_package_artifact_uri: @@ -665,6 +688,7 @@ def _add_env_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) if model_specs.is_gated_model(): raise ValueError( @@ -695,9 +719,11 @@ def _add_training_job_name_to_kwargs( model_version=kwargs.model_version, hub_arn=kwargs.hub_arn, region=kwargs.region, + scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) kwargs.job_name = kwargs.job_name or ( @@ -725,6 +751,7 @@ def _add_hyperparameters_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) for key, value in default_hyperparameters.items(): @@ -759,6 +786,7 @@ def _add_metric_definitions_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) or [] ) @@ -789,6 +817,7 @@ def _add_estimator_extra_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) for key, value in estimator_kwargs_to_add.items(): @@ -814,6 +843,7 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) for key, value in fit_kwargs_to_add.items(): @@ -821,3 +851,27 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim setattr(kwargs, key, value) return kwargs + + +def _add_config_name_to_kwargs( + kwargs: JumpStartEstimatorInitKwargs, +) -> JumpStartEstimatorInitKwargs: + """Sets tags in kwargs based on default or override, returns full kwargs.""" + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.TRAINING, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, + ) + + if specs.training_configs and specs.training_configs.get_top_config_from_ranking(): + kwargs.config_name = ( + kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name + ) + + return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index a231ab917c..61fcff242f 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -44,11 +44,13 @@ JumpStartModelDeployKwargs, JumpStartModelInitKwargs, JumpStartModelRegisterKwargs, + JumpStartModelSpecs, ) from sagemaker.jumpstart.utils import ( add_hub_content_arn_tags, - add_jumpstart_model_id_version_tags, + add_jumpstart_model_info_tags, get_default_jumpstart_session_with_user_agent_suffix, + get_neo_content_bucket, update_dict_if_key_not_present, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, @@ -60,7 +62,7 @@ from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig from sagemaker.session import Session -from sagemaker.utils import name_from_base, format_tags, Tags +from sagemaker.utils import camel_case_to_pascal_case, name_from_base, format_tags, Tags from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker import resource_requirements @@ -77,6 +79,7 @@ def get_default_predictor( tolerate_deprecated_model: bool, sagemaker_session: Session, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -100,6 +103,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.deserializer = deserializers.retrieve_default( model_id=model_id, @@ -110,6 +114,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.accept = accept_types.retrieve_default( model_id=model_id, @@ -120,6 +125,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.content_type = content_types.retrieve_default( model_id=model_id, @@ -130,6 +136,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return predictor @@ -212,7 +219,6 @@ def _add_instance_type_to_kwargs( """Sets instance type based on default or override, returns full kwargs.""" orig_instance_type = kwargs.instance_type - kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default( region=kwargs.region, model_id=kwargs.model_id, @@ -224,6 +230,7 @@ def _add_instance_type_to_kwargs( sagemaker_session=kwargs.sagemaker_session, training_instance_type=kwargs.training_instance_type, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) if not disable_instance_type_logging and orig_instance_type is None: @@ -256,6 +263,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) return kwargs @@ -304,6 +312,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"): @@ -345,6 +354,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ): source_dir = source_dir or script_uris.retrieve( script_scope=JumpStartScriptScope.INFERENCE, @@ -355,6 +365,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) kwargs.source_dir = source_dir @@ -379,6 +390,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ): entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME @@ -411,6 +423,7 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw sagemaker_session=kwargs.sagemaker_session, script=JumpStartScriptScope.INFERENCE, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) for key, value in extra_env_vars.items(): @@ -442,6 +455,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) kwargs.model_package_arn = model_package_arn @@ -460,6 +474,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) for key, value in model_kwargs_to_add.items(): @@ -497,6 +512,7 @@ def _add_endpoint_name_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) kwargs.endpoint_name = kwargs.endpoint_name or ( @@ -520,6 +536,7 @@ def _add_model_name_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) kwargs.name = kwargs.name or ( @@ -542,11 +559,17 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: - kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type + kwargs.tags = add_jumpstart_model_info_tags( + kwargs.tags, + kwargs.model_id, + full_model_version, + kwargs.model_type, + config_name=kwargs.config_name, + scope=JumpStartScriptScope.INFERENCE, ) if kwargs.hub_arn: @@ -568,6 +591,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) for key, value in deploy_kwargs_to_add.items(): @@ -591,11 +615,142 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, + ) + + return kwargs + + +def _select_inference_config_from_training_config( + specs: JumpStartModelSpecs, training_config_name: str +) -> Optional[str]: + """Selects the inference config from the training config. + + Args: + specs (JumpStartModelSpecs): The specs for the model. + training_config_name (str): The name of the training config. + + Returns: + str: The name of the inference config. + """ + if specs.training_configs: + resolved_training_config = specs.training_configs.configs.get(training_config_name) + if resolved_training_config: + return resolved_training_config.default_inference_config + + return None + + +def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: + """Sets default config name to the kwargs. Returns full kwargs. + + Raises: + ValueError: If the instance_type is not supported with the current config. + """ + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + config_name=kwargs.config_name, + ) + if specs.inference_configs: + default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name + kwargs.config_name = kwargs.config_name or default_config_name + + if not kwargs.config_name: + return kwargs + + if kwargs.config_name not in set(specs.inference_configs.configs.keys()): + raise ValueError( + f"Config {kwargs.config_name} is not supported for model {kwargs.model_id}." + ) + + resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config + supported_instance_types = resolved_config.get("supported_inference_instance_types", []) + if kwargs.instance_type not in supported_instance_types: + JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) + return kwargs + + +def _add_additional_model_data_sources_to_kwargs( + kwargs: JumpStartModelInitKwargs, +) -> JumpStartModelInitKwargs: + """Sets default additional model data sources to init kwargs""" + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + config_name=kwargs.config_name, + ) + # Append speculative decoding data source from metadata + speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources() + for data_source in speculative_decoding_data_sources: + data_source.s3_data_source.set_bucket(get_neo_content_bucket(region=kwargs.region)) + api_shape_additional_model_data_sources = ( + [ + camel_case_to_pascal_case(data_source.to_json()) + for data_source in speculative_decoding_data_sources + ] + if specs.get_speculative_decoding_s3_data_sources() + else None + ) + + kwargs.additional_model_data_sources = ( + kwargs.additional_model_data_sources or api_shape_additional_model_data_sources ) return kwargs +def _add_config_name_to_deploy_kwargs( + kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None +) -> JumpStartModelInitKwargs: + """Sets default config name to the kwargs. Returns full kwargs. + + If a training_config_name is passed, then choose the inference config + based on the supported inference configs in that training config. + + Raises: + ValueError: If the instance_type is not supported with the current config. + """ + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + config_name=kwargs.config_name, + ) + + if training_config_name: + kwargs.config_name = _select_inference_config_from_training_config( + specs=specs, training_config_name=training_config_name + ) + return kwargs + + if specs.inference_configs: + default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name + kwargs.config_name = kwargs.config_name or default_config_name + + return kwargs + + def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, @@ -629,6 +784,8 @@ def get_deploy_kwargs( resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: Optional[EndpointType] = None, + training_config_name: Optional[str] = None, + config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -664,6 +821,7 @@ def get_deploy_kwargs( model_reference_arn=model_reference_arn, endpoint_logging=endpoint_logging, resources=resources, + config_name=config_name, routing_config=routing_config, ) @@ -673,6 +831,10 @@ def get_deploy_kwargs( deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_config_name_to_deploy_kwargs( + kwargs=deploy_kwargs, training_config_name=training_config_name + ) + deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs.initial_instance_count = initial_instance_count or 1 @@ -720,6 +882,7 @@ def get_register_kwargs( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + config_name: Optional[str] = None, model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None, accept_eula: Optional[bool] = None, ) -> JumpStartModelRegisterKwargs: @@ -770,6 +933,7 @@ def get_register_kwargs( sagemaker_session=sagemaker_session, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + config_name=config_name, ) register_kwargs.content_types = ( @@ -813,6 +977,8 @@ def get_init_kwargs( training_instance_type: Optional[str] = None, disable_instance_type_logging: bool = False, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, + additional_model_data_sources: Optional[Dict[str, Any]] = None, ) -> JumpStartModelInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Model` object.""" @@ -845,6 +1011,8 @@ def get_init_kwargs( model_package_arn=model_package_arn, training_instance_type=training_instance_type, resources=resources, + config_name=config_name, + additional_model_data_sources=additional_model_data_sources, ) model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) @@ -880,4 +1048,8 @@ def get_init_kwargs( model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) + + model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs) + return model_init_kwargs diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index 2748409927..d987216872 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -13,14 +13,20 @@ """This module stores types related to SageMaker JumpStart HubAPI requests and responses.""" from __future__ import absolute_import +from enum import Enum import re import json import datetime from typing import Any, Dict, List, Union, Optional +from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import ( HubContentType, HubArnExtractedInfo, + JumpStartConfigComponent, + JumpStartConfigRanking, + JumpStartMetadataConfig, + JumpStartMetadataConfigs, JumpStartPredictorSpecs, JumpStartHyperparameter, JumpStartDataHolderType, @@ -34,6 +40,13 @@ ) +class _ComponentType(str, Enum): + """Enum for different component types.""" + + INFERENCE = "Inference" + TRAINING = "Training" + + class HubDataHolderType(JumpStartDataHolderType): """Base class for many Hub API interfaces.""" @@ -456,6 +469,9 @@ class HubModelDocument(HubDataHolderType): "hosting_use_script_uri", "hosting_eula_uri", "hosting_model_package_arn", + "inference_configs", + "inference_config_components", + "inference_config_rankings", "training_artifact_s3_data_type", "training_artifact_compression_type", "training_model_package_artifact_uri", @@ -467,6 +483,9 @@ class HubModelDocument(HubDataHolderType): "training_ecr_uri", "training_metrics", "training_artifact_uri", + "training_configs", + "training_config_components", + "training_config_rankings", "inference_dependencies", "training_dependencies", "default_inference_instance_type", @@ -566,6 +585,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ) self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri") self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn") + + self.inference_config_rankings = self._get_config_rankings(json_obj) + self.inference_config_components = self._get_config_components(json_obj) + self.inference_configs = self._get_configs(json_obj) + self.default_inference_instance_type: Optional[str] = json_obj.get( "DefaultInferenceInstanceType" ) @@ -667,6 +691,15 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: "TrainingMetrics", None ) self.training_artifact_uri: Optional[str] = json_obj.get("TrainingArtifactUri") + + self.training_config_rankings = self._get_config_rankings( + json_obj, _ComponentType.TRAINING + ) + self.training_config_components = self._get_config_components( + json_obj, _ComponentType.TRAINING + ) + self.training_configs = self._get_configs(json_obj, _ComponentType.TRAINING) + self.training_dependencies: Optional[str] = json_obj.get("TrainingDependencies") self.default_training_instance_type: Optional[str] = json_obj.get( "DefaultTrainingInstanceType" @@ -707,6 +740,64 @@ def get_region(self) -> str: """Returns hub region.""" return self._region + def _get_config_rankings( + self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE + ) -> Optional[Dict[str, JumpStartConfigRanking]]: + """Returns config rankings.""" + config_rankings = json_obj.get(f"{component_type.value}ConfigRankings") + return ( + { + alias: JumpStartConfigRanking(ranking, is_hub_content=True) + for alias, ranking in config_rankings.items() + } + if config_rankings + else None + ) + + def _get_config_components( + self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE + ) -> Optional[Dict[str, JumpStartConfigComponent]]: + """Returns config components.""" + config_components = json_obj.get(f"{component_type.value}ConfigComponents") + return ( + { + alias: JumpStartConfigComponent(alias, config, is_hub_content=True) + for alias, config in config_components.items() + } + if config_components + else None + ) + + def _get_configs( + self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE + ) -> Optional[JumpStartMetadataConfigs]: + """Returns configs.""" + if not (configs := json_obj.get(f"{component_type.value}Configs")): + return None + + configs_dict = {} + for alias, config in configs.items(): + config_components = None + if isinstance(config, dict) and (component_names := config.get("ComponentNames")): + config_components = { + name: getattr(self, f"{component_type.value.lower()}_config_components").get( + name + ) + for name in component_names + } + configs_dict[alias] = JumpStartMetadataConfig( + alias, config, json_obj, config_components, is_hub_content=True + ) + + if component_type == _ComponentType.INFERENCE: + config_rankings = self.inference_config_rankings + scope = JumpStartScriptScope.INFERENCE + else: + config_rankings = self.training_config_rankings + scope = JumpStartScriptScope.TRAINING + + return JumpStartMetadataConfigs(configs_dict, config_rankings, scope) + class HubNotebookDocument(HubDataHolderType): """Data class for notebook type HubContentDocument from session.describe_hub_content().""" diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py index 8226a380fd..28c2d9b32d 100644 --- a/src/sagemaker/jumpstart/hub/parsers.py +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -142,6 +142,9 @@ def make_model_specs_from_describe_hub_content_response( hub_model_document.incremental_training_supported ) specs["hosting_ecr_uri"] = hub_model_document.hosting_ecr_uri + specs["inference_configs"] = hub_model_document.inference_configs + specs["inference_config_components"] = hub_model_document.inference_config_components + specs["inference_config_rankings"] = hub_model_document.inference_config_rankings hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable hub_model_document.hosting_artifact_uri @@ -233,6 +236,11 @@ def make_model_specs_from_describe_hub_content_response( training_script_key, ) = parse_s3_url(hub_model_document.training_script_uri) specs["training_script_key"] = training_script_key + + specs["training_configs"] = hub_model_document.training_configs + specs["training_config_components"] = hub_model_document.training_config_components + specs["training_config_rankings"] = hub_model_document.training_config_rankings + specs["training_dependencies"] = hub_model_document.training_dependencies specs["default_training_instance_type"] = hub_model_document.default_training_instance_type specs["supported_training_instance_types"] = ( diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index b482d4fefd..15cfea5c86 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -14,7 +14,8 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Union, Any +from typing import Dict, List, Optional, Any, Union +import pandas as pd from botocore.exceptions import ClientError from sagemaker import payloads @@ -37,11 +38,19 @@ get_init_kwargs, get_register_kwargs, ) -from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint -from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.session_utils import get_model_info_from_endpoint +from sagemaker.jumpstart.types import ( + JumpStartSerializablePayload, + DeploymentConfigMetadata, +) from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, verify_model_region_and_return_specs, + get_jumpstart_configs, + get_metrics_from_deployment_configs, + add_instance_rate_stats_to_benchmark_metrics, + deployment_config_response_data, + _deployment_config_lru_cache, ) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType @@ -99,6 +108,8 @@ def __init__( git_config: Optional[Dict[str, str]] = None, model_package_arn: Optional[str] = None, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, + additional_model_data_sources: Optional[Dict[str, Any]] = None, ): """Initializes a ``JumpStartModel``. @@ -285,6 +296,10 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). + config_name (Optional[str]): The name of the JumpStart config that can be + optionally applied to the model. + additional_model_data_sources (Optional[Dict[str, Any]]): Additional location + of SageMaker model data (default: None). Raises: ValueError: If the model ID is not recognized by JumpStart. """ @@ -342,6 +357,8 @@ def _validate_model_id_and_type(): git_config=git_config, model_package_arn=model_package_arn, resources=resources, + config_name=config_name, + additional_model_data_sources=additional_model_data_sources, ) self.orig_predictor_cls = predictor_cls @@ -355,6 +372,9 @@ def _validate_model_id_and_type(): self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session + self.role = role + self.config_name = model_init_kwargs.config_name + self.additional_model_data_sources = model_init_kwargs.additional_model_data_sources self.model_reference_arn = model_init_kwargs.model_reference_arn if self.model_type == JumpStartModelType.PROPRIETARY: @@ -365,6 +385,15 @@ def _validate_model_id_and_type(): super(JumpStartModel, self).__init__(**model_init_kwargs_dict) self.model_package_arn = model_init_kwargs.model_package_arn + self.init_kwargs = model_init_kwargs.to_kwargs_dict(False) + + self._metadata_configs = get_jumpstart_configs( + region=self.region, + model_id=self.model_id, + model_version=self.model_version, + sagemaker_session=self.sagemaker_session, + model_type=self.model_type, + ) def log_subscription_warning(self) -> None: """Log message prompting the customer to subscribe to the proprietary model.""" @@ -424,6 +453,72 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: sagemaker_session=self.sagemaker_session, ) + def set_deployment_config(self, config_name: str, instance_type: str) -> None: + """Sets the deployment config to apply to the model. + + Args: + config_name (str): + The name of the deployment config to apply to the model. + Call list_deployment_configs to see the list of config names. + instance_type (str): + The instance_type that the model will use after setting + the config. + """ + self.__init__( + model_id=self.model_id, + model_version=self.model_version, + instance_type=instance_type, + config_name=config_name, + sagemaker_session=self.sagemaker_session, + role=self.role, + ) + + @property + def deployment_config(self) -> Optional[Dict[str, Any]]: + """The deployment config that will be applied to ``This`` model. + + Returns: + Optional[Dict[str, Any]]: Deployment config. + """ + if self.config_name is None: + return None + for config in self.list_deployment_configs(): + if config.get("DeploymentConfigName") == self.config_name: + return config + return None + + @property + def benchmark_metrics(self) -> pd.DataFrame: + """Benchmark Metrics for deployment configs. + + Returns: + Benchmark Metrics: Pandas DataFrame object. + """ + df = pd.DataFrame(self._get_deployment_configs_benchmarks_data()) + blank_index = [""] * len(df) + df.index = blank_index + return df + + def display_benchmark_metrics(self, **kwargs) -> None: + """Display deployment configs benchmark metrics.""" + df = self.benchmark_metrics + + instance_type = kwargs.get("instance_type") + if instance_type: + df = df[df["Instance Type"].str.contains(instance_type)] + + print(df.to_markdown(index=False, floatfmt=".2f")) + + def list_deployment_configs(self) -> List[Dict[str, Any]]: + """List deployment configs for ``This`` model. + + Returns: + List[Dict[str, Any]]: A list of deployment configs. + """ + return deployment_config_response_data( + self._get_deployment_configs(self.config_name, self.instance_type) + ) + @classmethod def attach( cls, @@ -441,8 +536,8 @@ def attach( inferred_model_id = inferred_model_version = inferred_inference_component_name = None if inference_component_name is None or model_id is None or model_version is None: - inferred_model_id, inferred_model_version, inferred_inference_component_name = ( - get_model_id_version_from_endpoint( + inferred_model_id, inferred_model_version, inferred_inference_component_name, _, _ = ( + get_model_info_from_endpoint( endpoint_name=endpoint_name, inference_component_name=inference_component_name, sagemaker_session=sagemaker_session, @@ -693,6 +788,7 @@ def deploy( managed_instance_scaling=managed_instance_scaling, endpoint_type=endpoint_type, model_type=self.model_type, + config_name=self.config_name, routing_config=routing_config, ) if ( @@ -713,6 +809,7 @@ def deploy( model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, + config_name=self.config_name, hub_arn=self.hub_arn, ).model_subscription_link get_proprietary_model_subscription_error(e, subscription_link) @@ -730,6 +827,7 @@ def deploy( tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, model_type=self.model_type, + config_name=self.config_name, ) # If a predictor class was passed, do not mutate predictor @@ -855,6 +953,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + config_name=self.config_name, model_card=model_card, accept_eula=accept_eula, ) @@ -874,6 +973,89 @@ def register_deploy_wrapper(*args, **kwargs): return model_package + @_deployment_config_lru_cache + def _get_deployment_configs_benchmarks_data(self) -> Dict[str, Any]: + """Deployment configs benchmark metrics. + + Returns: + Dict[str, List[str]]: Deployment config benchmark data. + """ + return get_metrics_from_deployment_configs( + self._get_deployment_configs(None, None), + ) + + @_deployment_config_lru_cache + def _get_deployment_configs( + self, selected_config_name: Optional[str], selected_instance_type: Optional[str] + ) -> List[DeploymentConfigMetadata]: + """Retrieve deployment configs metadata. + + Args: + selected_config_name (Optional[str]): The name of the selected deployment config. + selected_instance_type (Optional[str]): The selected instance type. + """ + deployment_configs = [] + if not self._metadata_configs: + return deployment_configs + + err = None + for config_name, metadata_config in self._metadata_configs.items(): + if selected_config_name == config_name: + instance_type_to_use = selected_instance_type + else: + instance_type_to_use = metadata_config.resolved_config.get( + "default_inference_instance_type" + ) + + if metadata_config.benchmark_metrics: + err, metadata_config.benchmark_metrics = ( + add_instance_rate_stats_to_benchmark_metrics( + self.region, metadata_config.benchmark_metrics + ) + ) + + config_components = metadata_config.config_components.get(config_name) + image_uri = ( + ( + config_components.hosting_instance_type_variants.get("regional_aliases", {}) + .get(self.region, {}) + .get("alias_ecr_uri_1") + ) + if config_components + else self.image_uri + ) + + init_kwargs = get_init_kwargs( + config_name=config_name, + model_id=self.model_id, + instance_type=instance_type_to_use, + sagemaker_session=self.sagemaker_session, + image_uri=image_uri, + region=self.region, + model_version=self.model_version, + ) + deploy_kwargs = get_deploy_kwargs( + model_id=self.model_id, + instance_type=instance_type_to_use, + sagemaker_session=self.sagemaker_session, + region=self.region, + model_version=self.model_version, + ) + + deployment_config_metadata = DeploymentConfigMetadata( + config_name, + metadata_config, + init_kwargs, + deploy_kwargs, + ) + deployment_configs.append(deployment_config_metadata) + + if err and err["Code"] == "AccessDeniedException": + error_message = "Instance rate metrics will be omitted. Reason: %s" + JUMPSTART_LOGGER.warning(error_message, err["Message"]) + + return deployment_configs + def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" return stringify_object(self) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 83613cd71b..781548b42a 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -535,6 +535,7 @@ def get_model_url( model_version: str, region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieve web url describing pretrained model. @@ -563,5 +564,6 @@ def get_model_url( sagemaker_session=sagemaker_session, scope=JumpStartScriptScope.INFERENCE, model_type=model_type, + config_name=config_name, ) return model_specs.url diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index e511a052d1..0955ae9480 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -17,17 +17,17 @@ from typing import Optional, Tuple from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn +from sagemaker.jumpstart.utils import get_jumpstart_model_info_from_resource_arn from sagemaker.session import Session from sagemaker.utils import aws_partition -def get_model_id_version_from_endpoint( +def get_model_info_from_endpoint( endpoint_name: str, inference_component_name: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str, Optional[str]]: - """Given an endpoint and optionally inference component names, return the model ID and version. +) -> Tuple[str, str, Optional[str], Optional[str], Optional[str]]: + """Optionally inference component names, return the model ID, version and config name. Infers the model ID and version based on the resource tags. Returns a tuple of the model ID and version. A third string element is included in the tuple for any inferred inference @@ -46,7 +46,9 @@ def get_model_id_version_from_endpoint( ( model_id, model_version, - ) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 + inference_config_name, + training_config_name, + ) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 inference_component_name, sagemaker_session ) @@ -54,22 +56,35 @@ def get_model_id_version_from_endpoint( ( model_id, model_version, + inference_config_name, + training_config_name, inference_component_name, - ) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 + ) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 endpoint_name, sagemaker_session ) else: - model_id, model_version = _get_model_id_version_from_model_based_endpoint( + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = _get_model_info_from_model_based_endpoint( endpoint_name, inference_component_name, sagemaker_session ) - return model_id, model_version, inference_component_name + return ( + model_id, + model_version, + inference_component_name, + inference_config_name, + training_config_name, + ) -def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( +def _get_model_info_from_inference_component_endpoint_without_inference_component_name( endpoint_name: str, sagemaker_session: Session -) -> Tuple[str, str, str]: - """Given an endpoint name, derives the model ID, version, and inferred inference component name. +) -> Tuple[str, str, str, str]: + """Derives the model ID, version, config name and inferred inference component name. This function assumes the endpoint corresponds to an inference-component-based endpoint. An endpoint is inference-component-based if and only if the associated endpoint config @@ -98,14 +113,14 @@ def _get_model_id_version_from_inference_component_endpoint_without_inference_co ) inference_component_name = inference_component_names[0] return ( - *_get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + *_get_model_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name, sagemaker_session ), inference_component_name, ) -def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( +def _get_model_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name: str, sagemaker_session: Session ): """Returns the model ID and version inferred from a SageMaker inference component. @@ -123,9 +138,12 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo f"inference-component/{inference_component_name}" ) - model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( - inference_component_arn, sagemaker_session - ) + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = get_jumpstart_model_info_from_resource_arn(inference_component_arn, sagemaker_session) if not model_id: raise ValueError( @@ -134,15 +152,15 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo "when retrieving default predictor for this inference component." ) - return model_id, model_version + return model_id, model_version, inference_config_name, training_config_name -def _get_model_id_version_from_model_based_endpoint( +def _get_model_info_from_model_based_endpoint( endpoint_name: str, inference_component_name: Optional[str], sagemaker_session: Session, -) -> Tuple[str, str]: - """Returns the model ID and version inferred from a model-based endpoint. +) -> Tuple[str, str, Optional[str], Optional[str]]: + """Returns the model ID, version and config name inferred from a model-based endpoint. Raises: ValueError: If an inference component name is supplied, or if the endpoint does @@ -161,9 +179,12 @@ def _get_model_id_version_from_model_based_endpoint( endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}" - model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( - endpoint_arn, sagemaker_session - ) + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = get_jumpstart_model_info_from_resource_arn(endpoint_arn, sagemaker_session) if not model_id: raise ValueError( @@ -172,14 +193,14 @@ def _get_model_id_version_from_model_based_endpoint( "predictor for this endpoint." ) - return model_id, model_version + return model_id, model_version, inference_config_name, training_config_name -def get_model_id_version_from_training_job( +def get_model_info_from_training_job( training_job_name: str, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str]: - """Returns the model ID and version inferred from a training job. +) -> Tuple[str, str, Optional[str], Optional[str]]: + """Returns the model ID and version and config name inferred from a training job. Raises: ValueError: If the training job does not have tags from which the model ID @@ -194,9 +215,12 @@ def get_model_id_version_from_training_job( f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}" ) - model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn( - training_job_arn, sagemaker_session - ) + ( + model_id, + inferred_model_version, + inference_config_name, + training_config_name, + ) = get_jumpstart_model_info_from_resource_arn(training_job_arn, sagemaker_session) model_version = inferred_model_version or None @@ -207,4 +231,4 @@ def get_model_id_version_from_training_job( "for this training job." ) - return model_id, model_version + return model_id, model_version, inference_config_name, training_config_name diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 171d9ce8a1..fb4c157a67 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -18,7 +18,13 @@ from enum import Enum from typing import Any, Dict, List, Optional, Set, Union from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard -from sagemaker.utils import get_instance_type_family, format_tags, Tags, deep_override_dict +from sagemaker.utils import ( + S3_PREFIX, + get_instance_type_family, + format_tags, + Tags, + deep_override_dict, +) from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines @@ -140,10 +146,14 @@ class HubContentType(str, Enum): class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): """Data class for launched region info.""" - __slots__ = ["content_bucket", "region_name", "gated_content_bucket"] + __slots__ = ["content_bucket", "region_name", "gated_content_bucket", "neo_content_bucket"] def __init__( - self, content_bucket: str, region_name: str, gated_content_bucket: Optional[str] = None + self, + content_bucket: str, + region_name: str, + gated_content_bucket: Optional[str] = None, + neo_content_bucket: Optional[str] = None, ): """Instantiates JumpStartLaunchedRegionInfo object. @@ -152,10 +162,13 @@ def __init__( region_name (str): Name of JumpStart launched region. gated_content_bucket (Optional[str[]): Name of JumpStart gated s3 content bucket optionally associated with region. + neo_content_bucket (Optional[str]): Name of Neo service s3 content bucket + optionally associated with region. """ self.content_bucket = content_bucket self.gated_content_bucket = gated_content_bucket self.region_name = region_name + self.neo_content_bucket = neo_content_bucket class JumpStartModelHeader(JumpStartDataHolderType): @@ -854,10 +867,8 @@ def _get_regional_property( if regional_property_alias is None and regional_property_value is None: instance_type_family = get_instance_type_family(instance_type) - if instance_type_family in {"", None}: return None - if self.regional_aliases: regional_property_alias = ( self.variants.get(instance_type_family, {}) @@ -894,10 +905,237 @@ def _get_regional_property( return regional_property_value +class JumpStartAdditionalDataSources(JumpStartDataHolderType): + """Data class of additional data sources.""" + + __slots__ = ["speculative_decoding", "scripts"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a AdditionalDataSources object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.speculative_decoding: Optional[List[JumpStartModelDataSource]] = ( + [ + JumpStartModelDataSource(data_source) + for data_source in json_obj["speculative_decoding"] + ] + if json_obj.get("speculative_decoding") + else None + ) + self.scripts: Optional[List[JumpStartModelDataSource]] = ( + [JumpStartModelDataSource(data_source) for data_source in json_obj["scripts"]] + if json_obj.get("scripts") + else None + ) + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of AdditionalDataSources object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + else: + json_obj[att] = cur_val + return json_obj + + +class ModelAccessConfig(JumpStartDataHolderType): + """Data class of model access config that mirrors CreateModel API.""" + + __slots__ = ["accept_eula"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a ModelAccessConfig object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.accept_eula: bool = json_obj["accept_eula"] + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of ModelAccessConfig object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + return json_obj + + +class HubAccessConfig(JumpStartDataHolderType): + """Data class of model access config that mirrors CreateModel API.""" + + __slots__ = ["hub_content_arn"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a HubAccessConfig object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.hub_content_arn: bool = json_obj["accept_eula"] + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of ModelAccessConfig object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + return json_obj + + +class S3DataSource(JumpStartDataHolderType): + """Data class of S3 data source that mirrors CreateModel API.""" + + __slots__ = [ + "compression_type", + "s3_data_type", + "s3_uri", + "model_access_config", + "hub_access_config", + ] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a S3DataSource object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.compression_type: str = json_obj["compression_type"] + self.s3_data_type: str = json_obj["s3_data_type"] + self.s3_uri: str = json_obj["s3_uri"] + self.model_access_config: ModelAccessConfig = ( + ModelAccessConfig(json_obj["model_access_config"]) + if json_obj.get("model_access_config") + else None + ) + self.hub_access_config: HubAccessConfig = ( + HubAccessConfig(json_obj["hub_access_config"]) + if json_obj.get("hub_access_config") + else None + ) + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of S3DataSource object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif cur_val: + json_obj[att] = cur_val + return json_obj + + def set_bucket(self, bucket: str) -> None: + """Sets bucket name from S3 URI.""" + + if self.s3_uri.startswith(S3_PREFIX): + s3_path = self.s3_uri[len(S3_PREFIX) :] + old_bucket = s3_path.split("/")[0] + key = s3_path[len(old_bucket) :] + self.s3_uri = f"{S3_PREFIX}{bucket}{key}" # pylint: disable=W0201 + return + + if not bucket.endswith("/"): + bucket += "/" + + self.s3_uri = f"{S3_PREFIX}{bucket}{self.s3_uri}" # pylint: disable=W0201 + + +class AdditionalModelDataSource(JumpStartDataHolderType): + """Data class of additional model data source mirrors CreateModel API.""" + + SERIALIZATION_EXCLUSION_SET: Set[str] = set() + + __slots__ = ["channel_name", "s3_data_source"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a AdditionalModelDataSource object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.channel_name: str = json_obj["channel_name"] + self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"]) + + def to_json(self, exclude_keys=True) -> Dict[str, Any]: + """Returns json representation of AdditionalModelDataSource object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + if exclude_keys and att not in self.SERIALIZATION_EXCLUSION_SET or not exclude_keys: + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + else: + json_obj[att] = cur_val + return json_obj + + +class JumpStartModelDataSource(AdditionalModelDataSource): + """Data class JumpStart additional model data source.""" + + SERIALIZATION_EXCLUSION_SET = {"artifact_version"} + + __slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__ + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + super().from_json(json_obj) + self.artifact_version: str = json_obj["artifact_version"] + + class JumpStartBenchmarkStat(JumpStartDataHolderType): """Data class JumpStart benchmark stat.""" - __slots__ = ["name", "value", "unit"] + __slots__ = ["name", "value", "unit", "concurrency"] def __init__(self, spec: Dict[str, Any]): """Initializes a JumpStartBenchmarkStat object. @@ -916,6 +1154,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.name: str = json_obj["name"] self.value: str = json_obj["value"] self.unit: Union[int, str] = json_obj["unit"] + self.concurrency: Union[int, str] = json_obj["concurrency"] def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartBenchmarkStat object.""" @@ -928,12 +1167,14 @@ class JumpStartConfigRanking(JumpStartDataHolderType): __slots__ = ["description", "rankings"] - def __init__(self, spec: Optional[Dict[str, Any]]): + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False): """Initializes a JumpStartConfigRanking object. Args: spec (Dict[str, Any]): Dictionary representation of training config ranking. """ + if is_hub_content: + spec = {camel_to_snake(key): val for key, val in spec.items()} self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -1012,6 +1253,9 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "default_payloads", "gated_bucket", "model_subscription_link", + "hosting_additional_data_sources", + "hosting_neuron_model_id", + "hosting_neuron_model_version", "hub_content_type", "_is_hub_content", ] @@ -1041,7 +1285,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj.get("incremental_training_supported", False) ) if self._is_hub_content: - self.hosting_ecr_uri: Optional[str] = json_obj["hosting_ecr_uri"] + self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri") self._non_serializable_slots.append("hosting_ecr_specs") else: self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( @@ -1129,7 +1373,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key") - self.hosting_model_package_arns: Optional[Dict] = json_obj.get("hosting_model_package_arns") + model_package_arns = json_obj.get("hosting_model_package_arns") + self.hosting_model_package_arns: Optional[Dict] = ( + model_package_arns if model_package_arns is not None else {} + ) self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True) self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( @@ -1139,6 +1386,15 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if json_obj.get("hosting_instance_type_variants") else None ) + self.hosting_additional_data_sources: Optional[JumpStartAdditionalDataSources] = ( + JumpStartAdditionalDataSources(json_obj["hosting_additional_data_sources"]) + if json_obj.get("hosting_additional_data_sources") + else None + ) + self.hosting_neuron_model_id: Optional[str] = json_obj.get("hosting_neuron_model_id") + self.hosting_neuron_model_version: Optional[str] = json_obj.get( + "hosting_neuron_model_version" + ) if self.training_supported: if self._is_hub_content: @@ -1235,9 +1491,7 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields): __slots__ = slots + JumpStartMetadataBaseFields.__slots__ def __init__( - self, - component_name: str, - component: Optional[Dict[str, Any]], + self, component_name: str, component: Optional[Dict[str, Any]], is_hub_content=False ): """Initializes a JumpStartConfigComponent object from its json representation. @@ -1248,8 +1502,10 @@ def __init__( Raises: ValueError: If the component field is invalid. """ - super().__init__(component) + if is_hub_content: + component = walk_and_apply_json(component, camel_to_snake) self.component_name = component_name + super().__init__(component, is_hub_content) self.from_json(component) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -1270,30 +1526,61 @@ class JumpStartMetadataConfig(JumpStartDataHolderType): __slots__ = [ "base_fields", "benchmark_metrics", + "acceleration_configs", "config_components", "resolved_metadata_config", + "config_name", + "default_inference_config", + "default_incremental_training_config", + "supported_inference_configs", + "supported_incremental_training_configs", ] def __init__( self, + config_name: str, + config: Dict[str, Any], base_fields: Dict[str, Any], config_components: Dict[str, JumpStartConfigComponent], - benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]], + is_hub_content=False, ): """Initializes a JumpStartMetadataConfig object from its json representation. Args: + config_name (str): Name of the config, + config (Dict[str, Any]): + Dictionary representation of the config. base_fields (Dict[str, Any]): - The default base fields that are used to construct the final resolved config. + The default base fields that are used to construct the resolved config. config_components (Dict[str, JumpStartConfigComponent]): The list of components that are used to construct the resolved config. - benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]): - The dictionary of benchmark metrics with name being the key. """ + if is_hub_content: + config = walk_and_apply_json(config, camel_to_snake) + base_fields = walk_and_apply_json(base_fields, camel_to_snake) self.base_fields = base_fields self.config_components: Dict[str, JumpStartConfigComponent] = config_components - self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics + self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = ( + { + stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] + for stat_name, stats in config.get("benchmark_metrics").items() + } + if config and config.get("benchmark_metrics") + else None + ) + self.acceleration_configs = config.get("acceleration_configs") self.resolved_metadata_config: Optional[Dict[str, Any]] = None + self.config_name: Optional[str] = config_name + self.default_inference_config: Optional[str] = config.get("default_inference_config") + self.default_incremental_training_config: Optional[str] = config.get( + "default_incremental_training_config" + ) + self.supported_inference_configs: Optional[List[str]] = config.get( + "supported_inference_configs" + ) + self.supported_incremental_training_configs: Optional[List[str]] = config.get( + "supported_incremental_training_configs" + ) def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataConfig object.""" @@ -1317,6 +1604,12 @@ def resolved_config(self) -> Dict[str, Any]: deepcopy(component.to_json()), component.OVERRIDING_DENY_LIST, ) + + # Remove environment variables from resolved config if using model packages + hosting_model_pacakge_arns = resolved_config.get("hosting_model_package_arns") + if hosting_model_pacakge_arns is not None and hosting_model_pacakge_arns != {}: + resolved_config["inference_environment_variables"] = [] + self.resolved_metadata_config = resolved_config return resolved_config @@ -1359,6 +1652,8 @@ def get_top_config_from_ranking( ) -> Optional[JumpStartMetadataConfig]: """Gets the best the config based on config ranking. + Fallback to use the ordering in config names if + ranking is not available. Args: ranking_name (str): The ranking name that config priority is based on. @@ -1366,13 +1661,8 @@ def get_top_config_from_ranking( The instance type which the config selection is based on. Raises: - ValueError: If the config exists but missing config ranking. NotImplementedError: If the scope is unrecognized. """ - if self.configs and ( - not self.config_rankings or not self.config_rankings.get(ranking_name) - ): - raise ValueError(f"Config exists but missing config ranking {ranking_name}.") if self.scope == JumpStartScriptScope.INFERENCE: instance_type_attribute = "supported_inference_instance_types" @@ -1381,8 +1671,14 @@ def get_top_config_from_ranking( else: raise NotImplementedError(f"Unknown script scope {self.scope}") - rankings = self.config_rankings.get(ranking_name) - for config_name in rankings.rankings: + if self.configs and ( + not self.config_rankings or not self.config_rankings.get(ranking_name) + ): + ranked_config_names = sorted(list(self.configs.keys())) + else: + rankings = self.config_rankings.get(ranking_name) + ranked_config_names = rankings.rankings + for config_name in ranked_config_names: resolved_config = self.configs[config_name].resolved_config if instance_type and instance_type not in getattr( resolved_config, instance_type_attribute @@ -1444,6 +1740,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( { alias: JumpStartMetadataConfig( + alias, + config, json_obj, ( { @@ -1453,14 +1751,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if config and config.get("component_names") else None ), - ( - { - stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] - for stat_name, stats in config.get("benchmark_metrics").items() - } - if config and config.get("benchmark_metrics") - else None - ), ) for alias, config in json_obj["inference_configs"].items() } @@ -1496,6 +1786,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( { alias: JumpStartMetadataConfig( + alias, + config, json_obj, ( { @@ -1505,14 +1797,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if config and config.get("component_names") else None ), - ( - { - stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] - for stat_name, stats in config.get("benchmark_metrics").items() - } - if config and config.get("benchmark_metrics") - else None - ), ) for alias, config in json_obj["training_configs"].items() } @@ -1588,6 +1872,21 @@ def supports_incremental_training(self) -> bool: """Returns True if the model supports incremental training.""" return self.incremental_training_supported + def get_speculative_decoding_s3_data_sources(self) -> List[JumpStartModelDataSource]: + """Returns data sources for speculative decoding.""" + if not self.hosting_additional_data_sources: + return [] + return self.hosting_additional_data_sources.speculative_decoding or [] + + def get_additional_s3_data_sources(self) -> List[JumpStartAdditionalDataSources]: + """Returns a list of the additional S3 data sources for use by the model.""" + additional_data_sources = [] + if self.hosting_additional_data_sources: + for data_source in self.hosting_additional_data_sources.to_json(): + data_sources = getattr(self.hosting_additional_data_sources, data_source) or [] + additional_data_sources.extend(data_sources) + return additional_data_sources + class JumpStartVersionedModelId(JumpStartDataHolderType): """Data class for versioned model IDs.""" @@ -1675,7 +1974,6 @@ def extract_region_from_arn(arn: str) -> Optional[str]: hub_region = None if match: hub_region = match.group(2) - return hub_region match = re.match(HUB_ARN_REGEX, arn) @@ -1717,11 +2015,11 @@ class JumpStartKwargs(JumpStartDataHolderType): SERIALIZATION_EXCLUSION_SET: Set[str] = set() - def to_kwargs_dict(self): + def to_kwargs_dict(self, exclude_keys: bool = True): """Serializes object to dictionary to be used for kwargs for method arguments.""" kwargs_dict = {} for field in self.__slots__: - if field not in self.SERIALIZATION_EXCLUSION_SET: + if exclude_keys and field not in self.SERIALIZATION_EXCLUSION_SET or not exclude_keys: att_value = getattr(self, field) if att_value is not None: kwargs_dict[field] = getattr(self, field) @@ -1760,6 +2058,8 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_package_arn", "training_instance_type", "resources", + "config_name", + "additional_model_data_sources", "hub_content_type", "model_reference_arn", ] @@ -1775,6 +2075,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "region", "model_package_arn", "training_instance_type", + "config_name", "hub_content_type", } @@ -1808,6 +2109,8 @@ def __init__( model_package_arn: Optional[str] = None, training_instance_type: Optional[str] = None, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, + additional_model_data_sources: Optional[Dict[str, Any]] = None, ) -> None: """Instantiates JumpStartModelInitKwargs object.""" @@ -1839,6 +2142,8 @@ def __init__( self.model_package_arn = model_package_arn self.training_instance_type = training_instance_type self.resources = resources + self.config_name = config_name + self.additional_model_data_sources = additional_model_data_sources class JumpStartModelDeployKwargs(JumpStartKwargs): @@ -1877,6 +2182,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "endpoint_logging", "resources", "endpoint_type", + "config_name", "routing_config", ] @@ -1890,6 +2196,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "sagemaker_session", "training_instance_type", + "config_name", } def __init__( @@ -1926,6 +2233,7 @@ def __init__( endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, endpoint_type: Optional[EndpointType] = None, + config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -1962,6 +2270,7 @@ def __init__( self.endpoint_logging = endpoint_logging self.resources = resources self.endpoint_type = endpoint_type + self.config_name = config_name self.routing_config = routing_config @@ -2024,6 +2333,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "disable_output_compression", "enable_infra_check", "enable_remote_debug", + "config_name", "enable_session_tag_chaining", ] @@ -2035,6 +2345,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "config_name", } def __init__( @@ -2094,6 +2405,7 @@ def __init__( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -2155,6 +2467,7 @@ def __init__( self.disable_output_compression = disable_output_compression self.enable_infra_check = enable_infra_check self.enable_remote_debug = enable_remote_debug + self.config_name = config_name self.enable_session_tag_chaining = enable_session_tag_chaining @@ -2175,6 +2488,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "tolerate_deprecated_model", "tolerate_vulnerable_model", "sagemaker_session", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -2186,6 +2500,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "tolerate_deprecated_model", "tolerate_vulnerable_model", "sagemaker_session", + "config_name", } def __init__( @@ -2203,6 +2518,7 @@ def __init__( tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -2219,6 +2535,7 @@ def __init__( self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model self.sagemaker_session = sagemaker_session + self.config_name = config_name class JumpStartEstimatorDeployKwargs(JumpStartKwargs): @@ -2265,6 +2582,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_name", "use_compiled_model", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -2275,6 +2593,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "model_version", "hub_arn", "sagemaker_session", + "config_name", } def __init__( @@ -2319,6 +2638,7 @@ def __init__( tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, use_compiled_model: bool = False, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -2362,6 +2682,7 @@ def __init__( self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model self.use_compiled_model = use_compiled_model + self.config_name = config_name class JumpStartModelRegisterKwargs(JumpStartKwargs): @@ -2398,6 +2719,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "data_input_configuration", "skip_model_validation", "source_uri", + "config_name", "model_card", "accept_eula", ] @@ -2410,6 +2732,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "model_version", "hub_arn", "sagemaker_session", + "config_name", } def __init__( @@ -2444,6 +2767,7 @@ def __init__( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + config_name: Optional[str] = None, model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None, accept_eula: Optional[bool] = None, ) -> None: @@ -2480,5 +2804,128 @@ def __init__( self.data_input_configuration = data_input_configuration self.skip_model_validation = skip_model_validation self.source_uri = source_uri + self.config_name = config_name self.model_card = model_card self.accept_eula = accept_eula + + +class BaseDeploymentConfigDataHolder(JumpStartDataHolderType): + """Base class for Deployment Config Data.""" + + def _convert_to_pascal_case(self, attr_name: str) -> str: + """Converts a snake_case attribute name into a camelCased string. + + Args: + attr_name (str): The snake_case attribute name. + Returns: + str: The PascalCased attribute name. + """ + return attr_name.replace("_", " ").title().replace(" ", "") + + def to_json(self) -> Dict[str, Any]: + """Represents ``This`` object as JSON.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + att = self._convert_to_pascal_case(att) + json_obj[att] = self._val_to_json(cur_val) + return json_obj + + def _val_to_json(self, val: Any) -> Any: + """Converts the given value to JSON. + + Args: + val (Any): The value to convert. + Returns: + Any: The converted json value. + """ + if issubclass(type(val), JumpStartDataHolderType): + if isinstance(val, JumpStartBenchmarkStat): + val.name = val.name.replace("_", " ").title() + return val.to_json() + if isinstance(val, list): + list_obj = [] + for obj in val: + list_obj.append(self._val_to_json(obj)) + return list_obj + if isinstance(val, dict): + dict_obj = {} + for k, v in val.items(): + if isinstance(v, JumpStartDataHolderType): + dict_obj[self._convert_to_pascal_case(k)] = self._val_to_json(v) + else: + dict_obj[k] = self._val_to_json(v) + return dict_obj + return val + + +class DeploymentArgs(BaseDeploymentConfigDataHolder): + """Dataclass representing a Deployment Args.""" + + __slots__ = [ + "image_uri", + "model_data", + "model_package_arn", + "environment", + "instance_type", + "compute_resource_requirements", + "model_data_download_timeout", + "container_startup_health_check_timeout", + "additional_data_sources", + ] + + def __init__( + self, + init_kwargs: Optional[JumpStartModelInitKwargs] = None, + deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None, + resolved_config: Optional[Dict[str, Any]] = None, + ): + """Instantiates DeploymentArgs object.""" + if init_kwargs is not None: + self.image_uri = init_kwargs.image_uri + self.model_data = init_kwargs.model_data + self.model_package_arn = init_kwargs.model_package_arn + self.instance_type = init_kwargs.instance_type + self.environment = init_kwargs.env + if init_kwargs.resources is not None: + self.compute_resource_requirements = ( + init_kwargs.resources.get_compute_resource_requirements() + ) + if deploy_kwargs is not None: + self.model_data_download_timeout = deploy_kwargs.model_data_download_timeout + self.container_startup_health_check_timeout = ( + deploy_kwargs.container_startup_health_check_timeout + ) + if resolved_config is not None: + self.default_instance_type = resolved_config.get("default_inference_instance_type") + self.supported_instance_types = resolved_config.get( + "supported_inference_instance_types" + ) + self.additional_data_sources = resolved_config.get("hosting_additional_data_sources") + + +class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): + """Dataclass representing a Deployment Config Metadata""" + + __slots__ = [ + "deployment_config_name", + "deployment_args", + "acceleration_configs", + "benchmark_metrics", + ] + + def __init__( + self, + config_name: Optional[str] = None, + metadata_config: Optional[JumpStartMetadataConfig] = None, + init_kwargs: Optional[JumpStartModelInitKwargs] = None, + deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None, + ): + """Instantiates DeploymentConfigMetadata object.""" + self.deployment_config_name = config_name + self.deployment_args = DeploymentArgs( + init_kwargs, deploy_kwargs, metadata_config.resolved_config + ) + self.benchmark_metrics = metadata_config.benchmark_metrics + self.acceleration_configs = metadata_config.acceleration_configs diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 989ca426b5..83425d62b3 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -15,9 +15,11 @@ from copy import copy import logging import os +from functools import lru_cache, wraps from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 +from botocore.exceptions import ClientError from packaging.version import Version import botocore import sagemaker @@ -43,10 +45,11 @@ JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, + DeploymentConfigMetadata, ) from sagemaker.session import Session from sagemaker.config import load_sagemaker_config -from sagemaker.utils import resolve_value_from_config, TagsDict +from sagemaker.utils import resolve_value_from_config, TagsDict, get_instance_rate_per_hour from sagemaker.workflow import is_pipeline_variable from sagemaker.user_agent import get_user_agent_extra_suffix @@ -153,7 +156,7 @@ def get_jumpstart_content_bucket( except KeyError: formatted_launched_regions_str = get_jumpstart_launched_regions_message() raise ValueError( - f"Unable to get content bucket for JumpStart in {region} region. " + f"Unable to get content bucket for Neo in {region} region. " f"{formatted_launched_regions_str}" ) @@ -167,6 +170,34 @@ def get_jumpstart_content_bucket( return bucket_to_return +def get_neo_content_bucket( + region: str = constants.NEO_DEFAULT_REGION_NAME, +) -> str: + """Returns the regionalized S3 bucket name for Neo service. + + Raises: + ValueError: If Neo is not launched in ``region``. + """ + + bucket_to_return: Optional[str] = None + if ( + constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE in os.environ + and len(os.environ[constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE]) > 0 + ): + bucket_to_return = os.environ[constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE] + info_log = f"Using Neo bucket override: '{bucket_to_return}'" + constants.JUMPSTART_LOGGER.info(info_log) + else: + try: + bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[ + region + ].neo_content_bucket + except KeyError: + raise ValueError(f"Unable to get content bucket for Neo in {region} region.") + + return bucket_to_return + + def get_formatted_manifest( manifest: List[Dict], ) -> Dict[JumpStartVersionedModelId, JumpStartModelHeader]: @@ -321,6 +352,8 @@ def add_single_jumpstart_tag( tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags) + or tag_key_in_array(enums.JumpStartTag.INFERENCE_CONFIG_NAME, curr_tags) + or tag_key_in_array(enums.JumpStartTag.TRAINING_CONFIG_NAME, curr_tags) ) if is_uri else False @@ -351,11 +384,13 @@ def get_jumpstart_base_name_if_jumpstart_model( return None -def add_jumpstart_model_id_version_tags( +def add_jumpstart_model_info_tags( tags: Optional[List[TagsDict]], model_id: str, model_version: str, model_type: Optional[enums.JumpStartModelType] = None, + config_name: Optional[str] = None, + scope: enums.JumpStartScriptScope = None, ) -> List[TagsDict]: """Add custom model ID and version tags to JumpStart related resources.""" if model_id is None or model_version is None: @@ -379,6 +414,20 @@ def add_jumpstart_model_id_version_tags( tags, is_uri=False, ) + if config_name and scope == enums.JumpStartScriptScope.INFERENCE: + tags = add_single_jumpstart_tag( + config_name, + enums.JumpStartTag.INFERENCE_CONFIG_NAME, + tags, + is_uri=False, + ) + if config_name and scope == enums.JumpStartScriptScope.TRAINING: + tags = add_single_jumpstart_tag( + config_name, + enums.JumpStartTag.TRAINING_CONFIG_NAME, + tags, + is_uri=False, + ) return tags @@ -566,6 +615,7 @@ def verify_model_region_and_return_specs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -590,6 +640,7 @@ def verify_model_region_and_return_specs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Raises: NotImplementedError: If the scope is not supported. @@ -657,6 +708,9 @@ def verify_model_region_and_return_specs( scope=constants.JumpStartScriptScope.TRAINING, ) + if model_specs and config_name: + model_specs.set_config(config_name, scope) + return model_specs @@ -821,52 +875,80 @@ def validate_model_id_and_get_type( return None -def get_jumpstart_model_id_version_from_resource_arn( +def _extract_value_from_list_of_tags( + tag_keys: List[str], + list_tags_result: List[str], + resource_name: str, + resource_arn: str, +): + """Extracts value from list of tags with check of duplicate tags. + + Returns None if no value is found. + """ + resolved_value = None + for tag_key in tag_keys: + try: + value_from_tag = get_tag_value(tag_key, list_tags_result) + except KeyError: + continue + if value_from_tag is not None: + if resolved_value is not None and value_from_tag != resolved_value: + constants.JUMPSTART_LOGGER.warning( + "Found multiple %s tags on the following resource: %s", + resource_name, + resource_arn, + ) + resolved_value = None + break + resolved_value = value_from_tag + return resolved_value + + +def get_jumpstart_model_info_from_resource_arn( resource_arn: str, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[Optional[str], Optional[str]]: - """Returns the JumpStart model ID and version if in resource tags. +) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: + """Returns the JumpStart model ID, version and config name if in resource tags. - Returns 'None' if model ID or version cannot be inferred from tags. + Returns 'None' if model ID or version or config name cannot be inferred from tags. """ list_tags_result = sagemaker_session.list_tags(resource_arn) - model_id: Optional[str] = None - model_version: Optional[str] = None - model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS] model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS] + inference_config_name_keys = [enums.JumpStartTag.INFERENCE_CONFIG_NAME] + training_config_name_keys = [enums.JumpStartTag.TRAINING_CONFIG_NAME] + + model_id: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_id_keys, + list_tags_result=list_tags_result, + resource_name="model ID", + resource_arn=resource_arn, + ) - for model_id_key in model_id_keys: - try: - model_id_from_tag = get_tag_value(model_id_key, list_tags_result) - except KeyError: - continue - if model_id_from_tag is not None: - if model_id is not None and model_id_from_tag != model_id: - constants.JUMPSTART_LOGGER.warning( - "Found multiple model ID tags on the following resource: %s", resource_arn - ) - model_id = None - break - model_id = model_id_from_tag + model_version: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_version_keys, + list_tags_result=list_tags_result, + resource_name="model version", + resource_arn=resource_arn, + ) - for model_version_key in model_version_keys: - try: - model_version_from_tag = get_tag_value(model_version_key, list_tags_result) - except KeyError: - continue - if model_version_from_tag is not None: - if model_version is not None and model_version_from_tag != model_version: - constants.JUMPSTART_LOGGER.warning( - "Found multiple model version tags on the following resource: %s", resource_arn - ) - model_version = None - break - model_version = model_version_from_tag + inference_config_name: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=inference_config_name_keys, + list_tags_result=list_tags_result, + resource_name="inference config name", + resource_arn=resource_arn, + ) - return model_id, model_version + training_config_name: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=training_config_name_keys, + list_tags_result=list_tags_result, + resource_name="training config name", + resource_arn=resource_arn, + ) + + return model_id, model_version, inference_config_name, training_config_name def get_region_fallback( @@ -916,7 +998,11 @@ def get_config_names( scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> List[str]: - """Returns a list of config names for the given model ID and region.""" + """Returns a list of config names for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -931,7 +1017,7 @@ def get_config_names( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") return list(metadata_configs.configs.keys()) if metadata_configs else [] @@ -946,7 +1032,11 @@ def get_benchmark_stats( scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> Dict[str, List[JumpStartBenchmarkStat]]: - """Returns benchmark stats for the given model ID and region.""" + """Returns benchmark stats for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -962,7 +1052,7 @@ def get_benchmark_stats( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") if not config_names: config_names = metadata_configs.configs.keys() if metadata_configs else [] @@ -970,7 +1060,7 @@ def get_benchmark_stats( benchmark_stats = {} for config_name in config_names: if config_name not in metadata_configs.configs: - raise ValueError(f"Unknown config name: '{config_name}'") + raise ValueError(f"Unknown config name: {config_name}") benchmark_stats[config_name] = metadata_configs.configs.get(config_name).benchmark_metrics return benchmark_stats @@ -984,8 +1074,12 @@ def get_jumpstart_configs( sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, -) -> Dict[str, List[JumpStartMetadataConfig]]: - """Returns metadata configs for the given model ID and region.""" +) -> Dict[str, JumpStartMetadataConfig]: + """Returns metadata configs for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -1000,10 +1094,12 @@ def get_jumpstart_configs( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") if not config_names: - config_names = metadata_configs.configs.keys() if metadata_configs else [] + config_names = ( + metadata_configs.config_rankings.get("overall").rankings if metadata_configs else [] + ) return ( {config_name: metadata_configs.configs[config_name] for config_name in config_names} @@ -1046,3 +1142,252 @@ def get_default_jumpstart_session_with_user_agent_suffix( config=botocore_config, ) return session + + +def add_instance_rate_stats_to_benchmark_metrics( + region: str, + benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]], +) -> Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]: + """Adds instance types metric stats to the given benchmark_metrics dict. + + Args: + region (str): AWS region. + benchmark_metrics (Optional[Dict[str, List[JumpStartBenchmarkStat]]]): + Returns: + Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]: + Contains Error and metrics. + """ + if not benchmark_metrics: + return None + + err_message = None + final_benchmark_metrics = {} + for instance_type, benchmark_metric_stats in benchmark_metrics.items(): + instance_type = instance_type if instance_type.startswith("ml.") else f"ml.{instance_type}" + + if not has_instance_rate_stat(benchmark_metric_stats) and not err_message: + try: + instance_type_rate = get_instance_rate_per_hour( + instance_type=instance_type, region=region + ) + + if not benchmark_metric_stats: + benchmark_metric_stats = [] + benchmark_metric_stats.append( + JumpStartBenchmarkStat({"concurrency": None, **instance_type_rate}) + ) + + final_benchmark_metrics[instance_type] = benchmark_metric_stats + except ClientError as e: + final_benchmark_metrics[instance_type] = benchmark_metric_stats + err_message = e.response["Error"] + except Exception: # pylint: disable=W0703 + final_benchmark_metrics[instance_type] = benchmark_metric_stats + else: + final_benchmark_metrics[instance_type] = benchmark_metric_stats + + return err_message, final_benchmark_metrics + + +def has_instance_rate_stat(benchmark_metric_stats: Optional[List[JumpStartBenchmarkStat]]) -> bool: + """Determines whether a benchmark metric stats contains instance rate metric stat. + + Args: + benchmark_metric_stats (Optional[List[JumpStartBenchmarkStat]]): + List of benchmark metric stats. + Returns: + bool: Whether the benchmark metric stats contains instance rate metric stat. + """ + if benchmark_metric_stats is None: + return True + for benchmark_metric_stat in benchmark_metric_stats: + if benchmark_metric_stat.name.lower() == "instance rate": + return True + return False + + +def get_metrics_from_deployment_configs( + deployment_configs: Optional[List[DeploymentConfigMetadata]], +) -> Dict[str, List[str]]: + """Extracts benchmark metrics from deployment configs metadata. + + Args: + deployment_configs (Optional[List[DeploymentConfigMetadata]]): + List of deployment configs metadata. + Returns: + Dict[str, List[str]]: Deployment configs bench metrics dict. + """ + if not deployment_configs: + return {} + + data = {"Instance Type": [], "Config Name": [], "Concurrent Users": []} + instance_rate_data = {} + for index, deployment_config in enumerate(deployment_configs): + benchmark_metrics = deployment_config.benchmark_metrics + if not deployment_config.deployment_args or not benchmark_metrics: + continue + + for current_instance_type, current_instance_type_metrics in benchmark_metrics.items(): + instance_type_rate, concurrent_users = _normalize_benchmark_metrics( + current_instance_type_metrics + ) + + for concurrent_user, metrics in concurrent_users.items(): + instance_type_to_display = ( + f"{current_instance_type} (Default)" + if index == 0 + and concurrent_user + and int(concurrent_user) == 1 + and current_instance_type + == deployment_config.deployment_args.default_instance_type + else current_instance_type + ) + + data["Config Name"].append(deployment_config.deployment_config_name) + data["Instance Type"].append(instance_type_to_display) + data["Concurrent Users"].append(concurrent_user) + + if instance_type_rate: + instance_rate_column_name = ( + f"{instance_type_rate.name} ({instance_type_rate.unit})" + ) + instance_rate_data[instance_rate_column_name] = instance_rate_data.get( + instance_rate_column_name, [] + ) + instance_rate_data[instance_rate_column_name].append(instance_type_rate.value) + + for metric in metrics: + column_name = _normalize_benchmark_metric_column_name(metric.name, metric.unit) + data[column_name] = data.get(column_name, []) + data[column_name].append(metric.value) + + data = {**data, **instance_rate_data} + return data + + +def _normalize_benchmark_metric_column_name(name: str, unit: str) -> str: + """Normalizes benchmark metric column name. + + Args: + name (str): Name of the metric. + unit (str): Unit of the metric. + Returns: + str: Normalized metric column name. + """ + if "latency" in name.lower(): + name = f"Latency, TTFT (P50 in {unit.lower()})" + elif "throughput" in name.lower(): + name = f"Throughput (P50 in {unit.lower()}/user)" + return name + + +def _normalize_benchmark_metrics( + benchmark_metric_stats: List[JumpStartBenchmarkStat], +) -> Tuple[JumpStartBenchmarkStat, Dict[str, List[JumpStartBenchmarkStat]]]: + """Normalizes benchmark metrics dict. + + Args: + benchmark_metric_stats (List[JumpStartBenchmarkStat]): + List of benchmark metrics stats. + Returns: + Tuple[JumpStartBenchmarkStat, Dict[str, List[JumpStartBenchmarkStat]]]: + Normalized benchmark metrics dict. + """ + instance_type_rate = None + concurrent_users = {} + for current_instance_type_metric in benchmark_metric_stats: + if "instance rate" in current_instance_type_metric.name.lower(): + instance_type_rate = current_instance_type_metric + elif current_instance_type_metric.concurrency not in concurrent_users: + concurrent_users[current_instance_type_metric.concurrency] = [ + current_instance_type_metric + ] + else: + concurrent_users[current_instance_type_metric.concurrency].append( + current_instance_type_metric + ) + + return instance_type_rate, concurrent_users + + +def deployment_config_response_data( + deployment_configs: Optional[List[DeploymentConfigMetadata]], +) -> List[Dict[str, Any]]: + """Deployment config api response data. + + Args: + deployment_configs (Optional[List[DeploymentConfigMetadata]]): + List of deployment configs metadata. + Returns: + List[Dict[str, Any]]: List of deployment config api response data. + """ + configs = [] + if not deployment_configs: + return configs + + for deployment_config in deployment_configs: + deployment_config_json = deployment_config.to_json() + benchmark_metrics = deployment_config_json.get("BenchmarkMetrics") + if benchmark_metrics and deployment_config.deployment_args: + deployment_config_json["BenchmarkMetrics"] = { + deployment_config.deployment_args.instance_type: benchmark_metrics.get( + deployment_config.deployment_args.instance_type + ) + } + + configs.append(deployment_config_json) + return configs + + +def _deployment_config_lru_cache(_func=None, *, maxsize: int = 128, typed: bool = False): + """LRU cache for deployment configs.""" + + def has_instance_rate_metric(config: DeploymentConfigMetadata) -> bool: + """Determines whether metadata config contains instance rate metric stat. + + Args: + config (DeploymentConfigMetadata): Metadata config metadata. + Returns: + bool: Whether the metadata config contains instance rate metric stat. + """ + if config.benchmark_metrics is None: + return True + for benchmark_metric_stats in config.benchmark_metrics.values(): + if not has_instance_rate_stat(benchmark_metric_stats): + return False + return True + + def wrapper_cache(f): + f = lru_cache(maxsize=maxsize, typed=typed)(f) + + @wraps(f) + def wrapped_f(*args, **kwargs): + res = f(*args, **kwargs) + + # Clear cache on first call if + # - The output does not contain Instant rate metrics + # as this is caused by missing policy. + if f.cache_info().hits == 0 and f.cache_info().misses == 1: + if isinstance(res, list): + for item in res: + if isinstance( + item, DeploymentConfigMetadata + ) and not has_instance_rate_metric(item): + f.cache_clear() + break + elif isinstance(res, dict): + keys = list(res.keys()) + if len(keys) == 0 or "Instance Rate" not in keys[-1]: + f.cache_clear() + elif len(res[keys[1]]) > len(res[keys[-1]]): + del res[keys[-1]] + f.cache_clear() + return res + + wrapped_f.cache_info = f.cache_info + wrapped_f.cache_clear = f.cache_clear + return wrapped_f + + if _func is None: + return wrapper_cache + return wrapper_cache(_func) diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index 1a4849522c..ea8041d1ee 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -172,6 +172,7 @@ def validate_hyperparameters( sagemaker_session: Optional[session.Session] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + config_name: Optional[str] = None, ) -> None: """Validate hyperparameters for JumpStart models. @@ -194,6 +195,7 @@ def validate_hyperparameters( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Raises: JumpStartHyperparametersError: If the hyperparameters are not formatted correctly, @@ -220,6 +222,7 @@ def validate_hyperparameters( sagemaker_session=sagemaker_session, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + config_name=config_name, ) hyperparameters_specs = model_specs.hyperparameters diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index a31d5d930d..dbf7ef7650 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -34,6 +34,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model matching the given arguments. @@ -59,6 +60,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: The default metric definitions to use for the model or None. @@ -80,4 +82,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 4292c3efdf..5d7ee5b378 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -73,6 +73,8 @@ format_tags, Tags, _resolve_routing_config, + _validate_new_tags, + remove_tag_with_key, ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor @@ -166,6 +168,7 @@ def __init__( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, resources: Optional[ResourceRequirements] = None, + additional_model_data_sources: Optional[Dict[str, Any]] = None, model_reference_arn: Optional[str] = None, ): """Initialize an SageMaker ``Model``. @@ -330,11 +333,14 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). + additional_model_data_sources (Optional[Dict[str, Any]]): Additional location + of SageMaker model data (default: None). model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type content (default: None). """ self.model_data = model_data + self.additional_model_data_sources = additional_model_data_sources self.image_uri = image_uri self.predictor_cls = predictor_cls self.name = name @@ -411,6 +417,23 @@ def __init__( self.response_types = None self.accept_eula = None self.model_reference_arn = model_reference_arn + self._tags: Optional[Tags] = None + + def add_tags(self, tags: Tags) -> None: + """Add tags to this ``Model`` + + Args: + tags (Tags): Tags to add. + """ + self._tags = _validate_new_tags(tags, self._tags) + + def remove_tag_with_key(self, key: str) -> None: + """Remove a tag with the given key from the list of tags. + + Args: + key (str): The key of the tag to remove. + """ + self._tags = remove_tag_with_key(key, self._tags) @classmethod def attach( @@ -700,6 +723,7 @@ def prepare_container_def( accept_eula=( accept_eula if accept_eula is not None else getattr(self, "accept_eula", None) ), + additional_model_data_sources=self.additional_model_data_sources, model_reference_arn=( model_reference_arn if model_reference_arn is not None @@ -1491,7 +1515,8 @@ def deploy( sagemaker_session=self.sagemaker_session, ) - tags = format_tags(tags) + self.add_tags(tags) + tags = format_tags(self._tags) if ( getattr(self.sagemaker_session, "settings", None) is not None diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index a2177c0ec5..2949fbaf5f 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -35,6 +35,7 @@ def retrieve( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieves the model artifact Amazon S3 URI for the model matching the given arguments. @@ -60,6 +61,8 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + Returns: str: The model artifact S3 URI for the corresponding model. @@ -85,4 +88,5 @@ def retrieve( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index d3f41bd9a6..df8554f7e8 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -18,7 +18,7 @@ from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint +from sagemaker.jumpstart.session_utils import get_model_info_from_endpoint from sagemaker.session import Session @@ -44,6 +44,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Predictor: """Retrieves the default predictor for the model matching the given arguments. @@ -68,6 +69,8 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + config_name (Optional[str]): The name of the configuration to use for the + predictor. (Default: None) Returns: Predictor: The default predictor to use for the model. @@ -81,9 +84,9 @@ def retrieve_default( inferred_model_id, inferred_model_version, inferred_inference_component_name, - ) = get_model_id_version_from_endpoint( - endpoint_name, inference_component_name, sagemaker_session - ) + inferred_config_name, + _, + ) = get_model_info_from_endpoint(endpoint_name, inference_component_name, sagemaker_session) if not inferred_model_id: raise ValueError( @@ -95,6 +98,7 @@ def retrieve_default( model_id = inferred_model_id model_version = model_version or inferred_model_version or "*" inference_component_name = inference_component_name or inferred_inference_component_name + config_name = config_name or inferred_config_name or None else: model_version = model_version or "*" @@ -114,4 +118,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index 396a158939..d0ddea4432 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -38,6 +38,7 @@ def retrieve_default( model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> ResourceRequirements: """Retrieves the default resource requirements for the model matching the given arguments. @@ -65,6 +66,7 @@ def retrieve_default( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get host requirements specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default resource requirements to use for the model. @@ -91,4 +93,5 @@ def retrieve_default( model_type=model_type, sagemaker_session=sagemaker_session, instance_type=instance_type, + config_name=config_name, ) diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 91a5a97b1f..d60095b521 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -34,6 +34,7 @@ def retrieve( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -58,6 +59,7 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The model script URI for the corresponding model. @@ -82,4 +84,5 @@ def retrieve( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index 4ffd121ad8..ef502dc6f3 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -46,6 +46,7 @@ def retrieve_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model matching the given arguments. @@ -69,6 +70,7 @@ def retrieve_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[SimpleBaseSerializer]: The supported serializers to use for the model. @@ -89,6 +91,7 @@ def retrieve_options( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) @@ -101,6 +104,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. @@ -124,6 +128,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: SimpleBaseSerializer: The default serializer to use for the model. @@ -145,4 +150,5 @@ def retrieve_default( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/serve/builder/djl_builder.py b/src/sagemaker/serve/builder/djl_builder.py index d234259db9..75acd0d1fe 100644 --- a/src/sagemaker/serve/builder/djl_builder.py +++ b/src/sagemaker/serve/builder/djl_builder.py @@ -24,6 +24,7 @@ LocalModelOutOfMemoryException, LocalModelInvocationException, ) +from sagemaker.serve.utils.optimize_utils import _is_optimized from sagemaker.serve.utils.tuning import ( _serial_benchmark, _concurrent_benchmark, @@ -90,6 +91,7 @@ def __init__(self): self.env_vars = None self.nb_instance_type = None self.ram_usage_model_load = None + self.role_arn = None @abstractmethod def _prepare_for_mode(self): @@ -213,9 +215,10 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa del kwargs["role"] # set model_data to uncompressed s3 dict - self.pysdk_model.model_data, env_vars = self._prepare_for_mode() - self.env_vars.update(env_vars) - self.pysdk_model.env.update(self.env_vars) + if not _is_optimized(self.pysdk_model): + self.pysdk_model.model_data, env_vars = self._prepare_for_mode() + self.env_vars.update(env_vars) + self.pysdk_model.env.update(self.env_vars) # if the weights have been cached via local container mode -> set to offline if str(Mode.LOCAL_CONTAINER) in self.modes: @@ -449,4 +452,8 @@ def _build_for_djl(self): self.pysdk_model = self._build_for_hf_djl() self.pysdk_model.tune = self._tune_for_hf_djl + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index bc31e8d323..07885792d2 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -14,11 +14,15 @@ from __future__ import absolute_import import copy +import re from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Type +from typing import Type, Any, List, Dict, Optional import logging +from botocore.exceptions import ClientError + +from sagemaker.enums import Tag from sagemaker.model import Model from sagemaker import model_uris from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources @@ -33,6 +37,18 @@ LocalModelLoadException, SkipTuningComboException, ) +from sagemaker.serve.utils.optimize_utils import ( + _generate_model_source, + _update_environment_variables, + _extract_speculative_draft_model_provider, + _is_image_compatible_with_optimization_job, + _generate_channel_name, + _extract_optimization_config_and_env, + _is_optimized, + _custom_speculative_decoding, + SPECULATIVE_DRAFT_MODEL, + _is_inferentia_or_trainium, +) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, TgiLocalModePredictor, @@ -53,6 +69,7 @@ from sagemaker.serve.utils.types import ModelServer from sagemaker.base_predictor import PredictorBase from sagemaker.jumpstart.model import JumpStartModel +from sagemaker.utils import Tags _DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py" _NO_JS_MODEL_EX = "HuggingFace JumpStart Model ID not detected. Building for HuggingFace Model ID." @@ -94,12 +111,19 @@ def __init__(self): self.prepared_for_djl = None self.prepared_for_mms = None self.schema_builder = None + self.instance_type = None self.nb_instance_type = None self.ram_usage_model_load = None - self.jumpstart = None + self.model_hub = None + self.model_metadata = None + self.role_arn = None + self.is_fine_tuned = None + self.is_compiled = False + self.is_quantized = False + self.speculative_decoding_draft_model_source = None @abstractmethod - def _prepare_for_mode(self): + def _prepare_for_mode(self, **kwargs): """Placeholder docstring""" @abstractmethod @@ -108,6 +132,9 @@ def _get_client_translators(self): def _is_jumpstart_model_id(self) -> bool: """Placeholder docstring""" + if self.model is None: + return False + try: model_uris.retrieve(model_id=self.model, model_version="*", model_scope=_JS_SCOPE) except KeyError: @@ -119,8 +146,9 @@ def _is_jumpstart_model_id(self) -> bool: def _create_pre_trained_js_model(self) -> Type[Model]: """Placeholder docstring""" - pysdk_model = JumpStartModel(self.model, vpc_config=self.vpc_config) - pysdk_model.sagemaker_session = self.sagemaker_session + pysdk_model = JumpStartModel( + self.model, vpc_config=self.vpc_config, sagemaker_session=self.sagemaker_session + ) self._original_deploy = pysdk_model.deploy pysdk_model.deploy = self._js_builder_deploy_wrapper @@ -129,6 +157,7 @@ def _create_pre_trained_js_model(self) -> Type[Model]: @_capture_telemetry("jumpstart.deploy") def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: """Placeholder docstring""" + env = {} if "mode" in kwargs and kwargs.get("mode") != self.mode: overwrite_mode = kwargs.get("mode") # mode overwritten by customer during model.deploy() @@ -145,7 +174,8 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: or not hasattr(self, "prepared_for_tgi") or not hasattr(self, "prepared_for_mms") ): - self.pysdk_model.model_data, env = self._prepare_for_mode() + if not _is_optimized(self.pysdk_model): + self.pysdk_model.model_data, env = self._prepare_for_mode() elif overwrite_mode == Mode.LOCAL_CONTAINER: self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER @@ -176,7 +206,6 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: ) self._prepare_for_mode() - env = {} else: raise ValueError("Mode %s is not supported!" % overwrite_mode) @@ -249,7 +278,7 @@ def _build_for_djl_jumpstart(self): ) self._prepare_for_mode() elif self.mode == Mode.SAGEMAKER_ENDPOINT and hasattr(self, "prepared_for_djl"): - self.nb_instance_type = _get_nb_instance() + self.nb_instance_type = self.instance_type or _get_nb_instance() self.pysdk_model.model_data, env = self._prepare_for_mode() self.pysdk_model.env.update(env) @@ -467,17 +496,133 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800): sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration ) + def set_deployment_config(self, config_name: str, instance_type: str) -> None: + """Sets the deployment config to apply to the model. + + Args: + config_name (str): + The name of the deployment config to apply to the model. + Call list_deployment_configs to see the list of config names. + instance_type (str): + The instance_type that the model will use after setting + the config. + """ + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + raise Exception("Cannot set deployment config to an uninitialized model.") + + self.pysdk_model.set_deployment_config(config_name, instance_type) + + self.instance_type = instance_type + + # JS-benchmarked models only include SageMaker-provided SD models + if self.pysdk_model.additional_model_data_sources: + self.speculative_decoding_draft_model_source = "sagemaker" + self.pysdk_model.add_tags( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"}, + ) + self.pysdk_model.remove_tag_with_key(Tag.OPTIMIZATION_JOB_NAME) + self.pysdk_model.remove_tag_with_key(Tag.FINE_TUNING_MODEL_PATH) + self.pysdk_model.remove_tag_with_key(Tag.FINE_TUNING_JOB_NAME) + + def get_deployment_config(self) -> Optional[Dict[str, Any]]: + """Gets the deployment config to apply to the model. + + Returns: + Optional[Dict[str, Any]]: Deployment config to apply to this model. + """ + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + self._build_for_jumpstart() + + return self.pysdk_model.deployment_config + + def display_benchmark_metrics(self, **kwargs): + """Display Markdown Benchmark Metrics for deployment configs.""" + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + self._build_for_jumpstart() + + self.pysdk_model.display_benchmark_metrics(**kwargs) + + def list_deployment_configs(self) -> List[Dict[str, Any]]: + """List deployment configs for ``This`` model in the current region. + + Returns: + List[Dict[str, Any]]: A list of deployment configs. + """ + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + self._build_for_jumpstart() + + return self.pysdk_model.list_deployment_configs() + + def _is_fine_tuned_model(self) -> bool: + """Checks whether a fine-tuned model exists.""" + return self.model_metadata and ( + self.model_metadata.get("FINE_TUNING_MODEL_PATH") + or self.model_metadata.get("FINE_TUNING_JOB_NAME") + ) + + def _update_model_data_for_fine_tuned_model(self, pysdk_model: Type[Model]) -> Type[Model]: + """Set the model path and data and add fine-tuning tags for the model.""" + # TODO: determine precedence of FINE_TUNING_MODEL_PATH and FINE_TUNING_JOB_NAME + if fine_tuning_model_path := self.model_metadata.get("FINE_TUNING_MODEL_PATH"): + if not re.match("^(https|s3)://([^/]+)/?(.*)$", fine_tuning_model_path): + raise ValueError( + f"Invalid path for FINE_TUNING_MODEL_PATH: {fine_tuning_model_path}." + ) + pysdk_model.model_data["S3DataSource"]["S3Uri"] = fine_tuning_model_path + pysdk_model.add_tags( + {"Key": Tag.FINE_TUNING_MODEL_PATH, "Value": fine_tuning_model_path} + ) + logger.info( + "FINE_TUNING_MODEL_PATH detected. Using fine-tuned model found in %s.", + fine_tuning_model_path, + ) + return pysdk_model + + if fine_tuning_job_name := self.model_metadata.get("FINE_TUNING_JOB_NAME"): + try: + response = self.sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=fine_tuning_job_name + ) + fine_tuning_model_path = response["ModelArtifacts"]["S3ModelArtifacts"] + pysdk_model.model_data["S3DataSource"]["S3Uri"] = fine_tuning_model_path + pysdk_model.add_tags( + [ + {"key": Tag.FINE_TUNING_JOB_NAME, "value": fine_tuning_job_name}, + {"key": Tag.FINE_TUNING_MODEL_PATH, "value": fine_tuning_model_path}, + ] + ) + logger.info( + "FINE_TUNING_JOB_NAME detected. Using fine-tuned model found in %s.", + fine_tuning_model_path, + ) + return pysdk_model + except ClientError: + raise ValueError( + f"Invalid job name for FINE_TUNING_JOB_NAME: {fine_tuning_job_name}." + ) + + raise ValueError( + "Input model not found. Please provide either `model_path`, or " + "`FINE_TUNING_MODEL_PATH` or `FINE_TUNING_JOB_NAME` under `model_metadata`." + ) + def _build_for_jumpstart(self): """Placeholder docstring""" + if hasattr(self, "pysdk_model") and self.pysdk_model is not None: + return self.pysdk_model + # we do not pickle for jumpstart. set to none self.secret_key = None - self.jumpstart = True pysdk_model = self._create_pre_trained_js_model() image_uri = pysdk_model.image_uri logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri) + if self._is_fine_tuned_model(): + self.is_fine_tuned = True + pysdk_model = self._update_model_data_for_fine_tuned_model(pysdk_model) + if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT: raise ValueError( "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode." @@ -514,9 +659,142 @@ def _build_for_jumpstart(self): "with djl-inference, tgi-inference, or mms-inference container." ) + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model - def _is_gated_model(self, model) -> bool: + def _optimize_for_jumpstart( + self, + output_path: Optional[str] = None, + instance_type: Optional[str] = None, + role_arn: Optional[str] = None, + tags: Optional[Tags] = None, + job_name: Optional[str] = None, + accept_eula: Optional[bool] = None, + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + speculative_decoding_config: Optional[Dict] = None, + env_vars: Optional[Dict] = None, + vpc_config: Optional[Dict] = None, + kms_key: Optional[str] = None, + max_runtime_in_sec: Optional[int] = None, + ) -> Optional[Dict[str, Any]]: + """Runs a model optimization job. + + Args: + output_path (Optional[str]): Specifies where to store the compiled/quantized model. + instance_type (Optional[str]): Target deployment instance type that + the model is optimized for. + role_arn (Optional[str]): Execution role. Defaults to ``None``. + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. + Defaults to ``None`` + env_vars (Optional[Dict]): Additional environment variables to run the optimization + container. Defaults to ``None``. + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading + to S3. Defaults to ``None``. + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to + ``None``. + + Returns: + Dict[str, Any]: Model optimization job input arguments. + """ + if self._is_gated_model() and accept_eula is not True: + raise ValueError( + f"Model '{self.model}' requires accepting end-user license agreement (EULA)." + ) + + is_compilation = (quantization_config is None) and ( + (compilation_config is not None) or _is_inferentia_or_trainium(instance_type) + ) + + pysdk_model_env_vars = dict() + if is_compilation: + pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type) + + optimization_config, override_env = _extract_optimization_config_and_env( + quantization_config, compilation_config + ) + if not optimization_config and is_compilation: + override_env = override_env or pysdk_model_env_vars + optimization_config = { + "ModelCompilationConfig": { + "OverrideEnvironment": override_env, + } + } + + if speculative_decoding_config: + self._set_additional_model_source(speculative_decoding_config) + else: + deployment_config = self._find_compatible_deployment_config(None) + if deployment_config: + self.pysdk_model.set_deployment_config( + config_name=deployment_config.get("DeploymentConfigName"), + instance_type=deployment_config.get("InstanceType"), + ) + pysdk_model_env_vars = self.pysdk_model.env + + model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula) + optimization_env_vars = _update_environment_variables(pysdk_model_env_vars, env_vars) + + output_config = {"S3OutputLocation": output_path} + if kms_key: + output_config["KmsKeyId"] = kms_key + + deployment_config_instance_type = ( + self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get("InstanceType") + if self.pysdk_model.deployment_config + else None + ) + self.instance_type = instance_type or deployment_config_instance_type or _get_nb_instance() + self.role_arn = role_arn or self.role_arn + + create_optimization_job_args = { + "OptimizationJobName": job_name, + "ModelSource": model_source, + "DeploymentInstanceType": self.instance_type, + "OptimizationConfigs": [optimization_config], + "OutputConfig": output_config, + "RoleArn": self.role_arn, + } + + if optimization_env_vars: + create_optimization_job_args["OptimizationEnvironment"] = optimization_env_vars + if max_runtime_in_sec: + create_optimization_job_args["StoppingCondition"] = { + "MaxRuntimeInSeconds": max_runtime_in_sec + } + if tags: + create_optimization_job_args["Tags"] = tags + if vpc_config: + create_optimization_job_args["VpcConfig"] = vpc_config + + if accept_eula: + self.pysdk_model.accept_eula = accept_eula + if isinstance(self.pysdk_model.model_data, dict): + self.pysdk_model.model_data["S3DataSource"]["ModelAccessConfig"] = { + "AcceptEula": True + } + + if quantization_config or is_compilation: + self.pysdk_model.env = _update_environment_variables( + optimization_env_vars, override_env + ) + return create_optimization_job_args + return None + + def _is_gated_model(self, model=None) -> bool: """Determine if ``this`` Model is Gated Args: @@ -524,10 +802,116 @@ def _is_gated_model(self, model) -> bool: Returns: bool: ``True`` if ``this`` Model is Gated """ - s3_uri = model.model_data + s3_uri = model.model_data if model else self.pysdk_model.model_data if isinstance(s3_uri, dict): s3_uri = s3_uri.get("S3DataSource").get("S3Uri") if s3_uri is None: return False return "private" in s3_uri + + def _set_additional_model_source( + self, + speculative_decoding_config: Optional[Dict[str, Any]] = None, + accept_eula: Optional[bool] = None, + ) -> None: + """Set Additional Model Source to ``this`` model. + + Args: + speculative_decoding_config (Optional[Dict[str, Any]]): Speculative decoding config. + accept_eula (Optional[bool]): For models that require a Model Access Config. + """ + if speculative_decoding_config: + model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config) + channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources) + + if model_provider == "sagemaker": + additional_model_data_sources = ( + self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get( + "AdditionalDataSources" + ) + if self.pysdk_model.deployment_config + else None + ) + if additional_model_data_sources is None: + deployment_config = self._find_compatible_deployment_config( + speculative_decoding_config + ) + if deployment_config: + self.pysdk_model.set_deployment_config( + config_name=deployment_config.get("DeploymentConfigName"), + instance_type=deployment_config.get("InstanceType"), + ) + else: + raise ValueError( + "Cannot find deployment config compatible for optimization job." + ) + + self.pysdk_model.env.update( + {"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}"} + ) + self.pysdk_model.add_tags( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"}, + ) + else: + self.pysdk_model = _custom_speculative_decoding( + self.pysdk_model, speculative_decoding_config, accept_eula + ) + + def _find_compatible_deployment_config( + self, speculative_decoding_config: Optional[Dict] = None + ) -> Optional[Dict[str, Any]]: + """Finds compatible model deployment config for optimization job. + + Args: + speculative_decoding_config (Optional[Dict]): Speculative decoding config. + + Returns: + Optional[Dict[str, Any]]: A compatible model deployment config for optimization job. + """ + model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config) + for deployment_config in self.pysdk_model.list_deployment_configs(): + image_uri = deployment_config.get("deployment_config", {}).get("ImageUri") + + if _is_image_compatible_with_optimization_job(image_uri): + if ( + model_provider == "sagemaker" + and deployment_config.get("DeploymentArgs", {}).get("AdditionalDataSources") + ) or model_provider == "custom": + return deployment_config + + # There's no matching config from jumpstart to add sagemaker draft model location + if model_provider == "sagemaker": + return None + + # fall back to the default jumpstart model deployment config for optimization job + return self.pysdk_model.deployment_config + + def _get_neuron_model_env_vars( + self, instance_type: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """Gets Neuron model env vars. + + Args: + instance_type (Optional[str]): Instance type. + + Returns: + Optional[Dict[str, Any]]: Neuron Model environment variables. + """ + metadata_configs = self.pysdk_model._metadata_configs + if metadata_configs: + metadata_config = metadata_configs.get(self.pysdk_model.config_name) + resolve_config = metadata_config.resolved_config if metadata_config else None + if resolve_config and instance_type not in resolve_config.get( + "supported_inference_instance_types", [] + ): + neuro_model_id = resolve_config.get("hosting_neuron_model_id") + neuro_model_version = resolve_config.get("hosting_neuron_model_version", "*") + if neuro_model_id: + job_model = JumpStartModel( + neuro_model_id, + model_version=neuro_model_version, + vpc_config=self.vpc_config, + ) + return job_model.env + return None diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index a3762e6638..01b2b96f68 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -23,6 +23,7 @@ from pathlib import Path +from sagemaker.enums import Tag from sagemaker.s3 import S3Downloader from sagemaker import Session @@ -67,6 +68,15 @@ from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model +from sagemaker.serve.utils.optimize_utils import ( + _generate_optimized_model, + _generate_model_source, + _extract_optimization_config_and_env, + _is_s3_uri, + _normalize_local_model_path, + _custom_speculative_decoding, + _extract_speculative_draft_model_provider, +) from sagemaker.serve.utils.predictors import _get_local_mode_predictor from sagemaker.serve.utils.hardware_detector import ( _get_gpu_info, @@ -81,15 +91,19 @@ from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve from sagemaker.serve.model_server.triton.triton_builder import Triton from sagemaker.serve.utils.telemetry_logger import _capture_telemetry -from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.utils.types import ModelServer, ModelHub from sagemaker.serve.validations.check_image_uri import is_1p_image_uri from sagemaker.serve.save_retrive.version_1_0_0.save.save_handler import SaveHandler from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import get_metadata from sagemaker.serve.validations.check_image_and_hardware_type import ( validate_image_uri_and_hardware, ) +from sagemaker.utils import Tags from sagemaker.workflow.entities import PipelineVariable -from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata +from sagemaker.huggingface.llm_utils import ( + get_huggingface_model_metadata, + download_huggingface_model_metadata, +) logger = logging.getLogger(__name__) @@ -184,7 +198,11 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, new models without task metadata in the Hub, adding unsupported task types will throw an exception. ``MLFLOW_MODEL_PATH`` is available for providing local path or s3 path to MLflow artifacts. However, ``MLFLOW_MODEL_PATH`` is experimental and is not - intended for production use at this moment. + intended for production use at this moment. ``CUSTOM_MODEL_PATH`` is available for + providing local path or s3 path to model artifacts. ``FINE_TUNING_MODEL_PATH`` is + available for providing s3 path to fine-tuned model artifacts. ``FINE_TUNING_JOB_NAME`` + is available for providing fine-tuned job name. Both ``FINE_TUNING_MODEL_PATH`` and + ``FINE_TUNING_JOB_NAME`` are mutually exclusive. """ model_path: Optional[str] = field( @@ -285,9 +303,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, default=None, metadata={ "help": "Define the model metadata to override, currently supports `HF_TASK`, " - "`MLFLOW_MODEL_PATH`, and `MLFLOW_TRACKING_ARN`. HF_TASK should be set for new " - "models without task metadata in the Hub, Adding unsupported task types will " - "throw an exception" + "`MLFLOW_MODEL_PATH`, `FINE_TUNING_MODEL_PATH`, `FINE_TUNING_JOB_NAME`, and " + "`CUSTOM_MODEL_PATH`. HF_TASK should be set for new models without task metadata " + "in the Hub, Adding unsupported task types will throw an exception." }, ) @@ -364,8 +382,15 @@ def _get_serve_setting(self): sagemaker_session=self.sagemaker_session, ) - def _prepare_for_mode(self): - """Placeholder docstring""" + def _prepare_for_mode( + self, model_path: Optional[str] = None, should_upload_artifacts: Optional[bool] = False + ): + """Prepare this `Model` for serving. + + Args: + model_path (Optional[str]): Model path + should_upload_artifacts (Optional[bool]): Whether to upload artifacts to S3. + """ # TODO: move mode specific prepare steps under _model_builder_deploy_wrapper self.s3_upload_path = None if self.mode == Mode.SAGEMAKER_ENDPOINT: @@ -376,12 +401,13 @@ def _prepare_for_mode(self): self.s3_upload_path, env_vars_sagemaker = self.modes[ str(Mode.SAGEMAKER_ENDPOINT) ].prepare( - self.model_path, + (model_path or self.model_path), self.secret_key, self.serve_settings.s3_model_data_url, self.sagemaker_session, self.image_uri, - self.jumpstart if hasattr(self, "jumpstart") else False, + getattr(self, "model_hub", None) == ModelHub.JUMPSTART, + should_upload_artifacts=should_upload_artifacts, ) self.env_vars.update(env_vars_sagemaker) return self.s3_upload_path, env_vars_sagemaker @@ -460,6 +486,10 @@ def _create_model(self): self.pysdk_model.mode = self.mode self.pysdk_model.modes = self.modes self.pysdk_model.serve_settings = self.serve_settings + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session # dynamically generate a method to direct model.deploy() logic based on mode # unique method to models created via ModelBuilder() @@ -621,11 +651,6 @@ def _handle_mlflow_input(self): mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH) artifact_path = self._get_artifact_path(mlflow_model_path) if not self._mlflow_metadata_exists(artifact_path): - logger.info( - "MLflow model metadata not detected in %s. ModelBuilder is not " - "handling MLflow model input", - mlflow_model_path, - ) return self._initialize_for_mlflow(artifact_path) @@ -799,7 +824,15 @@ def build( # pylint: disable=R0911 self.mode = mode if role_arn: self.role_arn = role_arn - self.sagemaker_session = sagemaker_session or Session() + + self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() + + self.sagemaker_session.settings._local_download_dir = self.model_path + + # DJL expects `HF_TOKEN` key. This allows backward compatibility + # until we deprecate HUGGING_FACE_HUB_TOKEN. + if self.env_vars.get("HUGGING_FACE_HUB_TOKEN") and not self.env_vars.get("HF_TOKEN"): + self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") self.sagemaker_session.settings._local_download_dir = self.model_path @@ -812,22 +845,25 @@ def build( # pylint: disable=R0911 ) self.serve_settings = self._get_serve_setting() - self._is_custom_image_uri = self.image_uri is not None self._handle_mlflow_input() self._build_validations() - if self.model_server: + if not self._is_jumpstart_model_id() and self.model_server: return self._build_for_model_server() if isinstance(self.model, str): model_task = None - if self.model_metadata: - model_task = self.model_metadata.get("HF_TASK") if self._is_jumpstart_model_id(): + self.model_hub = ModelHub.JUMPSTART return self._build_for_jumpstart() + self.model_hub = ModelHub.HUGGINGFACE + + if self.model_metadata: + model_task = self.model_metadata.get("HF_TASK") + if self._is_djl(): return self._build_for_djl() else: @@ -917,8 +953,15 @@ def save( This function is available for models served by DJL serving. Args: - save_path (Optional[str]): The path where you want to save resources. - s3_path (Optional[str]): The path where you want to upload resources. + save_path (Optional[str]): The path where you want to save resources. Defaults to + ``None``. + s3_path (Optional[str]): The path where you want to upload resources. Defaults to + ``None``. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. Defaults to + ``None``. + role_arn (Optional[str]): The IAM role arn. Defaults to ``None``. """ self.sagemaker_session = sagemaker_session or Session() @@ -1041,3 +1084,303 @@ def _try_fetch_gpu_info(self): raise ValueError( f"Unable to determine single GPU size for instance: [{self.instance_type}]" ) + + def optimize( + self, + output_path: Optional[str] = None, + instance_type: Optional[str] = None, + role_arn: Optional[str] = None, + tags: Optional[Tags] = None, + job_name: Optional[str] = None, + accept_eula: Optional[bool] = None, + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + speculative_decoding_config: Optional[Dict] = None, + env_vars: Optional[Dict] = None, + vpc_config: Optional[Dict] = None, + kms_key: Optional[str] = None, + max_runtime_in_sec: Optional[int] = 36000, + sagemaker_session: Optional[Session] = None, + ) -> Model: + """Create an optimized deployable ``Model`` instance with ``ModelBuilder``. + + Args: + output_path (str): Specifies where to store the compiled/quantized model. + instance_type (str): Target deployment instance type that the model is optimized for. + role_arn (Optional[str]): Execution role arn. Defaults to ``None``. + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. + Defaults to ``None`` + env_vars (Optional[Dict]): Additional environment variables to run the optimization + container. Defaults to ``None``. + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading + to S3. Defaults to ``None``. + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to + 36000 seconds. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Model: A deployable ``Model`` object. + """ + + # need to get telemetry_opt_out info before telemetry decorator is called + self.serve_settings = self._get_serve_setting() + + return self._model_builder_optimize_wrapper( + output_path=output_path, + instance_type=instance_type, + role_arn=role_arn, + tags=tags, + job_name=job_name, + accept_eula=accept_eula, + quantization_config=quantization_config, + compilation_config=compilation_config, + speculative_decoding_config=speculative_decoding_config, + env_vars=env_vars, + vpc_config=vpc_config, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + sagemaker_session=sagemaker_session, + ) + + @_capture_telemetry("optimize") + def _model_builder_optimize_wrapper( + self, + output_path: Optional[str] = None, + instance_type: Optional[str] = None, + role_arn: Optional[str] = None, + tags: Optional[Tags] = None, + job_name: Optional[str] = None, + accept_eula: Optional[bool] = None, + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + speculative_decoding_config: Optional[Dict] = None, + env_vars: Optional[Dict] = None, + vpc_config: Optional[Dict] = None, + kms_key: Optional[str] = None, + max_runtime_in_sec: Optional[int] = 36000, + sagemaker_session: Optional[Session] = None, + ) -> Model: + """Runs a model optimization job. + + Args: + output_path (str): Specifies where to store the compiled/quantized model. + instance_type (str): Target deployment instance type that the model is optimized for. + role_arn (Optional[str]): Execution role arn. Defaults to ``None``. + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. + Defaults to ``None`` + env_vars (Optional[Dict]): Additional environment variables to run the optimization + container. Defaults to ``None``. + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading + to S3. Defaults to ``None``. + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to + 36000 seconds. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Model: A deployable ``Model`` object. + """ + self.is_compiled = compilation_config is not None + self.is_quantized = quantization_config is not None + self.speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider( + speculative_decoding_config + ) + + if self.mode != Mode.SAGEMAKER_ENDPOINT: + raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.") + + if quantization_config and compilation_config: + raise ValueError("Quantization config and compilation config are mutually exclusive.") + + self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() + + self.instance_type = instance_type or self.instance_type + self.role_arn = role_arn or self.role_arn + + self.build(mode=self.mode, sagemaker_session=self.sagemaker_session) + job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}" + + if self._is_jumpstart_model_id(): + input_args = self._optimize_for_jumpstart( + output_path=output_path, + instance_type=instance_type, + role_arn=self.role_arn, + tags=tags, + job_name=job_name, + accept_eula=accept_eula, + quantization_config=quantization_config, + compilation_config=compilation_config, + speculative_decoding_config=speculative_decoding_config, + env_vars=env_vars, + vpc_config=vpc_config, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + ) + else: + input_args = self._optimize_for_hf( + output_path=output_path, + instance_type=instance_type, + role_arn=self.role_arn, + tags=tags, + job_name=job_name, + quantization_config=quantization_config, + compilation_config=compilation_config, + speculative_decoding_config=speculative_decoding_config, + env_vars=env_vars, + vpc_config=vpc_config, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + ) + + if input_args: + self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args) + job_status = self.sagemaker_session.wait_for_optimization_job(job_name) + return _generate_optimized_model(self.pysdk_model, job_status) + + self.pysdk_model.remove_tag_with_key(Tag.OPTIMIZATION_JOB_NAME) + if not speculative_decoding_config: + self.pysdk_model.remove_tag_with_key(Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER) + + return self.pysdk_model + + def _optimize_for_hf( + self, + output_path: str, + instance_type: Optional[str] = None, + role_arn: Optional[str] = None, + tags: Optional[Tags] = None, + job_name: Optional[str] = None, + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + speculative_decoding_config: Optional[Dict] = None, + env_vars: Optional[Dict] = None, + vpc_config: Optional[Dict] = None, + kms_key: Optional[str] = None, + max_runtime_in_sec: Optional[int] = None, + ) -> Optional[Dict[str, Any]]: + """Runs a model optimization job. + + Args: + output_path (str): Specifies where to store the compiled/quantized model. + instance_type (Optional[str]): Target deployment instance type that + the model is optimized for. + role_arn (Optional[str]): Execution role. Defaults to ``None``. + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. + Defaults to ``None`` + env_vars (Optional[Dict]): Additional environment variables to run the optimization + container. Defaults to ``None``. + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading + to S3. Defaults to ``None``. + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to + ``None``. + + Returns: + Optional[Dict[str, Any]]: Model optimization job input arguments. + """ + if self.model_server != ModelServer.DJL_SERVING: + logger.info("Overwriting model server to DJL.") + self.model_server = ModelServer.DJL_SERVING + + self.role_arn = role_arn or self.role_arn + self.instance_type = instance_type or self.instance_type + + self.pysdk_model = _custom_speculative_decoding( + self.pysdk_model, speculative_decoding_config, False + ) + + if quantization_config or compilation_config: + create_optimization_job_args = { + "OptimizationJobName": job_name, + "DeploymentInstanceType": self.instance_type, + "RoleArn": self.role_arn, + } + + if env_vars: + self.pysdk_model.env.update(env_vars) + create_optimization_job_args["OptimizationEnvironment"] = env_vars + + self._optimize_prepare_for_hf() + model_source = _generate_model_source(self.pysdk_model.model_data, False) + create_optimization_job_args["ModelSource"] = model_source + + optimization_config, override_env = _extract_optimization_config_and_env( + quantization_config, compilation_config + ) + create_optimization_job_args["OptimizationConfigs"] = [optimization_config] + self.pysdk_model.env.update(override_env) + + output_config = {"S3OutputLocation": output_path} + if kms_key: + output_config["KmsKeyId"] = kms_key + create_optimization_job_args["OutputConfig"] = output_config + + if max_runtime_in_sec: + create_optimization_job_args["StoppingCondition"] = { + "MaxRuntimeInSeconds": max_runtime_in_sec + } + if tags: + create_optimization_job_args["Tags"] = tags + if vpc_config: + create_optimization_job_args["VpcConfig"] = vpc_config + + # HF_MODEL_ID needs not to be present, otherwise, + # HF model artifacts will be re-downloaded during deployment + if "HF_MODEL_ID" in self.pysdk_model.env: + del self.pysdk_model.env["HF_MODEL_ID"] + + return create_optimization_job_args + return None + + def _optimize_prepare_for_hf(self): + """Prepare huggingface model data for optimization.""" + custom_model_path: str = ( + self.model_metadata.get("CUSTOM_MODEL_PATH") if self.model_metadata else None + ) + if _is_s3_uri(custom_model_path): + # Remove slash by the end of s3 uri, as it may lead to / subfolder during upload. + custom_model_path = ( + custom_model_path[:-1] if custom_model_path.endswith("/") else custom_model_path + ) + else: + if not custom_model_path: + custom_model_path = f"/tmp/sagemaker/model-builder/{self.model}/code" + download_huggingface_model_metadata( + self.model, + custom_model_path, + self.env_vars.get("HUGGING_FACE_HUB_TOKEN"), + ) + custom_model_path = _normalize_local_model_path(custom_model_path) + + self.pysdk_model.model_data, env = self._prepare_for_mode( + model_path=custom_model_path, + should_upload_artifacts=True, + ) + self.pysdk_model.env.update(env) diff --git a/src/sagemaker/serve/builder/tei_builder.py b/src/sagemaker/serve/builder/tei_builder.py index 53f41a6891..a1f4567eb9 100644 --- a/src/sagemaker/serve/builder/tei_builder.py +++ b/src/sagemaker/serve/builder/tei_builder.py @@ -25,6 +25,7 @@ _get_nb_instance, ) from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure +from sagemaker.serve.utils.optimize_utils import _is_optimized from sagemaker.serve.utils.predictors import TeiLocalModePredictor from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode @@ -63,7 +64,6 @@ def __init__(self): self.nb_instance_type = None self.ram_usage_model_load = None self.secret_key = None - self.jumpstart = None self.role_arn = None @abstractmethod @@ -163,10 +163,8 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa self.pysdk_model.role = kwargs.get("role") del kwargs["role"] - # set model_data to uncompressed s3 dict - self.pysdk_model.model_data, env_vars = self._prepare_for_mode() - self.env_vars.update(env_vars) - self.pysdk_model.env.update(self.env_vars) + if not _is_optimized(self.pysdk_model): + self._prepare_for_mode() # if the weights have been cached via local container mode -> set to offline if str(Mode.LOCAL_CONTAINER) in self.modes: @@ -222,4 +220,8 @@ def _build_for_tei(self): self._set_to_tei() self.pysdk_model = self._build_for_hf_tei() + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model diff --git a/src/sagemaker/serve/builder/tf_serving_builder.py b/src/sagemaker/serve/builder/tf_serving_builder.py index 42c548f4e4..9b171b1d98 100644 --- a/src/sagemaker/serve/builder/tf_serving_builder.py +++ b/src/sagemaker/serve/builder/tf_serving_builder.py @@ -102,6 +102,10 @@ def _create_tensorflow_model(self): self.pysdk_model.mode = self.mode self.pysdk_model.modes = self.modes self.pysdk_model.serve_settings = self.serve_settings + if hasattr(self, "role_arn") and self.role_arn: + self.pysdk_model.role = self.role_arn + if hasattr(self, "sagemaker_session") and self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session self._original_deploy = self.pysdk_model.deploy self.pysdk_model.deploy = self._model_builder_deploy_wrapper diff --git a/src/sagemaker/serve/builder/tgi_builder.py b/src/sagemaker/serve/builder/tgi_builder.py index a74c07f1e1..558a560a74 100644 --- a/src/sagemaker/serve/builder/tgi_builder.py +++ b/src/sagemaker/serve/builder/tgi_builder.py @@ -25,6 +25,7 @@ LocalModelInvocationException, SkipTuningComboException, ) +from sagemaker.serve.utils.optimize_utils import _is_optimized from sagemaker.serve.utils.tuning import ( _serial_benchmark, _concurrent_benchmark, @@ -90,7 +91,6 @@ def __init__(self): self.nb_instance_type = None self.ram_usage_model_load = None self.secret_key = None - self.jumpstart = None self.role_arn = None @abstractmethod @@ -202,10 +202,8 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa self.pysdk_model.role = kwargs.get("role") del kwargs["role"] - # set model_data to uncompressed s3 dict - self.pysdk_model.model_data, env_vars = self._prepare_for_mode() - self.env_vars.update(env_vars) - self.pysdk_model.env.update(self.env_vars) + if not _is_optimized(self.pysdk_model): + self._prepare_for_mode() # if the weights have been cached via local container mode -> set to offline if str(Mode.LOCAL_CONTAINER) in self.modes: @@ -474,4 +472,8 @@ def _build_for_tgi(self): self.pysdk_model = self._build_for_hf_tgi() self.pysdk_model.tune = self._tune_for_hf_tgi + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index eef5800d98..e618b54e44 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -27,6 +27,7 @@ from sagemaker.serve.model_server.multi_model_server.prepare import ( _create_dir_structure, ) +from sagemaker.serve.utils.optimize_utils import _is_optimized from sagemaker.serve.utils.predictors import TransformersLocalModePredictor from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode @@ -151,11 +152,11 @@ def _get_hf_metadata_create_model(self) -> Type[Model]: vpc_config=self.vpc_config, ) - if self.mode == Mode.LOCAL_CONTAINER: + if not self.image_uri and self.mode == Mode.LOCAL_CONTAINER: self.image_uri = pysdk_model.serving_image_uri( self.sagemaker_session.boto_region_name, "local" ) - else: + elif not self.image_uri: self.image_uri = pysdk_model.serving_image_uri( self.sagemaker_session.boto_region_name, self.instance_type ) @@ -223,10 +224,8 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr self.pysdk_model.role = kwargs.get("role") del kwargs["role"] - # set model_data to uncompressed s3 dict - self.pysdk_model.model_data, env_vars = self._prepare_for_mode() - self.env_vars.update(env_vars) - self.pysdk_model.env.update(self.env_vars) + if not _is_optimized(self.pysdk_model): + self._prepare_for_mode() if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True @@ -303,4 +302,8 @@ def _build_for_transformers(self): self._build_transformers_env() + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model diff --git a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py index b8f1d0529b..6f9bf8307f 100644 --- a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py +++ b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py @@ -59,6 +59,7 @@ def prepare( sagemaker_session: Session = None, image: str = None, jumpstart: bool = False, + should_upload_artifacts: bool = False, ): """Placeholder docstring""" try: @@ -69,7 +70,7 @@ def prepare( + "session to be created or supply `sagemaker_session` into @serve.invoke." ) from e - upload_artifacts = None + upload_artifacts = None, None if self.model_server == ModelServer.TORCHSERVE: upload_artifacts = self._upload_torchserve_artifacts( model_path=model_path, @@ -77,6 +78,7 @@ def prepare( secret_key=secret_key, s3_model_data_url=s3_model_data_url, image=image, + should_upload_artifacts=True, ) if self.model_server == ModelServer.TRITON: @@ -86,6 +88,7 @@ def prepare( secret_key=secret_key, s3_model_data_url=s3_model_data_url, image=image, + should_upload_artifacts=True, ) if self.model_server == ModelServer.DJL_SERVING: @@ -94,32 +97,41 @@ def prepare( sagemaker_session=sagemaker_session, s3_model_data_url=s3_model_data_url, image=image, + should_upload_artifacts=True, ) - if self.model_server == ModelServer.TGI: - upload_artifacts = self._upload_tgi_artifacts( + if self.model_server == ModelServer.TENSORFLOW_SERVING: + upload_artifacts = self._upload_tensorflow_serving_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, + secret_key=secret_key, s3_model_data_url=s3_model_data_url, image=image, - jumpstart=jumpstart, + should_upload_artifacts=True, ) - if self.model_server == ModelServer.MMS: - upload_artifacts = self._upload_server_artifacts( + # By default, we do not want to upload artifacts in S3 for the below server. + # In Case of Optimization, artifacts need to be uploaded into s3. + # In that case, `should_upload_artifacts` arg needs to come from + # the caller of prepare. + + if self.model_server == ModelServer.TGI: + upload_artifacts = self._upload_tgi_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, s3_model_data_url=s3_model_data_url, image=image, + jumpstart=jumpstart, + should_upload_artifacts=should_upload_artifacts, ) - if self.model_server == ModelServer.TENSORFLOW_SERVING: - upload_artifacts = self._upload_tensorflow_serving_artifacts( + if self.model_server == ModelServer.MMS: + upload_artifacts = self._upload_server_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, - secret_key=secret_key, s3_model_data_url=s3_model_data_url, image=image, + should_upload_artifacts=should_upload_artifacts, ) if self.model_server == ModelServer.TEI: @@ -128,9 +140,10 @@ def prepare( sagemaker_session=sagemaker_session, s3_model_data_url=s3_model_data_url, image=image, + should_upload_artifacts=should_upload_artifacts, ) - if upload_artifacts: + if upload_artifacts or isinstance(self.model_server, ModelServer): return upload_artifacts raise ValueError("%s model server is not supported" % self.model_server) diff --git a/src/sagemaker/serve/model_server/djl_serving/server.py b/src/sagemaker/serve/model_server/djl_serving/server.py index 80214332b0..4ba7dd227d 100644 --- a/src/sagemaker/serve/model_server/djl_serving/server.py +++ b/src/sagemaker/serve/model_server/djl_serving/server.py @@ -12,6 +12,7 @@ from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join from sagemaker.s3 import S3Uploader from sagemaker.local.utils import get_docker_host +from sagemaker.serve.utils.optimize_utils import _is_s3_uri logger = logging.getLogger(__name__) MODE_DIR_BINDING = "/opt/ml/model/" @@ -91,39 +92,48 @@ def _upload_djl_artifacts( s3_model_data_url: str = None, image: str = None, env_vars: dict = None, + should_upload_artifacts: bool = False, ): """Placeholder docstring""" - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + model_data_url = None + if _is_s3_uri(model_path): + model_data_url = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - code_dir = Path(model_path).joinpath("code") + code_dir = Path(model_path).joinpath("code") - s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") - logger.debug("Uploading DJL Model Resources uncompressed to: %s", s3_location) + logger.debug("Uploading DJL Model Resources uncompressed to: %s", s3_location) - model_data_url = S3Uploader.upload( - str(code_dir), - s3_location, - None, - sagemaker_session, - ) + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) - model_data = { - "S3DataSource": { - "CompressionType": "None", - "S3DataType": "S3Prefix", - "S3Uri": model_data_url + "/", + model_data = ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } } - } + if model_data_url + else None + ) return (model_data, _update_env_vars(env_vars)) diff --git a/src/sagemaker/serve/model_server/multi_model_server/server.py b/src/sagemaker/serve/model_server/multi_model_server/server.py index b78e01f5c3..91d585b4cf 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/server.py +++ b/src/sagemaker/serve/model_server/multi_model_server/server.py @@ -11,6 +11,7 @@ from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join from sagemaker.s3 import S3Uploader from sagemaker.local.utils import get_docker_host +from sagemaker.serve.utils.optimize_utils import _is_s3_uri MODE_DIR_BINDING = "/opt/ml/model/" _DEFAULT_ENV_VARS = {} @@ -84,38 +85,48 @@ def _upload_server_artifacts( s3_model_data_url: str = None, image: str = None, env_vars: dict = None, + should_upload_artifacts: bool = False, ): - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + model_data_url = None + if _is_s3_uri(model_path): + model_data_url = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + code_dir = Path(model_path).joinpath("code") - code_dir = Path(model_path).joinpath("code") + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") - s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + logger.debug("Uploading Multi Model Server Resources uncompressed to: %s", s3_location) - logger.debug("Uploading Multi Model Server Resources uncompressed to: %s", s3_location) + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) - model_data_url = S3Uploader.upload( - str(code_dir), - s3_location, - None, - sagemaker_session, + model_data = ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } + } + if model_data_url + else None ) - model_data = { - "S3DataSource": { - "CompressionType": "None", - "S3DataType": "S3Prefix", - "S3Uri": model_data_url + "/", - } - } return model_data, _update_env_vars(env_vars) diff --git a/src/sagemaker/serve/model_server/tei/server.py b/src/sagemaker/serve/model_server/tei/server.py index 25c27e6dda..54abbea0da 100644 --- a/src/sagemaker/serve/model_server/tei/server.py +++ b/src/sagemaker/serve/model_server/tei/server.py @@ -12,7 +12,7 @@ from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join from sagemaker.s3 import S3Uploader from sagemaker.local.utils import get_docker_host - +from sagemaker.serve.utils.optimize_utils import _is_s3_uri MODE_DIR_BINDING = "/opt/ml/model/" _SHM_SIZE = "2G" @@ -107,6 +107,7 @@ def _upload_tei_artifacts( s3_model_data_url: str = None, image: str = None, env_vars: dict = None, + should_upload_artifacts: bool = False, ): """Uploads the model artifacts to S3. @@ -116,38 +117,48 @@ def _upload_tei_artifacts( s3_model_data_url: S3 model data URL image: Image to use env_vars: Environment variables to set + model_data_s3_path: S3 path to model data + should_upload_artifacts: Whether to upload artifacts """ - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + model_data_url = None + if _is_s3_uri(model_path): + model_data_url = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - code_dir = Path(model_path).joinpath("code") + code_dir = Path(model_path).joinpath("code") - s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") - logger.debug("Uploading TEI Model Resources uncompressed to: %s", s3_location) + logger.debug("Uploading TEI Model Resources uncompressed to: %s", s3_location) - model_data_url = S3Uploader.upload( - str(code_dir), - s3_location, - None, - sagemaker_session, - ) + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) - model_data = { - "S3DataSource": { - "CompressionType": "None", - "S3DataType": "S3Prefix", - "S3Uri": model_data_url + "/", + model_data = ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } } - } + if model_data_url + else None + ) return (model_data, _update_env_vars(env_vars)) diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/server.py b/src/sagemaker/serve/model_server/tensorflow_serving/server.py index 2392287c61..45931e9afc 100644 --- a/src/sagemaker/serve/model_server/tensorflow_serving/server.py +++ b/src/sagemaker/serve/model_server/tensorflow_serving/server.py @@ -7,6 +7,7 @@ import platform from pathlib import Path from sagemaker.base_predictor import PredictorBase +from sagemaker.serve.utils.optimize_utils import _is_s3_uri from sagemaker.session import Session from sagemaker.serve.utils.exceptions import LocalModelInvocationException from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url @@ -101,6 +102,7 @@ def _upload_tensorflow_serving_artifacts( secret_key: str, s3_model_data_url: str = None, image: str = None, + should_upload_artifacts: bool = False, ): """Uploads the model artifacts to S3. @@ -110,23 +112,30 @@ def _upload_tensorflow_serving_artifacts( secret_key: Secret key to use for authentication s3_model_data_url: S3 model data URL image: Image to use + model_data_s3_path: S3 model data URI """ - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + s3_upload_path = None + if _is_s3_uri(model_path): + s3_upload_path = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - logger.debug( - "Uploading the model resources to bucket=%s, key_prefix=%s.", bucket, code_key_prefix - ) - s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) - logger.debug("Model resources uploaded to: %s", s3_upload_path) + logger.debug( + "Uploading the model resources to bucket=%s, key_prefix=%s.", + bucket, + code_key_prefix, + ) + s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) + logger.debug("Model resources uploaded to: %s", s3_upload_path) env_vars = { "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", diff --git a/src/sagemaker/serve/model_server/tgi/server.py b/src/sagemaker/serve/model_server/tgi/server.py index 75cf3bd402..4d9686a89c 100644 --- a/src/sagemaker/serve/model_server/tgi/server.py +++ b/src/sagemaker/serve/model_server/tgi/server.py @@ -12,6 +12,7 @@ from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join from sagemaker.s3 import S3Uploader from sagemaker.local.utils import get_docker_host +from sagemaker.serve.utils.optimize_utils import _is_s3_uri MODE_DIR_BINDING = "/opt/ml/model/" _SHM_SIZE = "2G" @@ -111,38 +112,47 @@ def _upload_tgi_artifacts( s3_model_data_url: str = None, image: str = None, env_vars: dict = None, + should_upload_artifacts: bool = False, ): - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + model_data_url = None + if _is_s3_uri(model_path): + model_data_url = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - code_dir = Path(model_path).joinpath("code") + code_dir = Path(model_path).joinpath("code") - s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") - logger.debug("Uploading TGI Model Resources uncompressed to: %s", s3_location) + logger.debug("Uploading TGI Model Resources uncompressed to: %s", s3_location) - model_data_url = S3Uploader.upload( - str(code_dir), - s3_location, - None, - sagemaker_session, - ) + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) - model_data = { - "S3DataSource": { - "CompressionType": "None", - "S3DataType": "S3Prefix", - "S3Uri": model_data_url + "/", + model_data = ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } } - } + if model_data_url + else None + ) if jumpstart: return (model_data, {}) return (model_data, _update_env_vars(env_vars)) diff --git a/src/sagemaker/serve/model_server/torchserve/server.py b/src/sagemaker/serve/model_server/torchserve/server.py index 5aef136355..74e37cd70b 100644 --- a/src/sagemaker/serve/model_server/torchserve/server.py +++ b/src/sagemaker/serve/model_server/torchserve/server.py @@ -7,6 +7,7 @@ import platform from pathlib import Path from sagemaker.base_predictor import PredictorBase +from sagemaker.serve.utils.optimize_utils import _is_s3_uri from sagemaker.session import Session from sagemaker.serve.utils.exceptions import LocalModelInvocationException from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url @@ -84,24 +85,31 @@ def _upload_torchserve_artifacts( secret_key: str, s3_model_data_url: str = None, image: str = None, + should_upload_artifacts: bool = False, ): """Tar the model artifact and upload to S3 bucket, then prepare for the environment variables""" - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + s3_upload_path = None + if _is_s3_uri(model_path): + s3_upload_path = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - logger.debug( - "Uploading the model resources to bucket=%s, key_prefix=%s.", bucket, code_key_prefix - ) - s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) - logger.debug("Model resources uploaded to: %s", s3_upload_path) + logger.debug( + "Uploading the model resources to bucket=%s, key_prefix=%s.", + bucket, + code_key_prefix, + ) + s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) + logger.debug("Model resources uploaded to: %s", s3_upload_path) env_vars = { "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", diff --git a/src/sagemaker/serve/model_server/triton/server.py b/src/sagemaker/serve/model_server/triton/server.py index 62dfb4759a..e2f3c20d7a 100644 --- a/src/sagemaker/serve/model_server/triton/server.py +++ b/src/sagemaker/serve/model_server/triton/server.py @@ -9,6 +9,7 @@ from sagemaker import fw_utils from sagemaker import Session from sagemaker.base_predictor import PredictorBase +from sagemaker.serve.utils.optimize_utils import _is_s3_uri from sagemaker.serve.utils.uploader import upload from sagemaker.serve.utils.exceptions import LocalModelInvocationException from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url @@ -115,25 +116,32 @@ def _upload_triton_artifacts( secret_key: str, s3_model_data_url: str = None, image: str = None, + should_upload_artifacts: bool = False, ): """Tar triton artifacts and upload to s3""" - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + s3_upload_path = None + if _is_s3_uri(model_path): + s3_upload_path = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - logger.debug( - "Uploading the model resources to bucket=%s, key_prefix=%s.", bucket, code_key_prefix - ) - model_repository = model_path + "/model_repository" - s3_upload_path = upload(sagemaker_session, model_repository, bucket, code_key_prefix) - logger.debug("Model resources uploaded to: %s", s3_upload_path) + logger.debug( + "Uploading the model resources to bucket=%s, key_prefix=%s.", + bucket, + code_key_prefix, + ) + model_repository = model_path + "/model_repository" + s3_upload_path = upload(sagemaker_session, model_repository, bucket, code_key_prefix) + logger.debug("Model resources uploaded to: %s", s3_upload_path) env_vars = { "SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model", diff --git a/src/sagemaker/serve/model_server/triton/triton_builder.py b/src/sagemaker/serve/model_server/triton/triton_builder.py index ed0ec49204..a19235767f 100644 --- a/src/sagemaker/serve/model_server/triton/triton_builder.py +++ b/src/sagemaker/serve/model_server/triton/triton_builder.py @@ -428,6 +428,10 @@ def _create_triton_model(self) -> Type[Model]: self.pysdk_model.mode = self.mode self.pysdk_model.modes = self.modes self.pysdk_model.serve_settings = self.serve_settings + if hasattr(self, "role_arn") and self.role_arn: + self.pysdk_model.role = self.role_arn + if hasattr(self, "sagemaker_session") and self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session # dynamically generate a method to direct model.deploy() logic based on mode # unique method to models created via ModelBuilder() diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py new file mode 100644 index 0000000000..35a937407e --- /dev/null +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -0,0 +1,338 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds the util functions used for the optimize function""" +from __future__ import absolute_import + +import re +import logging +from typing import Dict, Any, Optional, Union, List, Tuple + +from sagemaker import Model +from sagemaker.enums import Tag + +logger = logging.getLogger(__name__) + + +SPECULATIVE_DRAFT_MODEL = "/opt/ml/additional-model-data-sources" + + +def _is_inferentia_or_trainium(instance_type: Optional[str]) -> bool: + """Checks whether an instance is compatible with Inferentia. + + Args: + instance_type (str): The instance type used for the compilation job. + + Returns: + bool: Whether the given instance type is Inferentia or Trainium. + """ + if isinstance(instance_type, str): + match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + if match: + if match[1].startswith("inf") or match[1].startswith("trn"): + return True + return False + + +def _is_image_compatible_with_optimization_job(image_uri: Optional[str]) -> bool: + """Checks whether an instance is compatible with an optimization job. + + Args: + image_uri (str): The image URI of the optimization job. + + Returns: + bool: Whether the given instance type is compatible with an optimization job. + """ + # TODO: Use specific container type instead. + if image_uri is None: + return True + return "djl-inference:" in image_uri and ("-lmi" in image_uri or "-neuronx-" in image_uri) + + +def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) -> Model: + """Generates a new optimization model. + + Args: + pysdk_model (Model): A PySDK model. + optimization_response (dict): The optimization response. + + Returns: + Model: A deployable optimized model. + """ + recommended_image_uri = optimization_response.get("OptimizationOutput", {}).get( + "RecommendedInferenceImage" + ) + s3_uri = optimization_response.get("OutputConfig", {}).get("S3OutputLocation") + deployment_instance_type = optimization_response.get("DeploymentInstanceType") + + if recommended_image_uri: + pysdk_model.image_uri = recommended_image_uri + if s3_uri: + pysdk_model.model_data["S3DataSource"]["S3Uri"] = s3_uri + if deployment_instance_type: + pysdk_model.instance_type = deployment_instance_type + + pysdk_model.add_tags( + {"Key": Tag.OPTIMIZATION_JOB_NAME, "Value": optimization_response["OptimizationJobName"]} + ) + return pysdk_model + + +def _is_optimized(pysdk_model: Model) -> bool: + """Checks whether an optimization model is optimized. + + Args: + pysdk_model (Model): A PySDK model. + + Return: + bool: Whether the given model type is optimized. + """ + optimized_tags = [Tag.OPTIMIZATION_JOB_NAME, Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER] + if hasattr(pysdk_model, "_tags") and pysdk_model._tags: + if isinstance(pysdk_model._tags, dict): + return pysdk_model._tags.get("Key") in optimized_tags + for tag in pysdk_model._tags: + if tag.get("Key") in optimized_tags: + return True + return False + + +def _generate_model_source( + model_data: Optional[Union[Dict[str, Any], str]], accept_eula: Optional[bool] +) -> Optional[Dict[str, Any]]: + """Extracts model source from model data. + + Args: + model_data (Optional[Union[Dict[str, Any], str]]): A model data. + + Returns: + Optional[Dict[str, Any]]: Model source data. + """ + if model_data is None: + raise ValueError("Model Optimization Job only supports model with S3 data source.") + + s3_uri = model_data + if isinstance(s3_uri, dict): + s3_uri = s3_uri.get("S3DataSource").get("S3Uri") + + model_source = {"S3": {"S3Uri": s3_uri}} + if accept_eula: + model_source["S3"]["ModelAccessConfig"] = {"AcceptEula": True} + return model_source + + +def _update_environment_variables( + env: Optional[Dict[str, str]], new_env: Optional[Dict[str, str]] +) -> Optional[Dict[str, str]]: + """Updates environment variables based on environment variables. + + Args: + env (Optional[Dict[str, str]]): The environment variables. + new_env (Optional[Dict[str, str]]): The new environment variables. + + Returns: + Optional[Dict[str, str]]: The updated environment variables. + """ + if new_env: + if env: + env.update(new_env) + else: + env = new_env + return env + + +def _extract_speculative_draft_model_provider( + speculative_decoding_config: Optional[Dict] = None, +) -> Optional[str]: + """Extracts speculative draft model provider from speculative decoding config. + + Args: + speculative_decoding_config (Optional[Dict]): A speculative decoding config. + + Returns: + Optional[str]: The speculative draft model provider. + """ + if speculative_decoding_config is None: + return None + + if speculative_decoding_config.get( + "ModelProvider" + ) == "Custom" or speculative_decoding_config.get("ModelSource"): + return "custom" + + return "sagemaker" + + +def _extracts_and_validates_speculative_model_source( + speculative_decoding_config: Dict, +) -> str: + """Extracts model source from speculative decoding config. + + Args: + speculative_decoding_config (Optional[Dict]): A speculative decoding config. + + Returns: + str: Model source. + + Raises: + ValueError: If model source is none. + """ + model_source: str = speculative_decoding_config.get("ModelSource") + + if not model_source: + raise ValueError("ModelSource must be provided in speculative decoding config.") + return model_source + + +def _generate_channel_name(additional_model_data_sources: Optional[List[Dict]]) -> str: + """Generates a channel name. + + Args: + additional_model_data_sources (Optional[List[Dict]]): The additional model data sources. + + Returns: + str: The channel name. + """ + channel_name = "draft_model" + if additional_model_data_sources and len(additional_model_data_sources) > 0: + channel_name = additional_model_data_sources[0].get("ChannelName", channel_name) + + return channel_name + + +def _generate_additional_model_data_sources( + model_source: str, + channel_name: str, + accept_eula: bool = False, + s3_data_type: Optional[str] = "S3Prefix", + compression_type: Optional[str] = "None", +) -> List[Dict]: + """Generates additional model data sources. + + Args: + model_source (Optional[str]): The model source. + channel_name (Optional[str]): The channel name. + accept_eula (Optional[bool]): Whether to accept eula or not. + s3_data_type (Optional[str]): The S3 data type, defaults to 'S3Prefix'. + compression_type (Optional[str]): The compression type, defaults to None. + + Returns: + List[Dict]: The additional model data sources. + """ + + additional_model_data_source = { + "ChannelName": channel_name, + "S3DataSource": { + "S3Uri": model_source, + "S3DataType": s3_data_type, + "CompressionType": compression_type, + }, + } + if accept_eula: + additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {"ACCEPT_EULA": True} + + return [additional_model_data_source] + + +def _is_s3_uri(s3_uri: Optional[str]) -> bool: + """Checks whether an S3 URI is valid. + + Args: + s3_uri (Optional[str]): The S3 URI. + + Returns: + bool: Whether the S3 URI is valid. + """ + if s3_uri is None: + return False + + return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None + + +def _extract_optimization_config_and_env( + quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None +) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]: + """Extracts optimization config and environment variables. + + Args: + quantization_config (Optional[Dict]): The quantization config. + compilation_config (Optional[Dict]): The compilation config. + + Returns: + Optional[Tuple[Optional[Dict], Optional[Dict]]]: + The optimization config and environment variables. + """ + if quantization_config: + return {"ModelQuantizationConfig": quantization_config}, quantization_config.get( + "OverrideEnvironment" + ) + if compilation_config: + return {"ModelCompilationConfig": compilation_config}, compilation_config.get( + "OverrideEnvironment" + ) + return None, None + + +def _normalize_local_model_path(local_model_path: Optional[str]) -> Optional[str]: + """Normalizes the local model path. + + Args: + local_model_path (Optional[str]): The local model path. + + Returns: + Optional[str]: The normalized model path. + """ + if local_model_path is None: + return local_model_path + + # Removes /code or /code/ path at the end of local_model_path, + # as it is appended during artifacts upload. + pattern = r"/code/?$" + if re.search(pattern, local_model_path): + return re.sub(pattern, "", local_model_path) + return local_model_path + + +def _custom_speculative_decoding( + model: Model, + speculative_decoding_config: Optional[Dict], + accept_eula: Optional[bool] = False, +) -> Model: + """Modifies the given model for speculative decoding config with custom provider. + + Args: + model (Model): The model. + speculative_decoding_config (Optional[Dict]): The speculative decoding config. + accept_eula (Optional[bool]): Whether to accept eula or not. + """ + + if speculative_decoding_config: + additional_model_source = _extracts_and_validates_speculative_model_source( + speculative_decoding_config + ) + + if _is_s3_uri(additional_model_source): + channel_name = _generate_channel_name(model.additional_model_data_sources) + speculative_draft_model = f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}" + + model.additional_model_data_sources = _generate_additional_model_data_sources( + additional_model_source, channel_name, accept_eula + ) + else: + speculative_draft_model = additional_model_source + + model.env.update({"OPTION_SPECULATIVE_DRAFT_MODEL": speculative_draft_model}) + model.add_tags( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "custom"}, + ) + + return model diff --git a/src/sagemaker/serve/utils/telemetry_logger.py b/src/sagemaker/serve/utils/telemetry_logger.py index 99aeb4ff26..0ea6ec3f26 100644 --- a/src/sagemaker/serve/utils/telemetry_logger.py +++ b/src/sagemaker/serve/utils/telemetry_logger.py @@ -29,7 +29,12 @@ MLFLOW_REGISTRY_PATH, ) from sagemaker.serve.utils.lineage_utils import _get_mlflow_model_path_type -from sagemaker.serve.utils.types import ModelServer, ImageUriOption +from sagemaker.serve.utils.types import ( + ModelServer, + ImageUriOption, + ModelHub, + SpeculativeDecodingDraftModelSource, +) from sagemaker.serve.validations.check_image_uri import is_1p_image_uri from sagemaker.user_agent import SDK_VERSION @@ -69,6 +74,16 @@ MLFLOW_REGISTRY_PATH: 5, } +MODEL_HUB_TO_CODE = { + str(ModelHub.JUMPSTART): 1, + str(ModelHub.HUGGINGFACE): 2, +} + +SD_DRAFT_MODEL_SOURCE_TO_CODE = { + str(SpeculativeDecodingDraftModelSource.SAGEMAKER): 1, + str(SpeculativeDecodingDraftModelSource.CUSTOM): 2, +} + def _capture_telemetry(func_name: str): """Placeholder docstring""" @@ -79,16 +94,46 @@ def wrapper(self, *args, **kwargs): logger.info(TELEMETRY_OPT_OUT_MESSAGING) response = None caught_ex = None + status = "1" + failure_reason = None + failure_type = None + extra = f"{func_name}" + + start_timer = perf_counter() + try: + response = func(self, *args, **kwargs) + except ( + ModelBuilderException, + exceptions.CapacityError, + exceptions.UnexpectedStatusException, + exceptions.AsyncInferenceError, + ) as e: + status = "0" + caught_ex = e + failure_reason = str(e) + failure_type = e.__class__.__name__ + except Exception as e: # pylint: disable=W0703 + raise e - image_uri_tail = self.image_uri.split("/")[1] - image_uri_option = _get_image_uri_option(self.image_uri, self._is_custom_image_uri) - extra = ( - f"{func_name}" - f"&x-modelServer={MODEL_SERVER_TO_CODE[str(self.model_server)]}" - f"&x-imageTag={image_uri_tail}" - f"&x-sdkVersion={SDK_VERSION}" - f"&x-defaultImageUsage={image_uri_option}" - ) + stop_timer = perf_counter() + elapsed = stop_timer - start_timer + + if self.model_server: + extra += f"&x-modelServer={MODEL_SERVER_TO_CODE[str(self.model_server)]}" + + if self.image_uri: + image_uri_tail = self.image_uri.split("/")[1] + image_uri_option = _get_image_uri_option( + self.image_uri, getattr(self, "_is_custom_image_uri", False) + ) + + if self.image_uri: + extra += f"&x-imageTag={image_uri_tail}" + + extra += f"&x-sdkVersion={SDK_VERSION}" + + if self.image_uri: + extra += f"&x-defaultImageUsage={image_uri_option}" if self.model_server == ModelServer.DJL_SERVING or self.model_server == ModelServer.TGI: extra += f"&x-modelName={self.model}" @@ -101,46 +146,41 @@ def wrapper(self, *args, **kwargs): mlflow_model_path_type = _get_mlflow_model_path_type(mlflow_model_path) extra += f"&x-mlflowModelPathType={MLFLOW_MODEL_PATH_CODE[mlflow_model_path_type]}" - start_timer = perf_counter() - try: - response = func(self, *args, **kwargs) - stop_timer = perf_counter() - elapsed = stop_timer - start_timer - extra += f"&x-latency={round(elapsed, 2)}" - if not self.serve_settings.telemetry_opt_out: - _send_telemetry( - "1", - MODE_TO_CODE[str(self.mode)], - self.sagemaker_session, - None, - None, - extra, - ) - except ( - ModelBuilderException, - exceptions.CapacityError, - exceptions.UnexpectedStatusException, - exceptions.AsyncInferenceError, - ) as e: - stop_timer = perf_counter() - elapsed = stop_timer - start_timer - extra += f"&x-latency={round(elapsed, 2)}" - if not self.serve_settings.telemetry_opt_out: - _send_telemetry( - "0", - MODE_TO_CODE[str(self.mode)], - self.sagemaker_session, - str(e), - e.__class__.__name__, - extra, - ) - caught_ex = e - except Exception as e: # pylint: disable=W0703 - caught_ex = e - finally: - if caught_ex: - raise caught_ex - return response # pylint: disable=W0150 + if getattr(self, "model_hub", False): + extra += f"&x-modelHub={MODEL_HUB_TO_CODE[str(self.model_hub)]}" + + if getattr(self, "is_fine_tuned", False): + extra += "&x-fineTuned=1" + + if getattr(self, "is_compiled", False): + extra += "&x-compiled=1" + if getattr(self, "is_quantized", False): + extra += "&x-quantized=1" + if getattr(self, "speculative_decoding_draft_model_source", False): + model_provider_enum = ( + SpeculativeDecodingDraftModelSource.SAGEMAKER + if self.speculative_decoding_draft_model_source == "sagemaker" + else SpeculativeDecodingDraftModelSource.CUSTOM + ) + model_provider_value = SD_DRAFT_MODEL_SOURCE_TO_CODE[str(model_provider_enum)] + extra += f"&x-sdDraftModelSource={model_provider_value}" + + extra += f"&x-latency={round(elapsed, 2)}" + + if not self.serve_settings.telemetry_opt_out: + _send_telemetry( + status, + MODE_TO_CODE[str(self.mode)], + self.sagemaker_session, + failure_reason, + failure_type, + extra, + ) + + if caught_ex: + raise caught_ex + + return response return wrapper diff --git a/src/sagemaker/serve/utils/types.py b/src/sagemaker/serve/utils/types.py index adb9fb57e8..e50be62440 100644 --- a/src/sagemaker/serve/utils/types.py +++ b/src/sagemaker/serve/utils/types.py @@ -45,3 +45,25 @@ def __str__(self) -> str: CUSTOM_IMAGE = 1 CUSTOM_1P_IMAGE = 2 DEFAULT_IMAGE = 3 + + +class ModelHub(Enum): + """Enum type for model hub source""" + + def __str__(self) -> str: + """Convert enum to string""" + return str(self.name) + + JUMPSTART = 1 + HUGGINGFACE = 2 + + +class SpeculativeDecodingDraftModelSource(Enum): + """Enum type for speculative decoding draft model source""" + + def __str__(self) -> str: + """Convert enum to string""" + return str(self.name) + + SAGEMAKER = 1 + CUSTOM = 2 diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index cd36cc739c..4209f37ae6 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -680,7 +680,6 @@ def general_bucket_check_if_user_has_permission( s3 (str): S3 object from boto session region (str): The region in which to create the bucket. bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not - """ try: s3.meta.client.head_bucket(Bucket=bucket_name) @@ -2626,6 +2625,24 @@ def wait_for_auto_ml_job(self, job, poll=5): _check_job_status(job, desc, "AutoMLJobStatus") return desc + def wait_for_optimization_job(self, job, poll=5): + """Wait for an Amazon SageMaker Optimization job to complete. + + Args: + job (str): Name of optimization job to wait for. + poll (int): Polling interval in seconds (default: 5). + + Returns: + (dict): Return value from the ``DescribeOptimizationJob`` API. + + Raises: + exceptions.ResourceNotFound: If optimization job fails with CapacityError. + exceptions.UnexpectedStatusException: If optimization job fails. + """ + desc = _wait_until(lambda: _optimization_job_status(self.sagemaker_client, job), poll) + _check_job_status(job, desc, "OptimizationJobStatus") + return desc + def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this method self, job_name, wait=False, poll=10 ): @@ -7510,6 +7527,7 @@ def container_def( container_mode=None, image_config=None, accept_eula=None, + additional_model_data_sources=None, model_reference_arn=None, ): """Create a definition for executing a container as part of a SageMaker model. @@ -7533,6 +7551,8 @@ def container_def( The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license agreement (EULA) that some models require. (Default: None). + additional_model_data_sources (PipelineVariable or dict): Additional location + of SageMaker model data (default: None). Returns: dict[str, str]: A complete container definition object usable with the CreateModel API if @@ -7542,6 +7562,9 @@ def container_def( env = {} c_def = {"Image": image_uri, "Environment": env} + if additional_model_data_sources: + c_def["AdditionalModelDataSources"] = additional_model_data_sources + if isinstance(model_data_url, str) and ( not (model_data_url.startswith("s3://") and model_data_url.endswith("tar.gz")) or accept_eula is None @@ -7982,6 +8005,31 @@ def _auto_ml_job_status(sagemaker_client, job_name): return desc +def _optimization_job_status(sagemaker_client, job_name): + """Placeholder docstring""" + optimization_job_status_codes = { + "INPROGRESS": ".", + "COMPLETED": "!", + "FAILED": "*", + "STARTING": ".", + "STOPPING": "_", + "STOPPED": "s", + } + in_progress_statuses = ["INPROGRESS", "STARTING", "STOPPING"] + + desc = sagemaker_client.describe_optimization_job(OptimizationJobName=job_name) + status = desc["OptimizationJobStatus"] + + print(optimization_job_status_codes.get(status, "?"), end="") + sys.stdout.flush() + + if status in in_progress_statuses: + return None + + print("") + return desc + + def _create_model_package_status(sagemaker_client, model_package_name): """Placeholder docstring""" in_progress_statuses = ["InProgress", "Pending"] diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 430effefa3..45509f65f6 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -25,6 +25,7 @@ import tarfile import tempfile import time +from functools import lru_cache from typing import Union, Any, List, Optional, Dict import json import abc @@ -33,10 +34,12 @@ from os.path import abspath, realpath, dirname, normpath, join as joinpath from importlib import import_module + +import boto3 import botocore from botocore.utils import merge_dicts from six.moves.urllib import parse -import pandas as pd +from six import viewitems from sagemaker import deprecations from sagemaker.config import validate_sagemaker_config @@ -1451,10 +1454,15 @@ def volume_size_supported(instance_type: str) -> bool: if len(parts) != 2: raise ValueError(f"Failed to parse instance type '{instance_type}'") - # Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc) + g5 - # does not support attaching an EBS volume. + # Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc) + # + g5 or g6 or p5 does not support attaching an EBS volume. family = parts[0] - return "d" not in family and not family.startswith("g5") + return ( + "d" not in family + and not family.startswith("g5") + and not family.startswith("g6") + and not family.startswith("p5") + ) except Exception as e: raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}") @@ -1603,44 +1611,78 @@ def can_model_package_source_uri_autopopulate(source_uri: str): ) -def flatten_dict(source_dict: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: - """Flatten a nested dictionary. +def flatten_dict( + d: Dict[str, Any], + max_flatten_depth=None, +) -> Dict[str, Any]: + """Flatten a dictionary object. - Args: - source_dict (dict): The dictionary to be flattened. - sep (str): The separator to be used in the flattened dictionary. - Returns: - transformed_dict: The flattened dictionary. + d (Dict[str, Any]): + The dict that will be flattened. + max_flatten_depth (Optional[int]): + Maximum depth to merge. """ - flat_dict_list = pd.json_normalize(source_dict, sep=sep).to_dict(orient="records") - if flat_dict_list: - return flat_dict_list[0] - return {} + def tuple_reducer(k1, k2): + if k1 is None: + return (k2,) + return k1 + (k2,) -def unflatten_dict(source_dict: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: - """Unflatten a flattened dictionary back into a nested dictionary. + # check max_flatten_depth + if max_flatten_depth is not None and max_flatten_depth < 1: + raise ValueError("max_flatten_depth should not be less than 1.") - Args: - source_dict (dict): The input flattened dictionary. - sep (str): The separator used in the flattened keys. + reducer = tuple_reducer - Returns: - transformed_dict: The reconstructed nested dictionary. + flat_dict = {} + + def _flatten(_d, depth, parent=None): + key_value_iterable = viewitems(_d) + has_item = False + for key, value in key_value_iterable: + has_item = True + flat_key = reducer(parent, key) + if isinstance(value, dict) and (max_flatten_depth is None or depth < max_flatten_depth): + has_child = _flatten(value, depth=depth + 1, parent=flat_key) + if has_child: + continue + + if flat_key in flat_dict: + raise ValueError("duplicated key '{}'".format(flat_key)) + flat_dict[flat_key] = value + + return has_item + + _flatten(d, depth=1) + return flat_dict + + +def nested_set_dict(d: Dict[str, Any], keys: List[str], value: Any) -> None: + """Set a value to a sequence of nested keys.""" + + key = keys[0] + + if len(keys) == 1: + d[key] = value + return + + d = d.setdefault(key, {}) + nested_set_dict(d, keys[1:], value) + + +def unflatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: + """Unflatten dict-like object. + + d (Dict[str, Any]) : + The dict that will be unflattened. """ - if not source_dict: - return {} - result = {} - for key, value in source_dict.items(): - keys = key.split(sep) - current = result - for k in keys[:-1]: - if k not in current: - current[k] = {} - current = current[k] if current[k] is not None else current - current[keys[-1]] = value - return result + unflattened_dict = {} + for flat_key, value in viewitems(d): + key_tuple = flat_key + nested_set_dict(unflattened_dict, key_tuple, value) + + return unflattened_dict def deep_override_dict( @@ -1651,6 +1693,7 @@ def deep_override_dict( skip_keys = [] flattened_dict1 = flatten_dict(dict1) + flattened_dict1 = {key: value for key, value in flattened_dict1.items() if value is not None} flattened_dict2 = flatten_dict( {key: value for key, value in dict2.items() if key not in skip_keys} ) @@ -1686,3 +1729,179 @@ def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optiona "or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS" ) return None + + +@lru_cache +def get_instance_rate_per_hour( + instance_type: str, + region: str, +) -> Optional[Dict[str, str]]: + """Gets instance rate per hour for the given instance type. + + Args: + instance_type (str): The instance type. + region (str): The region. + Returns: + Optional[Dict[str, str]]: Instance rate per hour. + Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.125'}. + + Raises: + Exception: An exception is raised if + the IAM role is not authorized to perform pricing:GetProducts. + or unexpected event happened. + """ + region_name = "us-east-1" + if region.startswith("eu") or region.startswith("af"): + region_name = "eu-central-1" + elif region.startswith("ap") or region.startswith("cn"): + region_name = "ap-south-1" + + pricing_client: boto3.client = boto3.client("pricing", region_name=region_name) + res = pricing_client.get_products( + ServiceCode="AmazonSageMaker", + Filters=[ + {"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type}, + {"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"}, + {"Type": "TERM_MATCH", "Field": "regionCode", "Value": region}, + ], + ) + + price_list = res.get("PriceList", []) + if len(price_list) > 0: + price_data = price_list[0] + if isinstance(price_data, str): + price_data = json.loads(price_data) + + instance_rate_per_hour = extract_instance_rate_per_hour(price_data) + if instance_rate_per_hour is not None: + return instance_rate_per_hour + raise Exception(f"Unable to get instance rate per hour for instance type: {instance_type}.") + + +def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[str, str]]: + """Extract instance rate per hour for the given Price JSON data. + + Args: + price_data (Dict[str, Any]): The Price JSON data. + Returns: + Optional[Dict[str, str], None]: Instance rate per hour. + """ + + if price_data is not None: + price_dimensions = price_data.get("terms", {}).get("OnDemand", {}).values() + for dimension in price_dimensions: + for price in dimension.get("priceDimensions", {}).values(): + for currency in price.get("pricePerUnit", {}).keys(): + value = price.get("pricePerUnit", {}).get(currency) + if value is not None: + value = str(round(float(value), 3)) + return { + "unit": f"{currency}/Hr", + "value": value, + "name": "On-demand Instance Rate", + } + return None + + +def camel_case_to_pascal_case(data: Dict[str, Any]) -> Dict[str, Any]: + """Iteratively updates a dictionary to convert all keys from snake_case to PascalCase. + + Args: + data (dict): The dictionary to be updated. + + Returns: + dict: The updated dictionary with keys in PascalCase. + """ + result = {} + + def convert_key(key): + """Converts a snake_case key to PascalCase.""" + return "".join(part.capitalize() for part in key.split("_")) + + def convert_value(value): + """Recursively processes the value of a key-value pair.""" + if isinstance(value, dict): + return camel_case_to_pascal_case(value) + if isinstance(value, list): + return [convert_value(item) for item in value] + + return value + + for key, value in data.items(): + result[convert_key(key)] = convert_value(value) + + return result + + +def tag_exists(tag: TagsDict, curr_tags: Optional[Tags]) -> bool: + """Returns True if ``tag`` already exists. + + Args: + tag (TagsDict): The tag dictionary. + curr_tags (Optional[Tags]): The current tags. + + Returns: + bool: True if the tag exists. + """ + if curr_tags is None: + return False + + for curr_tag in curr_tags: + if tag["Key"] == curr_tag["Key"]: + return True + + return False + + +def _validate_new_tags(new_tags: Optional[Tags], curr_tags: Optional[Tags]) -> Optional[Tags]: + """Validates new tags against existing tags. + + Args: + new_tags (Optional[Tags]): The new tags. + curr_tags (Optional[Tags]): The current tags. + + Returns: + Optional[Tags]: The updated tags. + """ + if curr_tags is None: + return new_tags + + if curr_tags and isinstance(curr_tags, dict): + curr_tags = [curr_tags] + + if isinstance(new_tags, dict): + if not tag_exists(new_tags, curr_tags): + curr_tags.append(new_tags) + elif isinstance(new_tags, list): + for new_tag in new_tags: + if not tag_exists(new_tag, curr_tags): + curr_tags.append(new_tag) + + return curr_tags + + +def remove_tag_with_key(key: str, tags: Optional[Tags]) -> Optional[Tags]: + """Remove a tag with the given key from the list of tags. + + Args: + key (str): The key of the tag to remove. + tags (Optional[Tags]): The current list of tags. + + Returns: + Optional[Tags]: The updated list of tags with the tag removed. + """ + if tags is None: + return tags + if isinstance(tags, dict): + tags = [tags] + + updated_tags = [] + for tag in tags: + if tag["Key"] != key: + updated_tags.append(tag) + + if not updated_tags: + return None + if len(updated_tags) == 1: + return updated_tags[0] + return updated_tags diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 6bc0a5c996..5ee0abd41f 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -11,7 +11,10 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import + +import io import os +import sys import time from unittest import mock @@ -349,3 +352,47 @@ def test_register_gated_jumpstart_model(setup): predictor.delete_predictor() assert response is not None + + +@pytest.mark.skipif( + True, + reason="Only enable after metadata is fully deployed.", +) +def test_jumpstart_model_with_deployment_configs(setup): + model_id = "meta-textgeneration-llama-2-13b" + + model = JumpStartModel( + model_id=model_id, + model_version="*", + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + captured_output = io.StringIO() + sys.stdout = captured_output + model.display_benchmark_metrics() + sys.stdout = sys.__stdout__ + assert captured_output.getvalue() is not None + + configs = model.list_deployment_configs() + assert len(configs) > 0 + + model.set_deployment_config( + configs[0]["ConfigName"], + "ml.g5.2xlarge", + ) + assert model.config_name == configs[0]["ConfigName"] + + predictor = model.deploy( + accept_eula=True, + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + payload = { + "inputs": "some-payload", + "parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6}, + } + + response = predictor.predict(payload, custom_attributes="accept_eula=true") + + assert response is not None diff --git a/tests/integ/sagemaker/serve/test_serve_js_happy.py b/tests/integ/sagemaker/serve/test_serve_js_happy.py index ad0527fcc0..807a5ad691 100644 --- a/tests/integ/sagemaker/serve/test_serve_js_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_js_happy.py @@ -12,6 +12,9 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import io +import sys + import pytest from sagemaker.serve.builder.model_builder import ModelBuilder @@ -54,6 +57,19 @@ def happy_model_builder(sagemaker_session): ) +@pytest.fixture +def meta_textgeneration_llama_2_7b_f_schema(): + prompt = "Hello, I'm a language model," + response = "Hello, I'm a language model, and I'm here to help you with your English." + sample_input = {"inputs": prompt} + sample_output = [{"generated_text": response}] + + return SchemaBuilder( + sample_input=sample_input, + sample_output=sample_output, + ) + + @pytest.fixture def happy_mms_model_builder(sagemaker_session): iam_client = sagemaker_session.boto_session.client("iam") @@ -125,3 +141,59 @@ def test_happy_mms_sagemaker_endpoint(happy_mms_model_builder, gpu_instance_type ) if caught_ex: raise caught_ex + + +@pytest.mark.skipif( + True, + reason="Only enable after metadata is fully deployed.", +) +def test_js_model_with_deployment_configs( + meta_textgeneration_llama_2_7b_f_schema, + sagemaker_session, +): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-2-13b", + schema_builder=meta_textgeneration_llama_2_7b_f_schema, + ) + configs = model_builder.list_deployment_configs() + + assert len(configs) > 0 + + captured_output = io.StringIO() + sys.stdout = captured_output + model_builder.display_benchmark_metrics() + sys.stdout = sys.__stdout__ + assert captured_output.getvalue() is not None + + model_builder.set_deployment_config( + configs[0]["ConfigName"], + "ml.g5.2xlarge", + ) + model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session) + assert model.config_name == configs[0]["ConfigName"] + assert model_builder.get_deployment_config() is not None + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy(accept_eula=True) + logger.info("Endpoint successfully deployed.") + + updated_sample_input = model_builder.schema_builder.sample_input + + predictor.predict(updated_sample_input) + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + raise caught_ex diff --git a/tests/unit/sagemaker/huggingface/test_llm_utils.py b/tests/unit/sagemaker/huggingface/test_llm_utils.py index 3c4cdde3f6..675a6fd885 100644 --- a/tests/unit/sagemaker/huggingface/test_llm_utils.py +++ b/tests/unit/sagemaker/huggingface/test_llm_utils.py @@ -15,7 +15,10 @@ from unittest import TestCase from urllib.error import HTTPError from unittest.mock import Mock, patch -from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata +from sagemaker.huggingface.llm_utils import ( + get_huggingface_model_metadata, + download_huggingface_model_metadata, +) MOCK_HF_ID = "mock_hf_id" MOCK_HF_HUB_TOKEN = "mock_hf_hub_token" @@ -74,3 +77,25 @@ def test_huggingface_model_metadata_general_exception(self, mock_urllib): f"Did not find model metadata for the following HuggingFace Model ID {MOCK_HF_ID}" ) self.assertEquals(expected_error_msg, str(context.exception)) + + @patch("huggingface_hub.snapshot_download") + def test_download_huggingface_model_metadata(self, mock_snapshot_download): + mock_snapshot_download.side_effect = None + + download_huggingface_model_metadata(MOCK_HF_ID, "local_path", MOCK_HF_HUB_TOKEN) + + mock_snapshot_download.assert_called_once_with( + repo_id=MOCK_HF_ID, local_dir="local_path", token=MOCK_HF_HUB_TOKEN + ) + + @patch("importlib.util.find_spec") + def test_download_huggingface_model_metadata_ex(self, mock_find_spec): + mock_find_spec.side_effect = lambda *args, **kwargs: False + + self.assertRaisesRegex( + ImportError, + "Unable to import huggingface_hub, check if huggingface_hub is installed", + lambda: download_huggingface_model_metadata( + MOCK_HF_ID, "local_path", MOCK_HF_HUB_TOKEN + ), + ) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 6298a06db2..9117b2d26d 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -8294,7 +8294,7 @@ "training_model_package_artifact_uris": None, "deprecate_warn_message": None, "deprecated_message": None, - "hosting_model_package_arns": None, + "hosting_model_package_arns": {}, "hosting_eula_key": None, "model_subscription_link": None, "hyperparameters": [ @@ -8452,6 +8452,40 @@ "training_config_components": None, "inference_config_rankings": None, "training_config_rankings": None, + "hosting_additional_data_sources": None, + "hosting_neuron_model_id": None, + "hosting_neuron_model_version": None, +} + +BASE_HOSTING_ADDITIONAL_DATA_SOURCES = { + "hosting_additional_data_sources": { + "speculative_decoding": [ + { + "channel_name": "speculative_decoding_channel", + "artifact_version": "version", + "s3_data_source": { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "s3://bucket/path1", + "hub_access_config": None, + "model_access_config": None, + }, + } + ], + "scripts": [ + { + "channel_name": "scripts_channel", + "artifact_version": "version", + "s3_data_source": { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "s3://bucket/path1", + "hub_access_config": None, + "model_access_config": None, + }, + } + ], + }, } BASE_HEADER = { @@ -8599,28 +8633,52 @@ "inference_configs": { "neuron-inference": { "benchmark_metrics": { - "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + "ml.inf2.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, "component_names": ["neuron-inference"], }, "neuron-inference-budget": { "benchmark_metrics": { - "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + "ml.inf2.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, "component_names": ["neuron-base"], }, "gpu-inference-budget": { "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, "component_names": ["gpu-inference-budget"], }, "gpu-inference": { "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, "component_names": ["gpu-inference"], }, + "gpu-inference-model-package": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] + }, + "component_names": ["gpu-inference-model-package"], + }, + "gpu-accelerated": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] + }, + "component_names": ["gpu-accelerated"], + }, }, "inference_config_components": { "neuron-base": { @@ -8639,13 +8697,23 @@ "regional_aliases": { "us-west-2": { "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" } }, "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, }, }, - "neuron-budget": {"inference_environment_variables": {"BUDGET": "1234"}}, + "neuron-budget": { + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + } + ], + }, "gpu-inference": { "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference/model/", @@ -8653,7 +8721,7 @@ "regional_aliases": { "us-west-2": { "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" } }, "variants": { @@ -8662,9 +8730,32 @@ }, }, }, + "gpu-inference-model-package": { + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_model_package_arns": { + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll" + "ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + }, + }, "gpu-inference-budget": { "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference-budget/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting:1.13.1-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + }, + "gpu-accelerated": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "hosting_instance_type_variants": { "regional_aliases": { "us-west-2": { @@ -8677,6 +8768,20 @@ "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, }, }, + "hosting_additional_data_sources": { + "speculative_decoding": [ + { + "channel_name": "draft_model_name", + "artifact_version": "1.2.1", + "s3_data_source": { + "compression_type": "None", + "model_access_config": {"accept_eula": False}, + "s3_data_type": "S3Prefix", + "s3_uri": "key/to/draft/model/artifact/", + }, + } + ], + }, }, }, } @@ -8685,35 +8790,70 @@ "training_configs": { "neuron-training": { "benchmark_metrics": { - "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], - "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + "ml.tr1n1.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ], + "ml.tr1n1.4xlarge": [ + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ], }, "component_names": ["neuron-training"], + "default_inference_config": "neuron-inference", + "default_incremental_training_config": "neuron-training", + "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], + "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], }, "neuron-training-budget": { "benchmark_metrics": { - "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], - "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + "ml.tr1n1.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ], + "ml.tr1n1.4xlarge": [ + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ], }, "component_names": ["neuron-training-budget"], + "default_inference_config": "neuron-inference-budget", + "default_incremental_training_config": "neuron-training-budget", + "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], + "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], }, "gpu-training": { "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "200", "unit": "Tokens/S"}], + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "200", "unit": "Tokens/S", "concurrency": "1"} + ], }, "component_names": ["gpu-training"], + "default_inference_config": "gpu-inference", + "default_incremental_training_config": "gpu-training", + "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], + "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], }, "gpu-training-budget": { "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": "1"} + ] }, "component_names": ["gpu-training-budget"], + "default_inference_config": "gpu-inference-budget", + "default_incremental_training_config": "gpu-training-budget", + "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], + "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], }, }, "training_config_components": { "neuron-training": { + "default_training_instance_type": "ml.trn1.2xlarge", "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, "training_instance_type_variants": { "regional_aliases": { "us-west-2": { @@ -8725,13 +8865,14 @@ }, }, "gpu-training": { + "default_training_instance_type": "ml.p2.xlarge", "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training/model/", "training_instance_type_variants": { "regional_aliases": { "us-west-2": { "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + "huggingface-pytorch-training:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" } }, "variants": { @@ -8741,6 +8882,7 @@ }, }, "neuron-training-budget": { + "default_training_instance_type": "ml.trn1.2xlarge", "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training-budget/model/", "training_instance_type_variants": { @@ -8754,13 +8896,14 @@ }, }, "gpu-training-budget": { + "default_training_instance_type": "ml.p2.xlarge", "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/", "training_instance_type_variants": { "regional_aliases": { "us-west-2": { "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + "pytorch-training:1.13.1-py310-sdk2.14.1-ubuntu20.04" } }, "variants": { @@ -8782,6 +8925,7 @@ "neuron-inference-budget", "gpu-inference", "gpu-inference-budget", + "gpu-accelerated", ], }, "performance": { @@ -8845,6 +8989,173 @@ } } + +DEPLOYMENT_CONFIGS = [ + { + "DeploymentConfigName": "neuron-inference", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [ + {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} + ], + }, + { + "DeploymentConfigName": "neuron-inference-budget", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [ + {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} + ], + }, + { + "DeploymentConfigName": "gpu-inference-budget", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [ + {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} + ], + }, + { + "DeploymentConfigName": "gpu-inference", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], + }, +] + + +INIT_KWARGS = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu" + "-py310-cu121-ubuntu20.04", + "model_data": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface-textgeneration" + "-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "instance_type": "ml.p2.xlarge", + "env": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "role": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628", + "name": "hf-textgeneration-bloom-1b1-2024-04-22-20-23-48-799", + "enable_network_isolation": True, +} + HUB_MODEL_DOCUMENT_DICTS = { "huggingface-llm-gemma-2b-instruct": { "Url": "https://huggingface.co/google/gemma-2b-it", @@ -9494,6 +9805,60 @@ "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, }, }, + "InferenceConfigRankings": { + "overall": {"Description": "default", "Rankings": ["variant1"]} + }, + "InferenceConfigs": { + "variant1": { + "ComponentNames": ["variant1"], + "BenchmarkMetrics": { + "ml.g5.12xlarge": [ + {"Name": "latency", "Unit": "sec", "Value": "0.19", "Concurrency": "1"}, + ] + }, + }, + }, + "InferenceConfigComponents": { + "variant1": { + "HostingEcrUri": "123456789012.ecr.us-west-2.amazon.com/repository", + "HostingArtifactUri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-7b/artifacts/variant1/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-monarch-test-hub-bucket/monarch-curated-hub-1714579993.88695/curated_models/meta-textgeneration-llama-2-7b/4.0.0/source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", # noqa: E501 + "InferenceDependencies": [], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + } + ], + "HostingAdditionalDataSources": { + "speculative_decoding": [ + { + "ArtifactVersion": 1, + "ChannelName": "speculative_decoding_channel_1", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/path/1", + }, + }, + { + "ArtifactVersion": 1, + "ChannelName": "speculative_decoding_channel_2", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/path/2", + }, + }, + ] + }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, "ContextualHelp": { "HubFormatTrainData": [ "A train and an optional validation directories. Each directory contains a CSV/JSON/TXT. ", diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 062209e3a0..3678685db5 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -46,6 +46,8 @@ from sagemaker.model import Model from sagemaker.predictor import Predictor from tests.unit.sagemaker.jumpstart.utils import ( + get_prototype_manifest, + get_prototype_spec_with_configs, get_special_model_spec, overwrite_dictionary, ) @@ -700,7 +702,6 @@ def test_estimator_use_kwargs(self): "input_mode": "File", "output_path": "Optional[Union[str, PipelineVariable]] = None", "output_kms_key": "Optional[Union[str, PipelineVariable]] = None", - "base_job_name": "Optional[str] = None", "sagemaker_session": DEFAULT_JUMPSTART_SAGEMAKER_SESSION, "hyperparameters": {"hyp1": "val1"}, "tags": [], @@ -1033,6 +1034,8 @@ def test_jumpstart_estimator_attach_eula_model( additional_kwargs={ "model_id": "gemma-model", "model_version": "*", + "tolerate_vulnerable_model": True, + "tolerate_deprecated_model": True, "environment": {"accept_eula": "true"}, "tolerate_vulnerable_model": True, "tolerate_deprecated_model": True, @@ -1040,7 +1043,7 @@ def test_jumpstart_estimator_attach_eula_model( ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1048,15 +1051,17 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_id_version_from_training_job: mock.Mock, + get_model_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_id_version_from_training_job.return_value = ( + get_model_info_from_training_job.return_value = ( "js-trainable-model-prepacked", "1.0.0", + None, + None, ) mock_get_model_specs.side_effect = get_special_model_spec @@ -1067,7 +1072,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_id_version_from_training_job.assert_called_once_with( + get_model_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) @@ -1085,7 +1090,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1093,13 +1098,13 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_id_version_from_training_job: mock.Mock, + get_model_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_id_version_from_training_job.side_effect = ValueError() + get_model_info_from_training_job.side_effect = ValueError() mock_get_model_specs.side_effect = get_special_model_spec @@ -1110,7 +1115,7 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_id_version_from_training_job.assert_called_once_with( + get_model_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) @@ -1137,6 +1142,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "region", "tolerate_vulnerable_model", "tolerate_deprecated_model", + "config_name", "hub_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -1159,8 +1165,11 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): js_class_deploy = JumpStartEstimator.deploy js_class_deploy_args = set(signature(js_class_deploy).parameters.keys()) - assert js_class_deploy_args - parent_class_deploy_args == model_class_init_args - { + assert js_class_deploy_args - parent_class_deploy_args - { + "inference_config_name" + } == model_class_init_args - { "model_data", + "additional_model_data_sources", "self", "name", "resources", @@ -1243,6 +1252,7 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=estimator.sagemaker_session, + config_name=None, hub_arn=None, ) self.assertEqual(type(predictor), Predictor) @@ -1413,6 +1423,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + config_name=None, hub_arn=None, ) @@ -1469,6 +1480,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + config_name=None, hub_arn=None, ) @@ -1921,6 +1933,272 @@ def test_jumpstart_estimator_session( assert len(s3_clients) == 1 assert list(s3_clients)[0] == session.s3_client + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_initialization_with_config_name( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator( + model_id=model_id, + config_name="gpu-training", + ) + + mock_estimator_init.assert_called_once_with( + instance_type="ml.p2.xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-training:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", + model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" + "gpu-training/model/", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/transfer_learning/" + "eqa/v1.0.0/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={"epochs": "3", "adam-learning-rate": "2e-05", "batch-size": "4"}, + role="fake role! do not use!", + sagemaker_session=estimator.sagemaker_session, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "gpu-training"}, + ], + enable_network_isolation=False, + ) + + estimator.fit() + + mock_estimator_fit.assert_called_once_with(wait=True) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_set_config_name( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") + + estimator.set_training_config(config_name="gpu-training-budget") + + mock_estimator_init.assert_called_with( + instance_type="ml.p2.xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training:1.13.1-py310-sdk2.14.1-ubuntu20.04", + model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" + "gpu-training-budget/model/", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={"epochs": "3", "adam-learning-rate": "2e-05", "batch-size": "4"}, + role="fake role! do not use!", + sagemaker_session=estimator.sagemaker_session, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "gpu-training-budget"}, + ], + enable_network_isolation=False, + ) + + estimator.fit() + + mock_estimator_fit.assert_called_once_with(wait=True) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_default_inference_config( + self, + mock_estimator_fit: mock.Mock, + mock_estimator_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") + + assert estimator.config_name == "gpu-training" + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.p2.xlarge", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-hosting" + ":2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="inference.py", + predictor_cls=Predictor, + wait=True, + role="fake role! do not use!", + use_compiled_model=False, + enable_network_isolation=False, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference"}, + ], + ) + + @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_incremental_training_config( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_model_info_from_training_job: mock.Mock, + mock_attach: mock.Mock, + ): + mock_get_model_info_from_training_job.return_value = ( + "pytorch-eqa-bert-base-cased", + "1.0.0", + None, + "gpu-training-budget", + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") + + assert estimator.config_name == "gpu-training" + + JumpStartEstimator.attach( + training_job_name="some-training-job-name", sagemaker_session=mock_session + ) + + mock_attach.assert_called_once_with( + training_job_name="some-training-job-name", + sagemaker_session=mock_session, + model_channel_name="model", + additional_kwargs={ + "model_id": "pytorch-eqa-bert-base-cased", + "model_version": "1.0.0", + "tolerate_vulnerable_model": True, + "tolerate_deprecated_model": True, + "config_name": "gpu-training-budget", + }, + ) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_deploy_with_config( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training-budget") + + assert estimator.config_name == "gpu-training-budget" + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.p2.xlarge", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting:1.13.1-py310-sdk2.14.1-ubuntu20.04", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="inference.py", + predictor_cls=Predictor, + wait=True, + role="fake role! do not use!", + use_compiled_model=False, + enable_network_isolation=False, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference-budget"}, + ], + ) + def test_jumpstart_estimator_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py index c4b95443ec..11798bc854 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py @@ -15,9 +15,13 @@ import pytest import numpy as np from sagemaker.jumpstart.types import ( + JumpStartConfigComponent, + JumpStartConfigRanking, JumpStartHyperparameter, JumpStartInstanceTypeVariants, JumpStartEnvironmentVariable, + JumpStartMetadataConfig, + JumpStartMetadataConfigs, JumpStartPredictorSpecs, JumpStartSerializablePayload, ) @@ -32,9 +36,8 @@ def test_hub_content_document_from_json_obj(): region = "us-west-2" - gemma_model_document = HubModelDocument( - json_obj=HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"], region=region - ) + json_obj = HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"] + gemma_model_document = HubModelDocument(json_obj=json_obj, region=region) assert gemma_model_document.url == "https://huggingface.co/google/gemma-2b-it" assert gemma_model_document.min_sdk_version == "2.189.0" assert gemma_model_document.training_supported is True @@ -979,3 +982,69 @@ def test_hub_content_document_from_json_obj(): assert gemma_model_document.dynamic_container_deployment_supported is True assert gemma_model_document.training_model_package_artifact_uri is None assert gemma_model_document.dependencies == [] + + inference_config_rankings = { + "overall": JumpStartConfigRanking( + {"Description": "default", "Rankings": ["variant1"]}, is_hub_content=True + ) + } + + inference_config_components = { + "variant1": JumpStartConfigComponent( + "variant1", + { + "HostingEcrUri": "123456789012.ecr.us-west-2.amazon.com/repository", + "HostingArtifactUri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-7b/artifacts/variant1/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-monarch-test-hub-bucket/monarch-curated-hub-1714579993.88695/curated_models/meta-textgeneration-llama-2-7b/4.0.0/source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", # noqa: E501 + "InferenceDependencies": [], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + } + ], + "HostingAdditionalDataSources": { + "speculative_decoding": [ + { + "ArtifactVersion": 1, + "ChannelName": "speculative_decoding_channel_1", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/path/1", + }, + }, + { + "ArtifactVersion": 1, + "ChannelName": "speculative_decoding_channel_2", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/path/2", + }, + }, + ] + }, + }, + is_hub_content=True, + ) + } + + inference_configs_dict = { + "variant1": JumpStartMetadataConfig( + "variant1", + json_obj["InferenceConfigs"]["variant1"], + json_obj, + inference_config_components, + is_hub_content=True, + ) + } + + inference_configs = JumpStartMetadataConfigs(inference_configs_dict, inference_config_rankings) + + assert gemma_model_document.inference_config_rankings == inference_config_rankings + assert gemma_model_document.inference_config_components == inference_config_components + assert gemma_model_document.inference_configs == inference_configs diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 15c2c43bf0..56eaa0b660 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -15,6 +15,8 @@ from typing import Optional, Set from unittest import mock import unittest + +import pandas as pd from mock import MagicMock, Mock import pytest from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig @@ -40,12 +42,18 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from tests.unit.sagemaker.jumpstart.utils import ( + get_prototype_spec_with_configs, get_spec_from_base_spec, get_special_model_spec, overwrite_dictionary, get_special_model_spec_for_inference_component_based_endpoint, get_prototype_manifest, get_prototype_model_spec, + get_base_spec_with_prototype_configs, + get_mock_init_kwargs, + get_base_deployment_configs, + get_base_spec_with_prototype_configs_with_missing_benchmarks, + append_instance_stat_metrics, ) import boto3 @@ -60,9 +68,11 @@ class ModelTest(unittest.TestCase): - mock_session_empty_config = MagicMock(sagemaker_config={}) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @@ -82,6 +92,7 @@ def test_non_prepacked( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_jumpstart_model_factory_logger: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -141,6 +152,9 @@ def test_non_prepacked( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -158,6 +172,7 @@ def test_non_prepacked_inference_component_based_endpoint( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -223,6 +238,9 @@ def test_non_prepacked_inference_component_based_endpoint( endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -240,6 +258,7 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -300,6 +319,9 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" @@ -315,6 +337,7 @@ def test_prepacked( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -361,6 +384,9 @@ def test_prepacked( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.model.LOGGER.warning") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @@ -380,6 +406,7 @@ def test_no_compiled_model_warning_log_js_models( mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, mock_warning: mock.Mock(), + mock_get_jumpstart_configs: mock.Mock, ): mock_timestamp.return_value = "1234" @@ -400,6 +427,9 @@ def test_no_compiled_model_warning_log_js_models( mock_warning.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") @@ -417,6 +447,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( mock_create_model: mock.Mock, mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_timestamp.return_value = "1234" @@ -464,6 +495,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( ], ) + @mock.patch("sagemaker.jumpstart.model.get_jumpstart_configs") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @@ -485,7 +517,9 @@ def test_proprietary_model_endpoint( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): + mock_get_jumpstart_configs.side_effect = lambda *args, **kwargs: {} mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) @@ -534,6 +568,7 @@ def test_proprietary_model_endpoint( container_startup_health_check_timeout=600, ) + @mock.patch("sagemaker.jumpstart.model.get_jumpstart_configs") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -545,7 +580,9 @@ def test_deprecated( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): + mock_get_jumpstart_configs.side_effect = lambda *args, **kwargs: {} mock_model_deploy.return_value = default_predictor mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -561,6 +598,9 @@ def test_deprecated( JumpStartModel(model_id=model_id, tolerate_deprecated_model=True).deploy() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -572,6 +612,7 @@ def test_vulnerable( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -635,6 +676,9 @@ def test_model_use_kwargs(self): deploy_kwargs=all_deploy_kwargs_used, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.factory.model.environment_variables.retrieve_default") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -652,6 +696,7 @@ def evaluate_model_workflow_with_kwargs( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_retrieve_environment_variables: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, init_kwargs: Optional[dict] = None, deploy_kwargs: Optional[dict] = None, ): @@ -742,6 +787,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): "tolerate_deprecated_model", "instance_type", "model_package_arn", + "config_name", "hub_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -755,6 +801,9 @@ def test_jumpstart_model_kwargs_match_parent_class(self): assert js_class_deploy_args - parent_class_deploy_args == set() assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @@ -763,6 +812,7 @@ def test_validate_model_id_and_get_type( mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS JumpStartModel(model_id="valid_model_id") @@ -771,6 +821,9 @@ def test_validate_model_id_and_get_type( with pytest.raises(ValueError): JumpStartModel(model_id="invalid_model_id") + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -788,6 +841,7 @@ def test_no_predictor_returns_default_predictor( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -817,10 +871,14 @@ def test_no_predictor_returns_default_predictor( tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, model_type=JumpStartModelType.OPEN_WEIGHTS, + config_name=None, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -838,6 +896,7 @@ def test_no_predictor_yes_async_inference_config( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -859,6 +918,9 @@ def test_no_predictor_yes_async_inference_config( mock_get_default_predictor.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -876,6 +938,7 @@ def test_yes_predictor_returns_default_predictor( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -897,6 +960,9 @@ def test_yes_predictor_returns_default_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -914,6 +980,7 @@ def test_model_id_not_found_refeshes_cache_inference( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.side_effect = [False, False] @@ -986,6 +1053,9 @@ def test_model_id_not_found_refeshes_cache_inference( ] ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -993,6 +1063,7 @@ def test_jumpstart_model_tags( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1022,6 +1093,9 @@ def test_jumpstart_model_tags( [{"Key": "blah", "Value": "blahagain"}] + js_tags, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1029,6 +1103,7 @@ def test_jumpstart_model_tags_disabled( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1056,6 +1131,9 @@ def test_jumpstart_model_tags_disabled( [{"Key": "blah", "Value": "blahagain"}], ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1063,6 +1141,7 @@ def test_jumpstart_model_package_arn( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1090,6 +1169,9 @@ def test_jumpstart_model_package_arn( self.assertIn(tag, mock_session.create_model.call_args[1]["tags"]) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1097,6 +1179,7 @@ def test_jumpstart_model_package_arn_override( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1132,6 +1215,9 @@ def test_jumpstart_model_package_arn_override( }, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" @@ -1143,6 +1229,7 @@ def test_jumpstart_model_package_arn_unsupported_region( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1160,6 +1247,9 @@ def test_jumpstart_model_package_arn_unsupported_region( "us-east-2. Please try one of the following regions: us-west-2, us-east-1." ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -1179,6 +1269,7 @@ def test_model_data_s3_prefix_override( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1228,6 +1319,9 @@ def test_model_data_s3_prefix_override( '"S3DataType": "S3Prefix", "CompressionType": "None"}}', ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" @@ -1245,6 +1339,7 @@ def test_model_data_s3_prefix_model( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1274,6 +1369,9 @@ def test_model_data_s3_prefix_model( mock_js_info_logger.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" @@ -1291,6 +1389,7 @@ def test_model_artifact_variant_model( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1341,17 +1440,23 @@ def test_model_artifact_variant_model( enable_network_isolation=True, ) - @mock.patch("sagemaker.jumpstart.model.get_model_id_version_from_endpoint") + @mock.patch("sagemaker.jumpstart.model.get_model_info_from_endpoint") @mock.patch("sagemaker.jumpstart.model.JumpStartModel.__init__") def test_attach( self, mock_js_model_init, - mock_get_model_id_version_from_endpoint, + mock_get_model_info_from_endpoint, ): mock_js_model_init.return_value = None - mock_get_model_id_version_from_endpoint.return_value = "model-id", "model-version", None + mock_get_model_info_from_endpoint.return_value = ( + "model-id", + "model-version", + None, + None, + None, + ) val = JumpStartModel.attach("some-endpoint") - mock_get_model_id_version_from_endpoint.assert_called_once_with( + mock_get_model_info_from_endpoint.assert_called_once_with( endpoint_name="some-endpoint", inference_component_name=None, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -1363,32 +1468,36 @@ def test_attach( ) assert isinstance(val, JumpStartModel) - mock_get_model_id_version_from_endpoint.reset_mock() + mock_get_model_info_from_endpoint.reset_mock() JumpStartModel.attach("some-endpoint", model_id="some-id") - mock_get_model_id_version_from_endpoint.assert_called_once_with( + mock_get_model_info_from_endpoint.assert_called_once_with( endpoint_name="some-endpoint", inference_component_name=None, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) - mock_get_model_id_version_from_endpoint.reset_mock() + mock_get_model_info_from_endpoint.reset_mock() JumpStartModel.attach("some-endpoint", model_id="some-id", model_version="some-version") - mock_get_model_id_version_from_endpoint.assert_called_once_with( + mock_get_model_info_from_endpoint.assert_called_once_with( endpoint_name="some-endpoint", inference_component_name=None, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) # providing model id, version, and ic name should bypass check with endpoint tags - mock_get_model_id_version_from_endpoint.reset_mock() + mock_get_model_info_from_endpoint.reset_mock() JumpStartModel.attach( "some-endpoint", model_id="some-id", model_version="some-version", inference_component_name="some-ic-name", ) - mock_get_model_id_version_from_endpoint.assert_not_called() + mock_get_model_info_from_endpoint.assert_not_called() + + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" @@ -1404,6 +1513,7 @@ def test_model_registry_accept_and_response_types( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1425,6 +1535,9 @@ def test_model_registry_accept_and_response_types( model_package_group_name=model.model_id, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -1438,6 +1551,7 @@ def test_jumpstart_model_session( mock_deploy, mock_init, get_default_predictor, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = True @@ -1471,6 +1585,9 @@ def test_jumpstart_model_session( assert len(s3_clients) == 1 assert list(s3_clients)[0] == session.s3_client + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch.dict( "sagemaker.jumpstart.cache.os.environ", { @@ -1491,6 +1608,7 @@ def test_model_local_mode( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_model_specs.side_effect = get_prototype_model_spec mock_get_manifest.side_effect = ( @@ -1517,6 +1635,514 @@ def test_model_local_mode( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_initialization_with_config_name( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id, config_name="neuron-inference") + + assert model.config_name == "neuron-inference" + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, + ], + wait=True, + endpoint_logging=False, + ) + + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name is None + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + mock_get_model_specs.reset_mock() + mock_model_deploy.reset_mock() + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + model.set_deployment_config("neuron-inference", "ml.inf2.2xlarge") + + assert model.config_name == "neuron-inference" + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.2xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, + ], + wait=True, + endpoint_logging=False, + ) + mock_model_deploy.reset_mock() + model.set_deployment_config("neuron-inference", "ml.inf2.xlarge") + + assert model.config_name == "neuron-inference" + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, + ], + wait=True, + endpoint_logging=False, + ) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.model.Model.__init__") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_deployment_config_additional_model_data_source( + self, + mock_model_init: mock.Mock, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_session.return_value = sagemaker_session + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + model = JumpStartModel(model_id=model_id, config_name="gpu-accelerated") + + mock_model_init.assert_called_once_with( + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", + model_data="s3://jumpstart-cache-prod-us-west-2/pytorch-infer/" + "infer-pytorch-eqa-bert-base-cased.tar.gz", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="inference.py", + predictor_cls=Predictor, + role=execution_role, + sagemaker_session=sagemaker_session, + enable_network_isolation=False, + additional_model_data_sources=[ + { + "ChannelName": "draft_model_name", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://sagemaker-sd-models-prod-us-west-2/key/to/draft/model/artifact/", + "ModelAccessConfig": {"AcceptEula": False}, + }, + } + ], + ) + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-accelerated"}, + ], + wait=True, + endpoint_logging=False, + ) + + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config_model_package( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name == "neuron-inference" + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, + ], + wait=True, + endpoint_logging=False, + ) + + mock_model_deploy.reset_mock() + + model.set_deployment_config( + config_name="gpu-inference-model-package", instance_type="ml.p2.xlarge" + ) + + assert ( + model.model_package_arn + == "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + ) + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference-model-package"}, + ], + wait=True, + endpoint_logging=False, + ) + + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config_incompatible_instance_type_or_name( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name is None + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + mock_get_model_specs.reset_mock() + mock_model_deploy.reset_mock() + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + + with pytest.raises(ValueError) as error: + model.set_deployment_config("neuron-inference-unknown-name", "ml.inf2.32xlarge") + assert "Cannot find Jumpstart config name neuron-inference-unknown-name. " in str(error) + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_list_deployment_configs( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() + ) + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + configs = model.list_deployment_configs() + + self.assertEqual(configs, get_base_deployment_configs(True)) + + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_list_deployment_configs_empty( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_special_model_spec(model_id="gemma-model") + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + configs = model.list_deployment_configs() + + self.assertTrue(len(configs) == 0) + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_retrieve_deployment_config( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs() + ) + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + expected = get_base_deployment_configs()[0] + config_name = expected.get("DeploymentConfigName") + instance_type = expected.get("InstanceType") + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs( + model_id, config_name + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + model.set_deployment_config(config_name, instance_type) + + self.assertEqual(model.deployment_config, expected) + + mock_get_init_kwargs.reset_mock() + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_display_benchmark_metrics( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() + ) + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + model.display_benchmark_metrics() + model.display_benchmark_metrics(instance_type="g5.12xlarge") + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_benchmark_metrics( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs() + ) + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + df = model.benchmark_metrics + + self.assertTrue(isinstance(df, pd.DataFrame)) + def test_jumpstart_model_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index 70409704e6..2be4bde7e4 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -60,6 +60,9 @@ class IntelligentDefaultsModelTest(unittest.TestCase): region = "us-west-2" sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -77,6 +80,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -101,6 +105,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( assert "enable_network_isolation" not in mock_model_init.call_args[1] + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -118,6 +125,7 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -147,6 +155,9 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -164,6 +175,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -193,6 +205,9 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( config_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -210,6 +225,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -241,6 +257,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -258,6 +277,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -287,6 +307,9 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( metadata_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -304,6 +327,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -334,6 +358,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -351,6 +378,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -375,6 +403,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( self.assertEquals(mock_model_init.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_model_init.call_args[1] + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -392,6 +423,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 9fd9cc8398..a06b48deb7 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -1,6 +1,7 @@ from __future__ import absolute_import -import json + import datetime +import json from unittest import TestCase from unittest.mock import Mock, patch, ANY @@ -237,7 +238,7 @@ def test_list_jumpstart_models_script_filter( get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() ) patched_get_manifest.side_effect = ( - lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region) ) manifest_length = len(get_prototype_manifest()) @@ -245,7 +246,7 @@ def test_list_jumpstart_models_script_filter( for val in vals: kwargs = {"filter": And(f"training_supported == {val}", "model_type is open_weights")} list_jumpstart_models(**kwargs) - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -253,7 +254,7 @@ def test_list_jumpstart_models_script_filter( kwargs = {"filter": And(f"training_supported != {val}", "model_type is open_weights")} list_jumpstart_models(**kwargs) - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -272,7 +273,7 @@ def test_list_jumpstart_models_script_filter( ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"), ("xgboost-classification-model", "1.0.0"), ] - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -281,7 +282,7 @@ def test_list_jumpstart_models_script_filter( kwargs = {"filter": And(f"training_supported not in {vals}", "model_type is open_weights")} models = list_jumpstart_models(**kwargs) assert [] == models - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 6f86f724a9..8368f72d58 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -18,7 +18,7 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, get_spec_from_base_spec -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support( @@ -52,7 +52,7 @@ def test_jumpstart_predictor_support( assert js_predictor.accept == MIMEType.JSON -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_proprietary_predictor_support( @@ -91,13 +91,13 @@ def test_proprietary_predictor_support( @patch("sagemaker.predictor.Predictor") @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( patched_get_model_specs, patched_verify_model_region_and_return_specs, - patched_get_jumpstart_model_id_version_from_endpoint, + patched_get_model_info_from_endpoint, patched_get_default_predictor, patched_predictor, ): @@ -105,19 +105,19 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_get_jumpstart_model_id_version_from_endpoint.return_value = ( + patched_get_model_info_from_endpoint.return_value = ( "predictor-specs-model", "1.2.3", None, + None, + None, ) mock_session = Mock() predictor.retrieve_default(endpoint_name="blah", sagemaker_session=mock_session) - patched_get_jumpstart_model_id_version_from_endpoint.assert_called_once_with( - "blah", None, mock_session - ) + patched_get_model_info_from_endpoint.assert_called_once_with("blah", None, mock_session) patched_get_default_predictor.assert_called_once_with( predictor=patched_predictor.return_value, @@ -128,12 +128,13 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_vulnerable_model=False, sagemaker_session=mock_session, model_type=JumpStartModelType.OPEN_WEIGHTS, + config_name=None, hub_arn=None, ) @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( @@ -160,7 +161,8 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( patched_get_default_predictor.assert_not_called() -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}) +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") @patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @@ -170,7 +172,8 @@ def test_jumpstart_serializable_payload_with_predictor( patched_verify_model_region_and_return_specs, patched_validate_model_id_and_get_type, patched_get_object_cached, - patched_get_model_id_version_from_endpoint, + patched_get_model_info_from_endpoint, + patched_get_jumpstart_configs, ): patched_get_object_cached.return_value = base64.b64decode("encodedimage") @@ -180,7 +183,7 @@ def test_jumpstart_serializable_payload_with_predictor( patched_get_model_specs.side_effect = get_special_model_spec model_id, model_version = "default_payloads", "*" - patched_get_model_id_version_from_endpoint.return_value = model_id, model_version, None + patched_get_model_info_from_endpoint.return_value = model_id, model_version, None js_predictor = predictor.retrieve_default( endpoint_name="blah", model_id=model_id, model_version=model_version diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index 76ad50f31c..ce06a189bd 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -4,167 +4,202 @@ import pytest from sagemaker.jumpstart.session_utils import ( - _get_model_id_version_from_inference_component_endpoint_with_inference_component_name, - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name, - _get_model_id_version_from_model_based_endpoint, - get_model_id_version_from_endpoint, - get_model_id_version_from_training_job, + _get_model_info_from_inference_component_endpoint_with_inference_component_name, + _get_model_info_from_inference_component_endpoint_without_inference_component_name, + _get_model_info_from_model_based_endpoint, + get_model_info_from_endpoint, + get_model_info_from_training_job, ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_training_job_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_training_job_happy_case( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", + None, + None, ) - retval = get_model_id_version_from_training_job("bLaH", sagemaker_session=mock_sm_session) + retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_training_job_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_training_job_config_name( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( + "model_id", + "model_version", + None, + "training_config_name", + ) + + retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", None, "training_config_name") + + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session + ) + + +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_training_job_no_model_id_inferred( + mock_get_jumpstart_model_info_from_resource_arn, +): + mock_sm_session = Mock() + mock_sm_session.boto_region_name = "us-west-2" + mock_sm_session.account_id = Mock(return_value="123456789012") + + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( None, None, ) with pytest.raises(ValueError): - get_model_id_version_from_training_job("blah", sagemaker_session=mock_sm_session) + get_model_info_from_training_job("blah", sagemaker_session=mock_sm_session) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_model_based_endpoint_happy_case( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", + None, + None, ) - retval = _get_model_id_version_from_model_based_endpoint( + retval = _get_model_info_from_model_based_endpoint( "bLaH", inference_component_name=None, sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:endpoint/blah", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_inference_component_supplied( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_model_based_endpoint_inference_component_supplied( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", + None, + None, ) with pytest.raises(ValueError): - _get_model_id_version_from_model_based_endpoint( + _get_model_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_model_based_endpoint_no_model_id_inferred( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( + None, None, None, ) with pytest.raises(ValueError): - _get_model_id_version_from_model_based_endpoint( + _get_model_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_happy_case( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", + None, + None, ) - retval = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + retval = _get_model_info_from_inference_component_endpoint_with_inference_component_name( "bLaH", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:inference-component/bLaH", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( + None, + None, None, None, ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + _get_model_info_from_inference_component_endpoint_with_inference_component_name( "blah", sagemaker_session=mock_sm_session ) @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_inference_component_endpoint_without_inference_component_name_happy_case( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_inference_component_endpoint_without_inference_component_name_happy_case( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -172,10 +207,8 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc return_value=["icname"] ) - retval = ( - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( - "blahblah", mock_sm_session - ) + retval = _get_model_info_from_inference_component_endpoint_without_inference_component_name( + "blahblah", mock_sm_session ) assert retval == ("model_id", "model_version", "icname") @@ -185,14 +218,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -200,7 +233,7 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_ return_value=[] ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( + _get_model_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -210,14 +243,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_ @patch( - "sagemaker.jumpstart.session_utils._get_model_id" - "_version_from_inference_component_endpoint_with_inference_component_name" + "sagemaker.jumpstart.session_utils._get_model" + "_info_from_inference_component_endpoint_with_inference_component_name" ) def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_multiple_ics_for_endpoint( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -227,7 +260,7 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( + _get_model_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -236,67 +269,119 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) -@patch("sagemaker.jumpstart.session_utils._get_model_id_version_from_model_based_endpoint") -def test_get_model_id_version_from_endpoint_non_inference_component_endpoint( - mock_get_model_id_version_from_model_based_endpoint, +@patch("sagemaker.jumpstart.session_utils._get_model_info_from_model_based_endpoint") +def test_get_model_info_from_endpoint_non_inference_component_endpoint( + mock_get_model_info_from_model_based_endpoint, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = False - mock_get_model_id_version_from_model_based_endpoint.return_value = ( + mock_get_model_info_from_model_based_endpoint.return_value = ( "model_id", "model_version", + None, + None, ) - retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", None) - mock_get_model_id_version_from_model_based_endpoint.assert_called_once_with( + assert retval == ("model_id", "model_version", None, None, None) + mock_get_model_info_from_model_based_endpoint.assert_called_once_with( "blah", None, mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_called_once_with("blah") @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_endpoint_inference_component_endpoint_with_inference_component_name( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_endpoint_inference_component_endpoint_with_inference_component_name( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", + None, + None, ) - retval = get_model_id_version_from_endpoint( + retval = get_model_info_from_endpoint( "blah", inference_component_name="icname", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version", "icname") - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( + assert retval == ("model_id", "model_version", "icname", None, None) + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( "icname", mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_not_called() @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_component_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" + "endpoint_without_inference_component_name" +) +def test_get_model_info_from_endpoint_inference_component_endpoint_without_inference_component_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, +): + mock_sm_session = Mock() + mock_sm_session.is_inference_component_based_endpoint.return_value = True + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( + "model_id", + "model_version", + None, + None, + "inferred-icname", + ) + + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", "inferred-icname", None, None) + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + + +@patch( + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" "endpoint_without_inference_component_name" ) -def test_get_model_id_version_from_endpoint_inference_component_endpoint_without_inference_component_name( - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name, +def test_get_model_info_from_endpoint_inference_component_endpoint_with_inference_config_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( "model_id", "model_version", + "inference_config_name", + None, + "inferred-icname", + ) + + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", "inferred-icname", "inference_config_name", None) + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + + +@patch( + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" + "endpoint_without_inference_component_name" +) +def test_get_model_info_from_endpoint_inference_component_endpoint_with_training_config_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, +): + mock_sm_session = Mock() + mock_sm_session.is_inference_component_based_endpoint.return_value = True + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( + "model_id", + "model_version", + None, + "training_config_name", "inferred-icname", ) - retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", "inferred-icname") - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + assert retval == ("model_id", "model_version", "inferred-icname", None, "training_config_name") + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 987feef7da..884639b5d6 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -12,23 +12,31 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy +from unittest import TestCase import pytest from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import ( JumpStartBenchmarkStat, JumpStartECRSpecs, + JumpStartEnvironmentVariable, JumpStartHyperparameter, JumpStartInstanceTypeVariants, JumpStartModelSpecs, JumpStartModelHeader, JumpStartConfigComponent, + DeploymentConfigMetadata, + JumpStartModelInitKwargs, + S3DataSource, ) +from sagemaker.utils import S3_PREFIX from tests.unit.sagemaker.jumpstart.constants import ( BASE_SPEC, + BASE_HOSTING_ADDITIONAL_DATA_SOURCES, INFERENCE_CONFIG_RANKINGS, INFERENCE_CONFIGS, TRAINING_CONFIG_RANKINGS, TRAINING_CONFIGS, + INIT_KWARGS, ) INSTANCE_TYPE_VARIANT = JumpStartInstanceTypeVariants( @@ -438,6 +446,54 @@ def test_jumpstart_model_specs(): assert specs3 == specs1 +class TestS3DataSource(TestCase): + def setUp(self): + self.s3_data_source = S3DataSource( + { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "key/to/model/artifact/", + "model_access_config": {"accept_eula": False}, + } + ) + + def test_set_bucket_with_valid_s3_uri(self): + self.s3_data_source.set_bucket("my-bucket") + self.assertEqual(self.s3_data_source.s3_uri, f"{S3_PREFIX}my-bucket/key/to/model/artifact/") + + def test_set_bucket_with_existing_s3_uri(self): + self.s3_data_source.s3_uri = "s3://my-bucket/key/to/model/artifact/" + self.s3_data_source.set_bucket("random-new-bucket") + assert self.s3_data_source.s3_uri == "s3://random-new-bucket/key/to/model/artifact/" + + def test_set_bucket_with_existing_s3_uri_empty_bucket(self): + self.s3_data_source.s3_uri = "s3://my-bucket" + self.s3_data_source.set_bucket("random-new-bucket") + assert self.s3_data_source.s3_uri == "s3://random-new-bucket" + + def test_set_bucket_with_existing_s3_uri_empty(self): + self.s3_data_source.s3_uri = "s3://" + self.s3_data_source.set_bucket("random-new-bucket") + assert self.s3_data_source.s3_uri == "s3://random-new-bucket" + + +def test_get_speculative_decoding_s3_data_sources(): + specs = JumpStartModelSpecs({**BASE_SPEC, **BASE_HOSTING_ADDITIONAL_DATA_SOURCES}) + assert ( + specs.get_speculative_decoding_s3_data_sources() + == specs.hosting_additional_data_sources.speculative_decoding + ) + + +def test_get_additional_s3_data_sources(): + specs = JumpStartModelSpecs({**BASE_SPEC, **BASE_HOSTING_ADDITIONAL_DATA_SOURCES}) + data_sources = [ + *specs.hosting_additional_data_sources.speculative_decoding, + *specs.hosting_additional_data_sources.scripts, + ] + assert specs.get_additional_s3_data_sources() == data_sources + + def test_jumpstart_image_uri_instance_variants(): assert ( @@ -930,7 +986,9 @@ def test_inference_configs_parsing(): "neuron-inference", "neuron-budget", "gpu-inference", + "gpu-inference-model-package", "gpu-inference-budget", + "gpu-accelerated", ] # Non-overrided fields in top config @@ -1022,6 +1080,80 @@ def test_inference_configs_parsing(): } ), ] + assert specs1.inference_environment_variables == [ + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + } + ), + JumpStartEnvironmentVariable( + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + } + ), + JumpStartEnvironmentVariable( + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + } + ), + ] # Overrided fields in top config assert specs1.supported_inference_instance_types == ["ml.inf2.xlarge", "ml.inf2.2xlarge"] @@ -1030,7 +1162,9 @@ def test_inference_configs_parsing(): assert config.benchmark_metrics == { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), ] } assert len(config.config_components) == 1 @@ -1049,7 +1183,7 @@ def test_inference_configs_parsing(): "regional_aliases": { "us-west-2": { "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" } }, "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, @@ -1058,6 +1192,20 @@ def test_inference_configs_parsing(): ) assert list(config.config_components.keys()) == ["neuron-inference"] + config = specs1.inference_configs.configs["gpu-inference-model-package"] + assert config.config_components["gpu-inference-model-package"] == JumpStartConfigComponent( + "gpu-inference-model-package", + { + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_model_package_arns": { + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/" + "llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + }, + }, + ) + assert config.resolved_config.get("inference_environment_variables") == [] + spec = { **BASE_SPEC, **INFERENCE_CONFIGS, @@ -1076,7 +1224,9 @@ def test_set_inference_configs(): "neuron-inference", "neuron-budget", "gpu-inference", + "gpu-inference-model-package", "gpu-inference-budget", + "gpu-accelerated", ] with pytest.raises(ValueError) as error: @@ -1084,7 +1234,7 @@ def test_set_inference_configs(): assert "Cannot find Jumpstart config name invalid_name." "List of config names that is supported by the model: " "['neuron-inference', 'neuron-inference-budget', " - "'gpu-inference-budget', 'gpu-inference']" in str(error.value) + "'gpu-inference-budget', 'gpu-inference', 'gpu-inference-model-package']" in str(error.value) assert specs1.supported_inference_instance_types == ["ml.inf2.xlarge", "ml.inf2.2xlarge"] specs1.set_config("gpu-inference") @@ -1194,18 +1344,29 @@ def test_training_configs_parsing(): assert config.benchmark_metrics == { "ml.tr1n1.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), ], "ml.tr1n1.4xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "50", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ), ], } assert len(config.config_components) == 1 assert config.config_components["neuron-training"] == JumpStartConfigComponent( "neuron-training", { + "default_training_instance_type": "ml.trn1.2xlarge", "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, "training_instance_type_variants": { "regional_aliases": { "us-west-2": { @@ -1220,6 +1381,83 @@ def test_training_configs_parsing(): assert list(config.config_components.keys()) == ["neuron-training"] +def test_additional_model_data_source_parsing(): + accelerated_first_rankings = { + "inference_config_rankings": { + "overall": { + "description": "Overall rankings of configs", + "rankings": [ + "gpu-accelerated", + "neuron-inference", + "neuron-inference-budget", + "gpu-inference", + "gpu-inference-budget", + ], + } + } + } + spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **accelerated_first_rankings} + specs1 = JumpStartModelSpecs(spec) + + config = specs1.inference_configs.get_top_config_from_ranking() + + assert config.benchmark_metrics == { + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + ] + } + assert len(config.config_components) == 1 + assert config.config_components["gpu-accelerated"] == JumpStartConfigComponent( + "gpu-accelerated", + { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + "hosting_additional_data_sources": { + "speculative_decoding": [ + { + "channel_name": "draft_model_name", + "artifact_version": "1.2.1", + "s3_data_source": { + "compression_type": "None", + "model_access_config": {"accept_eula": False}, + "s3_data_type": "S3Prefix", + "s3_uri": "key/to/draft/model/artifact/", + }, + } + ], + }, + }, + ) + assert list(config.config_components.keys()) == ["gpu-accelerated"] + assert config.resolved_config["hosting_additional_data_sources"] == { + "speculative_decoding": [ + { + "channel_name": "draft_model_name", + "artifact_version": "1.2.1", + "s3_data_source": { + "compression_type": "None", + "model_access_config": {"accept_eula": False}, + "s3_data_type": "S3Prefix", + "s3_uri": "key/to/draft/model/artifact/", + }, + } + ], + } + + def test_set_inference_config(): spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} specs1 = JumpStartModelSpecs(spec) @@ -1262,3 +1500,38 @@ def test_set_training_config(): with pytest.raises(ValueError) as error: specs1.set_config("invalid_name", scope="unknown scope") + + +def test_deployment_config_metadata(): + spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + specs = JumpStartModelSpecs(spec) + jumpstart_config = specs.inference_configs.get_top_config_from_ranking() + + deployment_config_metadata = DeploymentConfigMetadata( + jumpstart_config.config_name, + jumpstart_config, + JumpStartModelInitKwargs( + model_id=specs.model_id, + model_data=INIT_KWARGS.get("model_data"), + image_uri=INIT_KWARGS.get("image_uri"), + instance_type=INIT_KWARGS.get("instance_type"), + env=INIT_KWARGS.get("env"), + config_name=jumpstart_config.config_name, + ), + ) + + json_obj = deployment_config_metadata.to_json() + + assert isinstance(json_obj, dict) + assert json_obj["DeploymentConfigName"] == jumpstart_config.config_name + for key in json_obj["BenchmarkMetrics"]: + assert len(json_obj["BenchmarkMetrics"][key]) == len( + jumpstart_config.benchmark_metrics.get(key) + ) + assert json_obj["AccelerationConfigs"] == jumpstart_config.resolved_config.get( + "acceleration_configs" + ) + assert json_obj["DeploymentArgs"]["ImageUri"] == INIT_KWARGS.get("image_uri") + assert json_obj["DeploymentArgs"]["ModelData"] == INIT_KWARGS.get("model_data") + assert json_obj["DeploymentArgs"]["Environment"] == INIT_KWARGS.get("env") + assert json_obj["DeploymentArgs"]["InstanceType"] == INIT_KWARGS.get("instance_type") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 941e2797ea..533483a497 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -13,6 +13,8 @@ from __future__ import absolute_import import os from unittest import TestCase + +from botocore.exceptions import ClientError from mock.mock import Mock, patch import pytest import boto3 @@ -24,6 +26,7 @@ ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE, ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE, + ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE, EXTRA_MODEL_ID_TAGS, EXTRA_MODEL_VERSION_TAGS, JUMPSTART_DEFAULT_REGION_NAME, @@ -31,6 +34,7 @@ JUMPSTART_LOGGER, JUMPSTART_REGION_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME, + NEO_DEFAULT_REGION_NAME, JumpStartScriptScope, ) from functools import partial @@ -49,8 +53,10 @@ get_spec_from_base_spec, get_special_model_spec, get_prototype_manifest, + get_base_deployment_configs_metadata, + get_base_deployment_configs, ) -from mock import MagicMock, call +from mock import MagicMock MOCK_CLIENT = MagicMock() @@ -60,79 +66,95 @@ def random_jumpstart_s3_uri(key): return f"s3://{random.choice(list(JUMPSTART_GATED_AND_PUBLIC_BUCKET_NAME_SET))}/{key}" -def test_get_jumpstart_content_bucket(): - bad_region = "bad_region" - assert bad_region not in JUMPSTART_REGION_NAME_SET - with pytest.raises(ValueError): - utils.get_jumpstart_content_bucket(bad_region) - - -def test_get_jumpstart_content_bucket_no_args(): - assert ( - utils.get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME) - == utils.get_jumpstart_content_bucket() - ) - +class TestBucketUtils(TestCase): + def test_get_jumpstart_content_bucket(self): + bad_region = "bad_region" + assert bad_region not in JUMPSTART_REGION_NAME_SET + with pytest.raises(ValueError): + utils.get_jumpstart_content_bucket(bad_region) -def test_get_jumpstart_content_bucket_override(): - with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}): - with patch("logging.Logger.info") as mocked_info_log: - random_region = "random_region" - assert "some-val" == utils.get_jumpstart_content_bucket(random_region) - mocked_info_log.assert_called_with("Using JumpStart bucket override: 'some-val'") + def test_get_jumpstart_content_bucket_no_args(self): + assert ( + utils.get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME) + == utils.get_jumpstart_content_bucket() + ) + def test_get_jumpstart_content_bucket_override(self): + with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}): + with patch("logging.Logger.info") as mocked_info_log: + random_region = "random_region" + assert "some-val" == utils.get_jumpstart_content_bucket(random_region) + mocked_info_log.assert_called_with("Using JumpStart bucket override: 'some-val'") -def test_get_jumpstart_gated_content_bucket(): - bad_region = "bad_region" - assert bad_region not in JUMPSTART_REGION_NAME_SET - with pytest.raises(ValueError): - utils.get_jumpstart_gated_content_bucket(bad_region) + def test_get_jumpstart_gated_content_bucket(self): + bad_region = "bad_region" + assert bad_region not in JUMPSTART_REGION_NAME_SET + with pytest.raises(ValueError): + utils.get_jumpstart_gated_content_bucket(bad_region) + def test_get_jumpstart_gated_content_bucket_no_args(self): + assert ( + utils.get_jumpstart_gated_content_bucket(JUMPSTART_DEFAULT_REGION_NAME) + == utils.get_jumpstart_gated_content_bucket() + ) -def test_get_jumpstart_gated_content_bucket_no_args(): - assert ( - utils.get_jumpstart_gated_content_bucket(JUMPSTART_DEFAULT_REGION_NAME) - == utils.get_jumpstart_gated_content_bucket() - ) + def test_get_jumpstart_gated_content_bucket_override(self): + with patch.dict( + os.environ, {ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE: "some-val"} + ): + with patch("logging.Logger.info") as mocked_info_log: + random_region = "random_region" + assert "some-val" == utils.get_jumpstart_gated_content_bucket(random_region) + mocked_info_log.assert_called_once_with( + "Using JumpStart gated bucket override: 'some-val'" + ) + def test_get_jumpstart_launched_regions_message(self): -def test_get_jumpstart_gated_content_bucket_override(): - with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE: "some-val"}): - with patch("logging.Logger.info") as mocked_info_log: - random_region = "random_region" - assert "some-val" == utils.get_jumpstart_gated_content_bucket(random_region) - mocked_info_log.assert_called_once_with( - "Using JumpStart gated bucket override: 'some-val'" + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is not available in any region." ) + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"some_region"}): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is available in some_region region." + ) -def test_get_jumpstart_launched_regions_message(): + with patch( + "sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", + {"some_region1", "some_region2"}, + ): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is available in some_region1 and some_region2 regions." + ) - with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}): - assert ( - utils.get_jumpstart_launched_regions_message() - == "JumpStart is not available in any region." - ) + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"a", "b", "c"}): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is available in a, b, and c regions." + ) - with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"some_region"}): - assert ( - utils.get_jumpstart_launched_regions_message() - == "JumpStart is available in some_region region." - ) + def test_get_neo_content_bucket(self): + bad_region = "bad_region" + assert bad_region not in JUMPSTART_REGION_NAME_SET + with pytest.raises(ValueError): + utils.get_neo_content_bucket(bad_region) - with patch( - "sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"some_region1", "some_region2"} - ): + def test_get_neo_content_bucket_no_args(self): assert ( - utils.get_jumpstart_launched_regions_message() - == "JumpStart is available in some_region1 and some_region2 regions." + utils.get_neo_content_bucket(NEO_DEFAULT_REGION_NAME) == utils.get_neo_content_bucket() ) - with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"a", "b", "c"}): - assert ( - utils.get_jumpstart_launched_regions_message() - == "JumpStart is available in a, b, and c regions." - ) + def test_get_neo_content_bucket_override(self): + with patch.dict(os.environ, {ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE: "some-val"}): + with patch("logging.Logger.info") as mocked_info_log: + random_region = "random_region" + assert "some-val" == utils.get_neo_content_bucket(random_region) + mocked_info_log.assert_called_with("Using Neo bucket override: 'some-val'") def test_get_formatted_manifest(): @@ -207,16 +229,16 @@ def test_is_jumpstart_model_uri(): assert utils.is_jumpstart_model_uri(random_jumpstart_s3_uri("random_key")) -def test_add_jumpstart_model_id_version_tags(): +def test_add_jumpstart_model_info_tags(): tags = None model_id = "model_id" version = "version" + inference_config_name = "inference_config_name" + training_config_name = "training_config_name" assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, @@ -228,9 +250,7 @@ def test_add_jumpstart_model_id_version_tags(): assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version_2"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "random key", "Value": "random_value"}, @@ -241,9 +261,7 @@ def test_add_jumpstart_model_id_version_tags(): {"Key": "random key", "Value": "random_value"}, {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, @@ -254,9 +272,7 @@ def test_add_jumpstart_model_id_version_tags(): assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "random key", "Value": "random_value"}, @@ -265,8 +281,58 @@ def test_add_jumpstart_model_id_version_tags(): version = None assert [ {"Key": "random key", "Value": "random_value"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + {"Key": "sagemaker-sdk:jumpstart-inference-config-name", "Value": "inference_config_name"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=inference_config_name, + scope=JumpStartScriptScope.INFERENCE, + ) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + {"Key": "sagemaker-sdk:jumpstart-training-config-name", "Value": "training_config_name"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=training_config_name, + scope=JumpStartScriptScope.TRAINING, + ) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=training_config_name, ) @@ -1319,10 +1385,8 @@ def test_no_model_id_no_version_found(self): mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1336,10 +1400,8 @@ def test_model_id_no_version_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id", None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id", None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1353,10 +1415,66 @@ def test_no_model_id_version_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, "model_version"), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, "model_version", None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_no_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_inference_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, "config_name", None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_training_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "config_name"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, "config_name"), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_both_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "inference_config_name"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "training_config_name"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, "inference_config_name", "training_config_name"), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1371,10 +1489,8 @@ def test_model_id_version_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id", "model_version"), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id", "model_version", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1391,10 +1507,8 @@ def test_multiple_model_id_versions_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1411,10 +1525,8 @@ def test_multiple_model_id_versions_found_aliases_consistent(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id_1", "model_version_1"), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id_1", "model_version_1", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1431,10 +1543,26 @@ def test_multiple_model_id_versions_found_aliases_inconsistent(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_multiple_config_names_found_aliases_inconsistent(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.MODEL_ID, "Value": "model_id_1"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version_1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name_1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name_2"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id_1", "model_version_1", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1529,6 +1657,8 @@ def test_get_jumpstart_config_names_success( "neuron-inference-budget", "gpu-inference-budget", "gpu-inference", + "gpu-inference-model-package", + "gpu-accelerated", ] @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1599,22 +1729,44 @@ def test_get_jumpstart_benchmark_stats_full_list( ) == { "neuron-inference": { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, "neuron-inference-budget": { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, "gpu-inference-budget": { "ml.p3.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, "gpu-inference": { "ml.p3.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ] + }, + "gpu-inference-model-package": { + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ] + }, + "gpu-accelerated": { + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, } @@ -1634,12 +1786,16 @@ def test_get_jumpstart_benchmark_stats_partial_list( ) == { "neuron-inference-budget": { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, "gpu-inference-budget": { "ml.p3.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, } @@ -1659,7 +1815,9 @@ def test_get_jumpstart_benchmark_stats_single_stat( ) == { "neuron-inference-budget": { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] } } @@ -1687,6 +1845,16 @@ def test_get_jumpstart_benchmark_stats_training( ): patched_get_model_specs.side_effect = get_base_spec_with_prototype_configs + print( + utils.get_benchmark_stats( + "mock-region", + "mock-model", + "mock-model-version", + scope=JumpStartScriptScope.TRAINING, + config_names=["neuron-training", "gpu-training-budget"], + ) + ) + assert utils.get_benchmark_stats( "mock-region", "mock-model", @@ -1696,97 +1864,199 @@ def test_get_jumpstart_benchmark_stats_training( ) == { "neuron-training": { "ml.tr1n1.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ], "ml.tr1n1.4xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "50", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ) ], }, "gpu-training-budget": { "ml.p3.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": "1"} + ) ] }, } -class TestUserAgent: - @patch("sagemaker.jumpstart.utils.os.getenv") - def test_get_jumpstart_user_agent_extra_suffix(self, mock_getenv): - mock_getenv.return_value = False - assert utils.get_jumpstart_user_agent_extra_suffix("some-id", "some-version").endswith( - "md/js_model_id#some-id md/js_model_ver#some-version" - ) - mock_getenv.return_value = None - assert utils.get_jumpstart_user_agent_extra_suffix("some-id", "some-version").endswith( - "md/js_model_id#some-id md/js_model_ver#some-version" - ) - mock_getenv.return_value = "True" - assert not utils.get_jumpstart_user_agent_extra_suffix("some-id", "some-version").endswith( - "md/js_model_id#some-id md/js_model_ver#some-version" - ) - mock_getenv.return_value = True - assert not utils.get_jumpstart_user_agent_extra_suffix("some-id", "some-version").endswith( - "md/js_model_id#some-id md/js_model_ver#some-version" - ) +def test_extract_metrics_from_deployment_configs(): + configs = get_base_deployment_configs_metadata() + configs[0].benchmark_metrics = None + configs[2].deployment_args = None - @patch("sagemaker.jumpstart.utils.botocore.session") - @patch("sagemaker.jumpstart.utils.botocore.config.Config") - @patch("sagemaker.jumpstart.utils.get_jumpstart_user_agent_extra_suffix") - @patch("sagemaker.jumpstart.utils.boto3.Session") - @patch("sagemaker.jumpstart.utils.boto3.client") - @patch("sagemaker.jumpstart.utils.Session") - def test_get_default_jumpstart_session_with_user_agent_suffix( - self, - mock_sm_session, - mock_boto3_client, - mock_botocore_session, - mock_get_jumpstart_user_agent_extra_suffix, - mock_botocore_config, - mock_boto3_session, - ): - utils.get_default_jumpstart_session_with_user_agent_suffix("model_id", "model_version") - mock_boto3_session.get_session.assert_called_once_with() - mock_get_jumpstart_user_agent_extra_suffix.assert_called_once_with( - "model_id", "model_version" - ) - mock_botocore_config.assert_called_once_with( - user_agent_extra=mock_get_jumpstart_user_agent_extra_suffix.return_value - ) - mock_botocore_session.assert_called_once_with( - region_name=JUMPSTART_DEFAULT_REGION_NAME, - botocore_session=mock_boto3_session.get_session.return_value, - ) - mock_boto3_client.assert_has_calls( + data = utils.get_metrics_from_deployment_configs(configs) + + for key in data: + assert len(data[key]) == (len(configs) - 2) + + +@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour") +def test_add_instance_rate_stats_to_benchmark_metrics( + mock_get_instance_rate_per_hour, +): + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + } + + err, out = utils.add_instance_rate_stats_to_benchmark_metrics( + "us-west-2", + { + "ml.p2.xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ], + "ml.gd4.xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ], + }, + ) + + assert err is None + for key in out: + assert len(out[key]) == 2 + for metric in out[key]: + if metric.name == "Instance Rate": + assert metric.to_json() == { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + "concurrency": None, + } + + +def test__normalize_benchmark_metrics(): + rate, metrics = utils._normalize_benchmark_metrics( + [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + JumpStartBenchmarkStat( + {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 2} + ), + JumpStartBenchmarkStat( + {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 2} + ), + JumpStartBenchmarkStat( + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76", "concurrency": None} + ), + ] + ) + + assert rate == JumpStartBenchmarkStat( + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76", "concurrency": None} + ) + assert metrics == { + 1: [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + JumpStartBenchmarkStat( + {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + ], + 2: [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 2} + ), + JumpStartBenchmarkStat( + {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 2} + ), + ], + } + + +@pytest.mark.parametrize( + "name, unit, expected", + [ + ("latency", "sec", "Latency, TTFT (P50 in sec)"), + ("throughput", "tokens/sec", "Throughput (P50 in tokens/sec/user)"), + ], +) +def test_normalize_benchmark_metric_column_name(name, unit, expected): + out = utils._normalize_benchmark_metric_column_name(name, unit) + + assert out == expected + + +@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour") +def test_add_instance_rate_stats_to_benchmark_metrics_client_ex( + mock_get_instance_rate_per_hour, +): + mock_get_instance_rate_per_hour.side_effect = ClientError( + { + "Error": { + "Message": "is not authorized to perform: pricing:GetProducts", + "Code": "AccessDenied", + }, + }, + "GetProducts", + ) + + err, out = utils.add_instance_rate_stats_to_benchmark_metrics( + "us-west-2", + { + "ml.p2.xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ], + }, + ) + + assert err["Message"] == "is not authorized to perform: pricing:GetProducts" + assert err["Code"] == "AccessDenied" + for key in out: + assert len(out[key]) == 1 + + +@pytest.mark.parametrize( + "stats, expected", + [ + (None, True), + ( [ - call( - "sagemaker", - region_name=JUMPSTART_DEFAULT_REGION_NAME, - config=mock_botocore_config.return_value, - ), - call( - "sagemaker-runtime", - region_name=JUMPSTART_DEFAULT_REGION_NAME, - config=mock_botocore_config.return_value, - ), + JumpStartBenchmarkStat( + { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + "concurrency": None, + } + ) ], - any_order=True, - ) + True, + ), + ( + [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": None} + ) + ], + False, + ), + ], +) +def test_has_instance_rate_stat(stats, expected): + assert utils.has_instance_rate_stat(stats) is expected - @patch("botocore.client.BaseClient._make_request") - def test_get_default_jumpstart_session_with_user_agent_suffix_http_header( - self, - mock_make_request, - ): - session = utils.get_default_jumpstart_session_with_user_agent_suffix( - "model_id", "model_version" - ) - try: - session.sagemaker_client.list_endpoints() - except Exception: - pass - assert ( - "md/js_model_id#model_id md/js_model_ver#model_version" - in mock_make_request.call_args[0][1]["headers"]["User-Agent"] - ) +@pytest.mark.parametrize( + "data, expected", + [(None, []), ([], []), (get_base_deployment_configs_metadata(), get_base_deployment_configs())], +) +def test_deployment_config_response_data(data, expected): + out = utils.deployment_config_response_data(data) + assert out == expected diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index e599d4eee1..de274f0374 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -12,9 +12,10 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy -from typing import List, Optional +from typing import List, Dict, Any, Optional import boto3 +from sagemaker.compute_resource_requirements import ResourceRequirements from sagemaker.jumpstart.cache import JumpStartModelsCache from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_REGION_NAME, @@ -27,6 +28,10 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, + JumpStartModelInitKwargs, + DeploymentConfigMetadata, + JumpStartModelDeployKwargs, + JumpStartBenchmarkStat, ) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import get_formatted_manifest @@ -43,6 +48,8 @@ SPECIAL_MODEL_SPECS_DICT, TRAINING_CONFIG_RANKINGS, TRAINING_CONFIGS, + DEPLOYMENT_CONFIGS, + INIT_KWARGS, ) @@ -233,6 +240,45 @@ def get_base_spec_with_prototype_configs( return JumpStartModelSpecs(spec) +def get_base_spec_with_prototype_configs_with_missing_benchmarks( + region: str = None, + model_id: str = None, + version: str = None, + s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, +) -> JumpStartModelSpecs: + spec = copy.deepcopy(BASE_SPEC) + copy_inference_configs = copy.deepcopy(INFERENCE_CONFIGS) + copy_inference_configs["inference_configs"]["neuron-inference"]["benchmark_metrics"] = None + + inference_configs = {**copy_inference_configs, **INFERENCE_CONFIG_RANKINGS} + training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS} + + spec.update(inference_configs) + spec.update(training_configs) + + return JumpStartModelSpecs(spec) + + +def get_prototype_spec_with_configs( + region: str = None, + model_id: str = None, + version: str = None, + s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + hub_arn: str = None, + sagemaker_session: boto3.Session = None, +) -> JumpStartModelSpecs: + spec = copy.deepcopy(PROTOTYPICAL_MODEL_SPECS_DICT[model_id]) + inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS} + + spec.update(inference_configs) + spec.update(training_configs) + + return JumpStartModelSpecs(spec) + + def patched_retrieval_function( _modelCacheObj: JumpStartModelsCache, key: JumpStartCachedContentKey, @@ -289,3 +335,104 @@ def overwrite_dictionary( base_dictionary[key] = value return base_dictionary + + +def get_base_deployment_configs_with_acceleration_configs() -> List[Dict[str, Any]]: + configs = copy.deepcopy(DEPLOYMENT_CONFIGS) + configs[0]["AccelerationConfigs"] = [ + {"Type": "Speculative-Decoding", "Enabled": True, "Spec": {"Version": "0.1"}} + ] + return configs + + +def get_mock_init_kwargs( + model_id: str, config_name: Optional[str] = None +) -> JumpStartModelInitKwargs: + kwargs = JumpStartModelInitKwargs( + model_id=model_id, + model_type=JumpStartModelType.OPEN_WEIGHTS, + model_data=INIT_KWARGS.get("model_data"), + image_uri=INIT_KWARGS.get("image_uri"), + instance_type=INIT_KWARGS.get("instance_type"), + env=INIT_KWARGS.get("env"), + resources=ResourceRequirements(), + config_name=config_name, + ) + setattr(kwargs, "model_reference_arn", None) + setattr(kwargs, "hub_content_type", None) + return kwargs + + +def get_base_deployment_configs_metadata( + omit_benchmark_metrics: bool = False, +) -> List[DeploymentConfigMetadata]: + specs = ( + get_base_spec_with_prototype_configs_with_missing_benchmarks() + if omit_benchmark_metrics + else get_base_spec_with_prototype_configs() + ) + configs = [] + for config_name in specs.inference_configs.config_rankings.get("overall").rankings: + jumpstart_config = specs.inference_configs.configs.get(config_name) + benchmark_metrics = jumpstart_config.benchmark_metrics + + if benchmark_metrics: + for instance_type in benchmark_metrics: + benchmark_metrics[instance_type].append( + JumpStartBenchmarkStat( + { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + "concurrency": None, + } + ) + ) + + configs.append( + DeploymentConfigMetadata( + config_name=config_name, + metadata_config=jumpstart_config, + init_kwargs=get_mock_init_kwargs( + get_base_spec_with_prototype_configs().model_id, config_name + ), + deploy_kwargs=JumpStartModelDeployKwargs( + model_id=get_base_spec_with_prototype_configs().model_id, + ), + ) + ) + return configs + + +def get_base_deployment_configs( + omit_benchmark_metrics: bool = False, +) -> List[Dict[str, Any]]: + configs = [] + for config in get_base_deployment_configs_metadata(omit_benchmark_metrics): + config_json = config.to_json() + if config_json["BenchmarkMetrics"]: + config_json["BenchmarkMetrics"] = { + config.deployment_args.instance_type: config_json["BenchmarkMetrics"].get( + config.deployment_args.instance_type + ) + } + configs.append(config_json) + return configs + + +def append_instance_stat_metrics( + metrics: Dict[str, List[JumpStartBenchmarkStat]] +) -> Dict[str, List[JumpStartBenchmarkStat]]: + if metrics is not None: + for key in metrics: + metrics[key].append( + JumpStartBenchmarkStat( + { + "name": "Instance Rate", + "value": "3.76", + "unit": "USD/Hrs", + "concurrency": None, + } + ) + ) + return metrics diff --git a/tests/unit/sagemaker/serve/builder/test_djl_builder.py b/tests/unit/sagemaker/serve/builder/test_djl_builder.py index 7b0c67f326..474403498c 100644 --- a/tests/unit/sagemaker/serve/builder/test_djl_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_djl_builder.py @@ -188,6 +188,7 @@ def test_tune_for_djl_local_container_deep_ping_ex( tuned_model = model.tune() assert tuned_model.env == mock_default_configs + @patch("sagemaker.serve.builder.djl_builder._get_model_config_properties_from_hf") @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", @@ -211,7 +212,10 @@ def test_tune_for_djl_local_container_load_ex( mock_get_ram_usage_mb, mock_is_jumpstart_model, mock_telemetry, + mock_get_model_config_properties_from_hf, ): + mock_get_model_config_properties_from_hf.return_value = {} + builder = ModelBuilder( model=mock_model_id, schema_builder=mock_schema_builder, diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 2065e86818..248955c273 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -11,10 +11,12 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, Mock import unittest +from sagemaker.enums import Tag +from sagemaker.serve import SchemaBuilder from sagemaker.serve.builder.model_builder import ModelBuilder from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.exceptions import ( @@ -23,6 +25,7 @@ LocalModelOutOfMemoryException, LocalModelInvocationException, ) +from tests.unit.sagemaker.serve.constants import DEPLOYMENT_CONFIGS mock_model_id = "huggingface-llm-amazon-falconlite" mock_t5_model_id = "google/flan-t5-xxl" @@ -84,6 +87,74 @@ "/artifacts/inference-prepack/v1.0.0/" ) +mock_optimization_job_response = { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:312206380606:optimization-job" + "/modelbuilderjob-c9b28846f963497ca540010b2aa2ec8d", + "OptimizationJobStatus": "COMPLETED", + "OptimizationStartTime": "", + "OptimizationEndTime": "", + "CreationTime": "", + "LastModifiedTime": "", + "OptimizationJobName": "modelbuilderjob-c9b28846f963497ca540010b2aa2ec8d", + "ModelSource": { + "S3": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-8b-instruct/artifacts/inference-prepack/v1.1.0/" + } + }, + "OptimizationEnvironment": { + "ENDPOINT_SERVER_TIMEOUT": "3600", + "HF_MODEL_ID": "/opt/ml/model", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "OPTION_DTYPE": "fp16", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "SAGEMAKER_PROGRAM": "inference.py", + }, + "DeploymentInstanceType": "ml.inf2.48xlarge", + "OptimizationConfigs": [ + { + "ModelCompilationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-neuronx-sdk2.18.2", + "OverrideEnvironment": { + "OPTION_DTYPE": "fp16", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + }, + } + } + ], + "OutputConfig": { + "S3OutputLocation": "s3://dont-delete-ss-jarvis-integ-test-312206380606-us-west-2/" + "code/a75a061aba764f2aa014042bcdc1464b/" + }, + "OptimizationOutput": { + "RecommendedInferenceImage": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "djl-inference:0.28.0-neuronx-sdk2.18.2" + }, + "RoleArn": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628", + "StoppingCondition": {"MaxRuntimeInSeconds": 36000}, + "ResponseMetadata": { + "RequestId": "704c7bcd-41e2-4d73-8039-262ff6a3f38b", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "704c7bcd-41e2-4d73-8039-262ff6a3f38b", + "content-type": "application/x-amz-json-1.1", + "content-length": "1787", + "date": "Thu, 04 Jul 2024 16:55:50 GMT", + }, + "RetryAttempts": 0, + }, +} + class TestJumpStartBuilder(unittest.TestCase): @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @@ -303,7 +374,7 @@ def test_tune_for_tgi_js_local_container_sharding_not_supported( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalDeepPingException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalDeepPingException("mock_exception")}, ) def test_tune_for_tgi_js_local_container_deep_ping_ex( self, @@ -353,7 +424,7 @@ def test_tune_for_tgi_js_local_container_deep_ping_ex( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelLoadException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelLoadException("mock_exception")}, ) def test_tune_for_tgi_js_local_container_load_ex( self, @@ -403,7 +474,7 @@ def test_tune_for_tgi_js_local_container_load_ex( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelOutOfMemoryException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelOutOfMemoryException("mock_exception")}, ) def test_tune_for_tgi_js_local_container_oom_ex( self, @@ -453,7 +524,7 @@ def test_tune_for_tgi_js_local_container_oom_ex( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")}, ) def test_tune_for_tgi_js_local_container_invoke_ex( self, @@ -568,7 +639,7 @@ def test_tune_for_djl_js_local_container( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")}, ) def test_tune_for_djl_js_local_container_invoke_ex( self, @@ -724,3 +795,591 @@ def test_js_gated_model_ex( ValueError, lambda: builder.build(), ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_list_deployment_configs( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.side_effect = ( + lambda: DEPLOYMENT_CONFIGS + ) + + configs = builder.list_deployment_configs() + + self.assertEqual(configs, DEPLOYMENT_CONFIGS) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_get_deployment_config( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + expected = DEPLOYMENT_CONFIGS[0] + mock_pre_trained_model.return_value.deployment_config = expected + + self.assertEqual(builder.get_deployment_config(), expected) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_set_deployment_config( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + builder.build() + builder.set_deployment_config("config-1", "ml.g5.24xlarge") + + mock_pre_trained_model.return_value.set_deployment_config.assert_called_with( + "config-1", "ml.g5.24xlarge" + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_set_deployment_config_ex( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + self.assertRaisesRegex( + Exception, + "Cannot set deployment config to an uninitialized model.", + lambda: ModelBuilder( + model="facebook/galactica-mock-model-id", schema_builder=mock_schema_builder + ).set_deployment_config("config-2", "ml.g5.24xlarge"), + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_display_benchmark_metrics( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.side_effect = ( + lambda: DEPLOYMENT_CONFIGS + ) + + builder.list_deployment_configs() + + builder.display_benchmark_metrics() + + mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once() + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_display_benchmark_metrics_initial( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.side_effect = ( + lambda: DEPLOYMENT_CONFIGS + ) + + builder.display_benchmark_metrics() + + mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once() + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_fine_tuned_model_with_fine_tuning_model_path( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri + mock_fine_tuning_model_path = "s3://test" + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species of turtle native to the brackish " + "coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a species of turtle native to the " + "brackish coastal tidal marshes of the east coast." + } + ] + builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + model_metadata={ + "FINE_TUNING_MODEL_PATH": mock_fine_tuning_model_path, + }, + ) + model = builder.build() + + model.model_data["S3DataSource"].__setitem__.assert_called_with( + "S3Uri", mock_fine_tuning_model_path + ) + mock_pre_trained_model.return_value.add_tags.assert_called_with( + {"Key": Tag.FINE_TUNING_MODEL_PATH, "Value": mock_fine_tuning_model_path} + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_fine_tuned_model_with_fine_tuning_job_name( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_serve_settings, + mock_telemetry, + ): + mock_fine_tuning_model_path = "s3://test" + mock_sagemaker_session = Mock() + mock_sagemaker_session.sagemaker_client.describe_training_job.return_value = { + "ModelArtifacts": { + "S3ModelArtifacts": mock_fine_tuning_model_path, + } + } + mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri + mock_fine_tuning_job_name = "mock-job" + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species of turtle native to the brackish " + "coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a species of turtle native to the " + "brackish coastal tidal marshes of the east coast." + } + ] + builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + model_metadata={"FINE_TUNING_JOB_NAME": mock_fine_tuning_job_name}, + sagemaker_session=mock_sagemaker_session, + ) + model = builder.build(sagemaker_session=mock_sagemaker_session) + + mock_sagemaker_session.sagemaker_client.describe_training_job.assert_called_once_with( + TrainingJobName=mock_fine_tuning_job_name + ) + + model.model_data["S3DataSource"].__setitem__.assert_any_call( + "S3Uri", mock_fine_tuning_model_path + ) + mock_pre_trained_model.return_value.add_tags.assert_called_with( + [ + {"key": Tag.FINE_TUNING_JOB_NAME, "value": mock_fine_tuning_job_name}, + {"key": Tag.FINE_TUNING_MODEL_PATH, "value": mock_fine_tuning_model_path}, + ] + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_quantize_for_jumpstart( + self, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + + mock_pysdk_model = Mock() + mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"} + mock_pysdk_model.model_data = mock_model_data + mock_pysdk_model.image_uri = mock_tgi_image_uri + mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS + mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_jumpstart( + accept_eula=True, + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + env_vars={ + "OPTION_TENSOR_PARALLEL_DEGREE": "1", + "OPTION_MAX_ROLLING_BATCH_SIZE": "2", + }, + output_path="s3://bucket/code/", + ) + + self.assertIsNotNone(out_put) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model") + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_optimize_compile_for_jumpstart_without_neuron_env( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + mock_sagemaker_session.wait_for_optimization_job.side_effect = ( + lambda *args: mock_optimization_job_response + ) + + mock_pre_trained_model.return_value = MagicMock() + mock_pre_trained_model.return_value.env = dict() + mock_pre_trained_model.return_value.model_data = mock_model_data + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.return_value = ( + DEPLOYMENT_CONFIGS + ) + mock_pre_trained_model.return_value.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pre_trained_model.return_value._metadata_configs = None + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + optimized_model = model_builder.optimize( + accept_eula=True, + instance_type="ml.inf2.48xlarge", + compilation_config={ + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + }, + output_path="s3://bucket/code/", + ) + + self.assertEqual( + optimized_model.image_uri, + mock_optimization_job_response["OptimizationOutput"]["RecommendedInferenceImage"], + ) + self.assertEqual( + optimized_model.model_data["S3DataSource"]["S3Uri"], + mock_optimization_job_response["OutputConfig"]["S3OutputLocation"], + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model") + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_optimize_compile_for_jumpstart_with_neuron_env( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_js_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + mock_metadata_config = Mock() + mock_sagemaker_session.wait_for_optimization_job.side_effect = ( + lambda *args: mock_optimization_job_response + ) + + mock_metadata_config.resolved_config = { + "supported_inference_instance_types": ["ml.inf2.48xlarge"], + "hosting_neuron_model_id": "neuron_model_id", + } + + mock_js_model.return_value = MagicMock() + mock_js_model.return_value.env = dict() + + mock_pre_trained_model.return_value = MagicMock() + mock_pre_trained_model.return_value.env = dict() + mock_pre_trained_model.return_value.config_name = "config_name" + mock_pre_trained_model.return_value.model_data = mock_model_data + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.return_value = ( + DEPLOYMENT_CONFIGS + ) + mock_pre_trained_model.return_value.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pre_trained_model.return_value._metadata_configs = { + "config_name": mock_metadata_config + } + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + optimized_model = model_builder.optimize( + accept_eula=True, + instance_type="ml.inf2.48xlarge", + compilation_config={ + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + }, + output_path="s3://bucket/code/", + ) + + self.assertEqual( + optimized_model.image_uri, + mock_optimization_job_response["OptimizationOutput"]["RecommendedInferenceImage"], + ) + self.assertEqual( + optimized_model.model_data["S3DataSource"]["S3Uri"], + mock_optimization_job_response["OutputConfig"]["S3OutputLocation"], + ) + self.assertEqual(optimized_model.env["OPTION_TENSOR_PARALLEL_DEGREE"], "2") + self.assertEqual(optimized_model.env["OPTION_N_POSITIONS"], "2048") + self.assertEqual(optimized_model.env["OPTION_DTYPE"], "fp16") + self.assertEqual(optimized_model.env["OPTION_ROLLING_BATCH"], "auto") + self.assertEqual(optimized_model.env["OPTION_MAX_ROLLING_BATCH_SIZE"], "4") + self.assertEqual(optimized_model.env["OPTION_NEURON_OPTIMIZE_LEVEL"], "2") diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index d7e2c0aac2..81d57243ea 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -13,6 +13,8 @@ from __future__ import absolute_import from unittest.mock import MagicMock, patch, Mock, mock_open +import pytest + import unittest from pathlib import Path from copy import deepcopy @@ -45,7 +47,7 @@ mock_image_uri = "abcd/efghijk" mock_1p_dlc_image_uri = "763104351884.dkr.ecr.us-east-1.amazonaws.com" -mock_role_arn = "sample role arn" +mock_role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" mock_s3_model_data_url = "sample s3 data url" mock_secret_key = "mock_secret_key" mock_instance_type = "mock instance type" @@ -147,7 +149,7 @@ def test_model_server_override_djl_without_model_or_mlflow(self, mock_serve_sett ) self.assertRaisesRegex( Exception, - "Missing required parameter `model` or 'ml_flow' path or inf_spec", + "Missing required parameter `model` or 'ml_flow' path", builder.build, Mode.SAGEMAKER_ENDPOINT, mock_role_arn, @@ -168,26 +170,12 @@ def test_model_server_override_torchserve_with_model( mock_build_for_ts.assert_called_once() - @patch("sagemaker.serve.builder.model_builder._ServeSettings") - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve") - def test_model_server_override_torchserve_with_inf_spec( - self, mock_build_for_ts, mock_serve_settings - ): - mock_setting_object = mock_serve_settings.return_value - mock_setting_object.role_arn = mock_role_arn - mock_setting_object.s3_model_data_url = mock_s3_model_data_url - - builder = ModelBuilder(model_server=ModelServer.TORCHSERVE, inference_spec="some value") - builder.build(sagemaker_session=mock_session) - - mock_build_for_ts.assert_called_once() - @patch("sagemaker.serve.builder.model_builder._ServeSettings") def test_model_server_override_torchserve_without_model_or_mlflow(self, mock_serve_settings): builder = ModelBuilder(model_server=ModelServer.TORCHSERVE) self.assertRaisesRegex( Exception, - "Missing required parameter `model` or 'ml_flow' path or inf_spec", + "Missing required parameter `model` or 'ml_flow' path", builder.build, Mode.SAGEMAKER_ENDPOINT, mock_role_arn, @@ -257,6 +245,10 @@ def test_model_server_override_transformers_with_model( mock_build_for_ts.assert_called_once() @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -275,6 +267,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_container.side_effect = lambda model, region, instance_type: ( @@ -311,7 +304,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -361,6 +354,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( self.assertEqual(build_result.serve_settings, mock_setting_object) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -379,6 +376,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc( mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_container.side_effect = lambda model, region, instance_type: ( @@ -414,7 +412,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc( mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -521,7 +519,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec( if inference_spec == mock_inference_spec and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -563,6 +561,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec( self.assertEqual(build_result.serve_settings, mock_setting_object) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -581,6 +583,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_container.side_effect = lambda model, region, instance_type: ( @@ -617,7 +620,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -665,6 +668,10 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( self.assertEqual("sample agent ModelBuilder", user_agent) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder.save_xgboost") @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @@ -685,6 +692,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model( mock_prepare_for_torchserve, mock_detect_fw_version, mock_save_xgb, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_container.side_effect = lambda model, region, instance_type: ( @@ -721,7 +729,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model( mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -947,7 +955,7 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo if inference_spec == mock_inference_spec and model_server == ModelServer.TORCHSERVE else None ) - mock_sagemaker_endpoint_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_sagemaker_endpoint_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -1013,6 +1021,10 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo ) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -1033,6 +1045,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_fw_version.return_value = framework, version @@ -1069,7 +1082,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -2233,6 +2246,10 @@ def test_build_mlflow_model_s3_input_tensorflow_serving_local_mode_happy( assert isinstance(predictor, TensorflowServingLocalPredictor) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.tf_serving_builder.prepare_for_tf_serving") @patch("sagemaker.serve.builder.model_builder.S3Downloader.list") @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @@ -2261,6 +2278,7 @@ def test_build_tensorflow_serving_non_mlflow_case( mock_detect_fw_version, mock_s3_downloader, mock_prepare_for_tf_serving, + mock_is_jumpstart_model_id, ): mock_s3_downloader.return_value = [] mock_detect_container.return_value = mock_image_uri @@ -2310,6 +2328,88 @@ def test_build_tensorflow_serving_non_mlflow_case( mock_session, ) + @pytest.mark.skip(reason="Implementation not completed") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_optimize(self, mock_send_telemetry, mock_get_serve_setting): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + builder = ModelBuilder( + model_path=MODEL_PATH, + schema_builder=schema_builder, + model=mock_fw_model, + sagemaker_session=mock_sagemaker_session, + ) + + job_name = "my-optimization-job" + instance_type = "ml.inf1.xlarge" + output_path = "s3://my-bucket/output" + quantization_config = { + "Image": "quantization-image-uri", + "OverrideEnvironment": {"ENV_VAR": "value"}, + } + compilation_config = { + "Image": "compilation-image-uri", + "OverrideEnvironment": {"ENV_VAR": "value"}, + } + env_vars = {"Var1": "value", "Var2": "value"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + expected_create_optimization_job_args = { + "ModelSource": {"S3": {"S3Uri": MODEL_PATH, "ModelAccessConfig": {"AcceptEula": True}}}, + "DeploymentInstanceType": instance_type, + "OptimizationEnvironment": env_vars, + "OptimizationConfigs": [ + {"ModelQuantizationConfig": quantization_config}, + {"ModelCompilationConfig": compilation_config}, + ], + "OutputConfig": {"S3OutputLocation": output_path, "KmsKeyId": kms_key}, + "RoleArn": mock_role_arn, + "OptimizationJobName": job_name, + "StoppingCondition": {"MaxRuntimeInSeconds": max_runtime_in_sec}, + "Tags": [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + "VpcConfig": vpc_config, + } + + mock_sagemaker_session.sagemaker_client.create_optimization_job.return_value = { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job" + } + + builder.optimize( + instance_type=instance_type, + output_path=output_path, + role=mock_role_arn, + job_name=job_name, + quantization_config=quantization_config, + compilation_config=compilation_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + mock_send_telemetry.assert_called_once() + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + **expected_create_optimization_job_args + ) + def test_handle_mlflow_input_without_mlflow_model_path(self): builder = ModelBuilder(model_metadata={}) assert not builder._has_mlflow_arguments() @@ -2493,3 +2593,155 @@ def test_set_tracking_arn_mlflow_not_installed(self): builder.set_tracking_arn, tracking_arn, ) + + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_local_mode(self, mock_get_serve_setting): + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", mode=Mode.LOCAL_CONTAINER + ) + + self.assertRaisesRegex( + ValueError, + "Model optimization is only supported in Sagemaker Endpoint Mode.", + lambda: model_builder.optimize( + quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}} + ), + ) + + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_exclusive_args(self, mock_get_serve_setting): + mock_sagemaker_session = Mock() + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + sagemaker_session=mock_sagemaker_session, + ) + + self.assertRaisesRegex( + ValueError, + "Quantization config and compilation config are mutually exclusive.", + lambda: model_builder.optimize( + quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + ), + ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_for_hf_with_custom_s3_path( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-8B-Instruc"} + + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HUGGING_FACE_HUB_TOKEN": "token"}, + model_metadata={ + "CUSTOM_MODEL_PATH": "s3://bucket/path/", + }, + ) + + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_hf( + job_name="job_name-123", + instance_type="ml.g5.2xlarge", + role_arn="role-arn", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + output_path="s3://bucket/code/", + ) + + print(out_put) + + self.assertEqual(model_builder.role_arn, "role-arn") + self.assertEqual(model_builder.instance_type, "ml.g5.2xlarge") + self.assertEqual(model_builder.pysdk_model.env["OPTION_QUANTIZE"], "awq") + self.assertEqual( + out_put, + { + "OptimizationJobName": "job_name-123", + "DeploymentInstanceType": "ml.g5.2xlarge", + "RoleArn": "role-arn", + "ModelSource": {"S3": {"S3Uri": "s3://bucket/code/code/"}}, + "OptimizationConfigs": [ + {"ModelQuantizationConfig": {"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}} + ], + "OutputConfig": {"S3OutputLocation": "s3://bucket/code/"}, + }, + ) + + @patch( + "sagemaker.serve.builder.model_builder.download_huggingface_model_metadata", autospec=True + ) + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_for_hf_without_custom_s3_path( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + mock_download_huggingface_model_metadata, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-8B-Instruc"} + + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HUGGING_FACE_HUB_TOKEN": "token"}, + ) + + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_hf( + job_name="job_name-123", + instance_type="ml.g5.2xlarge", + role_arn="role-arn", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + output_path="s3://bucket/code/", + ) + + self.assertEqual(model_builder.role_arn, "role-arn") + self.assertEqual(model_builder.instance_type, "ml.g5.2xlarge") + self.assertEqual(model_builder.pysdk_model.env["OPTION_QUANTIZE"], "awq") + self.assertEqual( + out_put, + { + "OptimizationJobName": "job_name-123", + "DeploymentInstanceType": "ml.g5.2xlarge", + "RoleArn": "role-arn", + "ModelSource": {"S3": {"S3Uri": "s3://bucket/code/code/"}}, + "OptimizationConfigs": [ + {"ModelQuantizationConfig": {"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}} + ], + "OutputConfig": {"S3OutputLocation": "s3://bucket/code/"}, + }, + ) diff --git a/tests/unit/sagemaker/serve/constants.py b/tests/unit/sagemaker/serve/constants.py index db9dd623d8..5a4679747b 100644 --- a/tests/unit/sagemaker/serve/constants.py +++ b/tests/unit/sagemaker/serve/constants.py @@ -15,3 +15,153 @@ MOCK_IMAGE_CONFIG = {"RepositoryAccessMode": "Vpc"} MOCK_VPC_CONFIG = {"Subnets": ["subnet-1234"], "SecurityGroupIds": ["sg123"]} +DEPLOYMENT_CONFIGS = [ + { + "ConfigName": "neuron-inference", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentArgs": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "neuron-inference-budget", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentArgs": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "gpu-inference-budget", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentArgs": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "gpu-inference", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentArgs": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, +] diff --git a/tests/unit/sagemaker/serve/model_server/tei/test_server.py b/tests/unit/sagemaker/serve/model_server/tei/test_server.py index 2344a61fbc..cc1226702f 100644 --- a/tests/unit/sagemaker/serve/model_server/tei/test_server.py +++ b/tests/unit/sagemaker/serve/model_server/tei/test_server.py @@ -135,6 +135,7 @@ def test_upload_artifacts_sagemaker_tei_server(self, mock_uploader): sagemaker_session=mock_session, s3_model_data_url=S3_URI, image=TEI_IMAGE, + should_upload_artifacts=True, ) mock_uploader.upload.assert_called_once() diff --git a/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py index 3d3bac0935..b9cce13dbb 100644 --- a/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py +++ b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py @@ -92,6 +92,7 @@ def test_upload_artifacts_sagemaker_triton_server(self, mock_upload, mock_platfo secret_key=SECRET_KEY, s3_model_data_url=S3_URI, image=CPU_TF_IMAGE, + should_upload_artifacts=True, ) mock_upload.assert_called_once_with(mock_session, MODEL_PATH, "mock_model_data_uri", ANY) diff --git a/tests/unit/sagemaker/serve/model_server/triton/test_server.py b/tests/unit/sagemaker/serve/model_server/triton/test_server.py index c80c4296e7..3f571424ed 100644 --- a/tests/unit/sagemaker/serve/model_server/triton/test_server.py +++ b/tests/unit/sagemaker/serve/model_server/triton/test_server.py @@ -172,6 +172,7 @@ def test_upload_artifacts_sagemaker_triton_server(self, mock_upload, mock_platfo secret_key=SECRET_KEY, s3_model_data_url=S3_URI, image=GPU_TRITON_IMAGE, + should_upload_artifacts=True, ) mock_upload.assert_called_once_with(mock_session, MODEL_REPO, "mock_model_data_uri", ANY) diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py new file mode 100644 index 0000000000..712382f068 --- /dev/null +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -0,0 +1,403 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest +from unittest.mock import Mock, patch + +import pytest + +from sagemaker.enums import Tag +from sagemaker.serve.utils.optimize_utils import ( + _generate_optimized_model, + _update_environment_variables, + _is_image_compatible_with_optimization_job, + _extract_speculative_draft_model_provider, + _extracts_and_validates_speculative_model_source, + _is_s3_uri, + _generate_additional_model_data_sources, + _generate_channel_name, + _extract_optimization_config_and_env, + _normalize_local_model_path, + _is_optimized, + _custom_speculative_decoding, + _is_inferentia_or_trainium, +) + +mock_optimization_job_output = { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:312206380606:optimization-job/" + "modelbuilderjob-3cbf9c40b63c455d85b60033f9a01691", + "OptimizationJobStatus": "COMPLETED", + "OptimizationJobName": "modelbuilderjob-3cbf9c40b63c455d85b60033f9a01691", + "ModelSource": { + "S3": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/" + } + }, + "OptimizationEnvironment": { + "ENDPOINT_SERVER_TIMEOUT": "3600", + "HF_MODEL_ID": "/opt/ml/model", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "SAGEMAKER_PROGRAM": "inference.py", + }, + "DeploymentInstanceType": "ml.g5.2xlarge", + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124", + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + } + } + ], + "OutputConfig": {"S3OutputLocation": "s3://quicksilver-model-data/llama-3-8b/quantized-1/"}, + "OptimizationOutput": { + "RecommendedInferenceImage": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124" + }, + "RoleArn": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20240116T151132", + "StoppingCondition": {"MaxRuntimeInSeconds": 36000}, + "ResponseMetadata": { + "RequestId": "a95253d5-c045-4708-8aac-9f0d327515f7", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "a95253d5-c045-4708-8aac-9f0d327515f7", + "content-type": "application/x-amz-json-1.1", + "content-length": "1371", + "date": "Fri, 21 Jun 2024 04:27:42 GMT", + }, + "RetryAttempts": 0, + }, +} + + +@pytest.mark.parametrize( + "instance, expected", + [ + ("ml.trn1.2xlarge", True), + ("ml.inf2.xlarge", True), + ("ml.c7gd.4xlarge", False), + ], +) +def test_is_inferentia_or_trainium(instance, expected): + assert _is_inferentia_or_trainium(instance) == expected + + +@pytest.mark.parametrize( + "image_uri, expected", + [ + ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124", + True, + ), + ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-neuronx-sdk2.18.2", + True, + ), + ( + None, + True, + ), + ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:" + "2.1.1-tgi2.0.0-gpu-py310-cu121-ubuntu22.04", + False, + ), + (None, True), + ], +) +def test_is_image_compatible_with_optimization_job(image_uri, expected): + assert _is_image_compatible_with_optimization_job(image_uri) == expected + + +def test_generate_optimized_model(): + pysdk_model = Mock() + pysdk_model.model_data = { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/" + } + } + + optimized_model = _generate_optimized_model(pysdk_model, mock_optimization_job_output) + + assert ( + optimized_model.image_uri + == mock_optimization_job_output["OptimizationOutput"]["RecommendedInferenceImage"] + ) + assert ( + optimized_model.model_data["S3DataSource"]["S3Uri"] + == mock_optimization_job_output["OutputConfig"]["S3OutputLocation"] + ) + assert optimized_model.instance_type == mock_optimization_job_output["DeploymentInstanceType"] + pysdk_model.add_tags.assert_called_once_with( + { + "Key": Tag.OPTIMIZATION_JOB_NAME, + "Value": mock_optimization_job_output["OptimizationJobName"], + } + ) + + +def test_is_optimized(): + model = Mock() + + model._tags = {"Key": Tag.OPTIMIZATION_JOB_NAME} + assert _is_optimized(model) is True + + model._tags = [{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER}] + assert _is_optimized(model) is True + + model._tags = [{"Key": Tag.FINE_TUNING_MODEL_PATH}] + assert _is_optimized(model) is False + + +@pytest.mark.parametrize( + "env, new_env, output_env", + [ + ({"a": "1"}, {"b": "2"}, {"a": "1", "b": "2"}), + (None, {"b": "2"}, {"b": "2"}), + ({"a": "1"}, None, {"a": "1"}), + (None, None, None), + ], +) +def test_update_environment_variables(env, new_env, output_env): + assert _update_environment_variables(env, new_env) == output_env + + +@pytest.mark.parametrize( + "speculative_decoding_config, expected_model_provider", + [ + ({"ModelProvider": "SageMaker"}, "sagemaker"), + ({"ModelProvider": "Custom"}, "custom"), + ({"ModelSource": "s3://"}, "custom"), + (None, None), + ], +) +def test_extract_speculative_draft_model_provider( + speculative_decoding_config, expected_model_provider +): + assert ( + _extract_speculative_draft_model_provider(speculative_decoding_config) + == expected_model_provider + ) + + +def test_extract_speculative_draft_model_s3_uri(): + res = _extracts_and_validates_speculative_model_source({"ModelSource": "s3://"}) + assert res == "s3://" + + +def test_extract_speculative_draft_model_s3_uri_ex(): + with pytest.raises(ValueError): + _extracts_and_validates_speculative_model_source({"ModelSource": None}) + + +def test_generate_channel_name(): + assert _generate_channel_name(None) is not None + + additional_model_data_sources = _generate_additional_model_data_sources( + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", True + ) + + assert _generate_channel_name(additional_model_data_sources) == "channel_name" + + +def test_generate_additional_model_data_sources(): + model_source = _generate_additional_model_data_sources( + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", True + ) + + assert model_source == [ + { + "ChannelName": "channel_name", + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"ACCEPT_EULA": True}, + }, + } + ] + + model_source = _generate_additional_model_data_sources( + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", False + ) + + assert model_source == [ + { + "ChannelName": "channel_name", + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + }, + } + ] + + +@pytest.mark.parametrize( + "s3_uri, expected", + [ + ( + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/", + True, + ), + ("invalid://", False), + ], +) +def test_is_s3_uri(s3_uri, expected): + assert _is_s3_uri(s3_uri) == expected + + +@pytest.mark.parametrize( + "quantization_config, compilation_config, expected_config, expected_env", + [ + ( + None, + { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + { + "ModelCompilationConfig": { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + }, + { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + }, + ), + ( + { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + None, + { + "ModelQuantizationConfig": { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + }, + { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + }, + ), + (None, None, None, None), + ], +) +def test_extract_optimization_config_and_env( + quantization_config, compilation_config, expected_config, expected_env +): + assert _extract_optimization_config_and_env(quantization_config, compilation_config) == ( + expected_config, + expected_env, + ) + + +@pytest.mark.parametrize( + "my_path, expected_path", + [ + ("local/path/llama/code", "local/path/llama"), + ("local/path/llama/code/", "local/path/llama"), + ("local/path/llama/", "local/path/llama/"), + ("local/path/llama", "local/path/llama"), + ], +) +def test_normalize_local_model_path(my_path, expected_path): + assert _normalize_local_model_path(my_path) == expected_path + + +class TestCustomSpeculativeDecodingConfig(unittest.TestCase): + + @patch("sagemaker.model.Model") + def test_with_s3_hf(self, mock_model): + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = { + "ModelSource": "s3://bucket/djl-inference-2024-07-02-00-03-32-127/code" + } + + res_model = _custom_speculative_decoding(mock_model, speculative_decoding_config) + + mock_model.add_tags.assert_called_once_with( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "custom"} + ) + + self.assertEqual( + res_model.env, + {"OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model"}, + ) + self.assertEqual( + res_model.additional_model_data_sources, + [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "S3Uri": "s3://bucket/djl-inference-2024-07-02-00-03-32-127/code", + "S3DataType": "S3Prefix", + "CompressionType": "None", + }, + } + ], + ) + + @patch("sagemaker.model.Model") + def test_with_s3_js(self, mock_model): + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = { + "ModelSource": "s3://bucket/huggingface-pytorch-tgi-inference" + } + + res_model = _custom_speculative_decoding(mock_model, speculative_decoding_config, True) + + self.assertEqual( + res_model.additional_model_data_sources, + [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "S3Uri": "s3://bucket/huggingface-pytorch-tgi-inference", + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"ACCEPT_EULA": True}, + }, + } + ], + ) + + @patch("sagemaker.model.Model") + def test_with_non_s3(self, mock_model): + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = {"ModelSource": "huggingface-pytorch-tgi-inference"} + + res_model = _custom_speculative_decoding(mock_model, speculative_decoding_config, False) + + self.assertIsNone(res_model.additional_model_data_sources) + self.assertEqual( + res_model.env, + {"OPTION_SPECULATIVE_DRAFT_MODEL": "huggingface-pytorch-tgi-inference"}, + ) + + mock_model.add_tags.assert_called_once_with( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "custom"} + ) diff --git a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py index 33af575e8f..4729efbda4 100644 --- a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py +++ b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock from sagemaker.serve import Mode, ModelServer from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH from sagemaker.serve.utils.telemetry_logger import ( @@ -25,7 +25,8 @@ from sagemaker.user_agent import SDK_VERSION MOCK_SESSION = Mock() -MOCK_FUNC_NAME = "Mock.deploy" +MOCK_DEPLOY_FUNC_NAME = "Mock.deploy" +MOCK_OPTIMIZE_FUNC_NAME = "Mock.optimize" MOCK_DJL_CONTAINER = ( "763104351884.dkr.ecr.us-west-2.amazonaws.com/" "djl-inference:0.25.0-deepspeed0.11.0-cu118" ) @@ -47,11 +48,15 @@ def __init__(self): self.serve_settings = Mock() self.sagemaker_session = MOCK_SESSION - @_capture_telemetry(MOCK_FUNC_NAME) + @_capture_telemetry(MOCK_DEPLOY_FUNC_NAME) def mock_deploy(self, mock_exception_func=None): if mock_exception_func: mock_exception_func() + @_capture_telemetry(MOCK_OPTIMIZE_FUNC_NAME) + def mock_optimize(self, *args, **kwargs): + pass + class TestTelemetryLogger(unittest.TestCase): @patch("sagemaker.serve.utils.telemetry_logger._requests_helper") @@ -88,7 +93,7 @@ def test_capture_telemetry_decorator_djl_success(self, mock_send_telemetry): args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=4" "&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118" f"&x-sdkVersion={SDK_VERSION}" @@ -118,7 +123,7 @@ def test_capture_telemetry_decorator_djl_success_with_custom_image(self, mock_se args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=4" "&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118" f"&x-sdkVersion={SDK_VERSION}" @@ -148,7 +153,7 @@ def test_capture_telemetry_decorator_tgi_success(self, mock_send_telemetry): args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=6" "&x-imageTag=huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04" f"&x-sdkVersion={SDK_VERSION}" @@ -196,7 +201,7 @@ def test_capture_telemetry_decorator_handle_exception_success(self, mock_send_te args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=4" "&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118" f"&x-sdkVersion={SDK_VERSION}" @@ -243,7 +248,7 @@ def test_construct_url_with_failure_reason_and_extra_info(self): f"&x-failureType={mock_failure_type}" f"&x-extra={mock_extra_info}" ) - self.assertEquals(ret_url, expected_base_url) + self.assertEqual(ret_url, expected_base_url) @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry): @@ -262,7 +267,7 @@ def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry): args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=1" "&x-imageTag=pytorch-inference:2.0.1-cpu-py310" f"&x-sdkVersion={SDK_VERSION}" @@ -275,3 +280,63 @@ def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry): mock_send_telemetry.assert_called_once_with( "1", 3, MOCK_SESSION, None, None, expected_extra_str ) + + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_capture_telemetry_decorator_optimize_with_default_configs(self, mock_send_telemetry): + mock_model_builder = ModelBuilderMock() + mock_model_builder.serve_settings.telemetry_opt_out = False + mock_model_builder.image_uri = None + mock_model_builder.mode = Mode.SAGEMAKER_ENDPOINT + mock_model_builder.model_server = ModelServer.TORCHSERVE + mock_model_builder.sagemaker_session.endpoint_arn = None + + mock_model_builder.mock_optimize() + + args = mock_send_telemetry.call_args.args + latency = str(args[5]).split("latency=")[1] + expected_extra_str = ( + f"{MOCK_OPTIMIZE_FUNC_NAME}" + "&x-modelServer=1" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-latency={latency}" + ) + + mock_send_telemetry.assert_called_once_with( + "1", 3, MOCK_SESSION, None, None, expected_extra_str + ) + + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_capture_telemetry_decorator_optimize_with_custom_configs(self, mock_send_telemetry): + mock_model_builder = ModelBuilderMock() + mock_model_builder.serve_settings.telemetry_opt_out = False + mock_model_builder.image_uri = None + mock_model_builder.mode = Mode.SAGEMAKER_ENDPOINT + mock_model_builder.model_server = ModelServer.TORCHSERVE + mock_model_builder.sagemaker_session.endpoint_arn = None + mock_model_builder.is_fine_tuned = True + mock_model_builder.is_compiled = True + mock_model_builder.is_quantized = True + mock_model_builder.speculative_decoding_draft_model_source = "sagemaker" + + mock_speculative_decoding_config = MagicMock() + mock_config = {"ModelProvider": "sagemaker"} + mock_speculative_decoding_config.__getitem__.side_effect = mock_config.__getitem__ + + mock_model_builder.mock_optimize() + + args = mock_send_telemetry.call_args.args + latency = str(args[5]).split("latency=")[1] + expected_extra_str = ( + f"{MOCK_OPTIMIZE_FUNC_NAME}" + "&x-modelServer=1" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-fineTuned=1" + f"&x-compiled=1" + f"&x-quantized=1" + f"&x-sdDraftModelSource=1" + f"&x-latency={latency}" + ) + + mock_send_telemetry.assert_called_once_with( + "1", 3, MOCK_SESSION, None, None, expected_extra_str + ) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index d81984e81f..deb295e6e1 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -34,6 +34,7 @@ from sagemaker.experiments._run_context import _RunContext from sagemaker.session_settings import SessionSettings from sagemaker.utils import ( + camel_case_to_pascal_case, deep_override_dict, flatten_dict, get_instance_type_family, @@ -51,7 +52,12 @@ _is_bad_link, custom_extractall_tarfile, can_model_package_source_uri_autopopulate, + get_instance_rate_per_hour, + extract_instance_rate_per_hour, _resolve_routing_config, + tag_exists, + _validate_new_tags, + remove_tag_with_key, ) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -1819,7 +1825,13 @@ def test_can_model_package_source_uri_autopopulate(): class TestDeepMergeDict(TestCase): def test_flatten_dict_basic(self): nested_dict = {"a": 1, "b": {"x": 2, "y": {"p": 3, "q": 4}}, "c": 5} - flattened_dict = {"a": 1, "b.x": 2, "b.y.p": 3, "b.y.q": 4, "c": 5} + flattened_dict = { + ("a",): 1, + ("b", "x"): 2, + ("b", "y", "p"): 3, + ("b", "y", "q"): 4, + ("c",): 5, + } self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) @@ -1831,13 +1843,19 @@ def test_flatten_dict_empty(self): def test_flatten_dict_no_nested(self): nested_dict = {"a": 1, "b": 2, "c": 3} - flattened_dict = {"a": 1, "b": 2, "c": 3} + flattened_dict = {("a",): 1, ("b",): 2, ("c",): 3} self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) def test_flatten_dict_with_various_types(self): nested_dict = {"a": [1, 2, 3], "b": {"x": None, "y": {"p": [], "q": ""}}, "c": 9} - flattened_dict = {"a": [1, 2, 3], "b.x": None, "b.y.p": [], "b.y.q": "", "c": 9} + flattened_dict = { + ("a",): [1, 2, 3], + ("b", "x"): None, + ("b", "y", "p"): [], + ("b", "y", "q"): "", + ("c",): 9, + } self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) @@ -1859,6 +1877,18 @@ def test_deep_override_nested_lists(self): expected_merged = {"a": [5], "b": {"c": [6, 7], "d": [8]}} self.assertDictEqual(deep_override_dict(dict1, dict2), expected_merged) + def test_deep_override_nested_lists_overriding_none(self): + dict1 = {"a": [{"c": "d"}, {"e": "f"}], "t": None} + dict2 = { + "a": [{"1": "2"}, {"3": "4"}, {"5": "6"}, "7"], + "t": {"g": [{"1": "2"}, {"3": "4"}, {"5": "6"}, "7"]}, + } + expected_merged = { + "a": [{"1": "2"}, {"3": "4"}, {"5": "6"}, "7"], + "t": {"g": [{"1": "2"}, {"3": "4"}, {"5": "6"}, "7"]}, + } + self.assertDictEqual(deep_override_dict(dict1, dict2), expected_merged) + def test_deep_override_skip_keys(self): dict1 = {"a": 1, "b": {"x": 2, "y": 3}, "c": [4, 5]} dict2 = { @@ -1870,6 +1900,140 @@ def test_deep_override_skip_keys(self): self.assertEqual(deep_override_dict(dict1, dict2, skip_keys=["c", "d"]), expected_result) +@pytest.mark.parametrize( + "instance, region, amazon_sagemaker_price_result, expected", + [ + ( + "ml.t4g.nano", + "us-west-2", + { + "PriceList": [ + { + "terms": { + "OnDemand": { + "3WK7G7WSYVS3K492.JRTCKXETXF": { + "priceDimensions": { + "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7": { + "unit": "Hrs", + "endRange": "Inf", + "description": "$0.9 per Unused Reservation Linux p2.xlarge Instance Hour", + "appliesTo": [], + "rateCode": "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7", + "beginRange": "0", + "pricePerUnit": {"USD": "0.9000000000"}, + } + } + } + } + }, + } + ] + }, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.9"}, + ), + ( + "ml.t4g.nano", + "eu-central-1", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.008"}, + ), + ( + "ml.t4g.nano", + "af-south-1", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.008"}, + ), + ( + "ml.t4g.nano", + "ap-northeast-2", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.008"}, + ), + ], +) +@patch("boto3.client") +def test_get_instance_rate_per_hour( + mock_client, instance, region, amazon_sagemaker_price_result, expected +): + + mock_client.return_value.get_products.side_effect = ( + lambda *args, **kwargs: amazon_sagemaker_price_result + ) + instance_rate = get_instance_rate_per_hour(instance_type=instance, region=region) + + assert instance_rate == expected + + +@pytest.mark.parametrize( + "price_data, expected_result", + [ + (None, None), + ( + { + "terms": { + "OnDemand": { + "3WK7G7WSYVS3K492.JRTCKXETXF": { + "priceDimensions": { + "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7": { + "unit": "Hrs", + "endRange": "Inf", + "description": "$0.9 per Unused Reservation Linux p2.xlarge Instance Hour", + "appliesTo": [], + "rateCode": "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7", + "beginRange": "0", + "pricePerUnit": {"USD": "0.9000000000"}, + } + } + } + } + } + }, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.9"}, + ), + ], +) +def test_extract_instance_rate_per_hour(price_data, expected_result): + out = extract_instance_rate_per_hour(price_data) + + assert out == expected_result + + @pytest.mark.parametrize( "routing_config, expected", [ @@ -1895,3 +2059,90 @@ def test_resolve_routing_config(routing_config, expected): def test_resolve_routing_config_ex(): pytest.raises(ValueError, lambda: _resolve_routing_config({"RoutingStrategy": "Invalid"})) + + +class TestConvertToPascalCase(TestCase): + def test_simple_dict(self): + input_dict = {"first_name": "John", "last_name": "Doe"} + expected_output = {"FirstName": "John", "LastName": "Doe"} + self.assertEqual(camel_case_to_pascal_case(input_dict), expected_output) + + def camel_case_to_pascal_case_nested(self): + input_dict = { + "model_name": "my-model", + "primary_container": { + "image": "my-docker-image:latest", + "model_data_url": "s3://my-bucket/model.tar.gz", + "environment": {"env_var_1": "value1", "env_var_2": "value2"}, + }, + "execution_role_arn": "arn:aws:iam::123456789012:role/my-sagemaker-role", + "tags": [ + {"key": "project", "value": "my-project"}, + {"key": "environment", "value": "development"}, + ], + } + expected_output = { + "ModelName": "my-model", + "PrimaryContainer": { + "Image": "my-docker-image:latest", + "ModelDataUrl": "s3://my-bucket/model.tar.gz", + "Environment": {"EnvVar1": "value1", "EnvVar2": "value2"}, + }, + "ExecutionRoleArn": "arn:aws:iam::123456789012:role/my-sagemaker-role", + "Tags": [ + {"Key": "project", "Value": "my-project"}, + {"Key": "environment", "Value": "development"}, + ], + } + self.assertEqual(camel_case_to_pascal_case(input_dict), expected_output) + + def test_empty_input(self): + self.assertEqual(camel_case_to_pascal_case({}), {}) + + +class TestTags(TestCase): + def test_tag_exists(self): + curr_tags = [{"Key": "project", "Value": "my-project"}] + self.assertTrue(tag_exists({"Key": "project", "Value": "my-project"}, curr_tags=curr_tags)) + + def test_does_not_tag_exists(self): + curr_tags = [{"Key": "project", "Value": "my-project"}] + self.assertFalse( + tag_exists({"Key": "project-2", "Value": "my-project-2"}, curr_tags=curr_tags) + ) + + def test_add_tags(self): + curr_tags = [{"Key": "project", "Value": "my-project"}] + new_tag = {"Key": "project-2", "Value": "my-project-2"} + expected = [ + {"Key": "project", "Value": "my-project"}, + {"Key": "project-2", "Value": "my-project-2"}, + ] + + self.assertEqual(_validate_new_tags(new_tag, curr_tags), expected) + + def test_new_add_tags(self): + new_tag = {"Key": "project-2", "Value": "my-project-2"} + + self.assertEqual(_validate_new_tags(new_tag, None), new_tag) + + def test_remove_existing_tag(self): + original_tags = [ + {"Key": "Tag1", "Value": "Value1"}, + {"Key": "Tag2", "Value": "Value2"}, + {"Key": "Tag3", "Value": "Value3"}, + ] + expected_output = [{"Key": "Tag1", "Value": "Value1"}, {"Key": "Tag3", "Value": "Value3"}] + self.assertEqual(remove_tag_with_key("Tag2", original_tags), expected_output) + + def test_remove_non_existent_tag(self): + original_tags = [ + {"Key": "Tag1", "Value": "Value1"}, + {"Key": "Tag2", "Value": "Value2"}, + {"Key": "Tag3", "Value": "Value3"}, + ] + self.assertEqual(remove_tag_with_key("NonExistentTag", original_tags), original_tags) + + def test_remove_only_tag(self): + original_tags = [{"Key": "Tag1", "Value": "Value1"}] + self.assertIsNone(remove_tag_with_key("Tag1", original_tags))