diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index 173266e1a1..a22890e873 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -20,7 +20,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartModelType, JumpStartScriptScope from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -38,6 +38,7 @@ def retrieve_default( instance_type: Optional[str] = None, script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Dict[str, str]: """Retrieves the default container environment variables for the model matching the arguments. @@ -70,6 +71,8 @@ def retrieve_default( script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables. config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: dict: The variables to use for the model. @@ -94,4 +97,5 @@ def retrieve_default( instance_type=instance_type, script=script, config_name=config_name, + model_type=model_type, ) diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index f1353cc8ff..86208858de 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -20,7 +20,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import HyperparameterValidationMode +from sagemaker.jumpstart.enums import HyperparameterValidationMode, JumpStartModelType from sagemaker.jumpstart.validators import validate_hyperparameters from sagemaker.session import Session @@ -38,6 +38,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Dict[str, str]: """Retrieves the default training hyperparameters for the model matching the given arguments. @@ -71,6 +72,8 @@ def retrieve_default( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: dict: The hyperparameters to use for the model. @@ -93,6 +96,7 @@ def retrieve_default( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 65497927e9..95080b8406 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -22,6 +22,7 @@ from sagemaker import utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import is_jumpstart_model_input from sagemaker.spark import defaults from sagemaker.jumpstart import artifacts @@ -72,6 +73,7 @@ def retrieve( serverless_inference_config=None, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name=None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. @@ -128,6 +130,8 @@ def retrieve( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: str: The ECR URI for the corresponding SageMaker Docker image. @@ -169,6 +173,7 @@ def retrieve( tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]): diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index f10bfe4a5d..48775542e6 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -19,6 +19,7 @@ SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( @@ -41,6 +42,7 @@ def _retrieve_default_environment_variables( instance_type: Optional[str] = None, script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Dict[str, str]: """Retrieves the inference environment variables for the model matching the given arguments. @@ -73,6 +75,8 @@ def _retrieve_default_environment_variables( script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables. config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: dict: the inference environment variables to use for the model. """ @@ -91,6 +95,7 @@ def _retrieve_default_environment_variables( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) default_environment_variables: Dict[str, str] = {} @@ -130,6 +135,7 @@ def _retrieve_default_environment_variables( sagemaker_session=sagemaker_session, instance_type=instance_type, config_name=config_name, + model_type=model_type, ) ) @@ -178,6 +184,7 @@ def _retrieve_gated_model_uri_env_var_value( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[str]: """Retrieves the gated model env var URI matching the given arguments. @@ -204,7 +211,8 @@ def _retrieve_gated_model_uri_env_var_value( instance_type (str): An instance type to optionally supply in order to get environment variables specific for the instance type. config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). - + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: Optional[str]: the s3 URI to use for the environment variable, or None if the model does not have gated training artifacts. @@ -227,6 +235,7 @@ def _retrieve_gated_model_uri_env_var_value( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) s3_key: Optional[str] = ( diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index 4383a17bf9..4bfe1732be 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -17,6 +17,7 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, VariableScope, ) @@ -38,6 +39,7 @@ def _retrieve_default_hyperparameters( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ): """Retrieves the training hyperparameters for the model matching the given arguments. @@ -71,6 +73,8 @@ def _retrieve_default_hyperparameters( instance_type (str): An instance type to optionally supply in order to get hyperparameters specific for the instance type. config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: dict: the hyperparameters to use for the model. """ @@ -89,6 +93,7 @@ def _retrieve_default_hyperparameters( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) default_hyperparameters: Dict[str, str] = {} diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 0d4a61d112..9079b094a4 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -19,6 +19,7 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, ModelFramework, ) @@ -48,6 +49,7 @@ def _retrieve_image_uri( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ): """Retrieves the container image URI for JumpStart models. @@ -100,6 +102,8 @@ def _retrieve_image_uri( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: str: the ECR URI for the corresponding SageMaker Docker image. @@ -123,6 +127,7 @@ def _retrieve_image_uri( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) if image_scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 80b5aa8ef5..f3b44524e7 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -17,6 +17,7 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( @@ -35,6 +36,7 @@ def _model_supports_incremental_training( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> bool: """Returns True if the model supports incremental training. @@ -59,6 +61,8 @@ def _model_supports_incremental_training( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: bool: the support status for incremental training. """ @@ -77,6 +81,7 @@ def _model_supports_incremental_training( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) return model_specs.supports_incremental_training() diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index eb7980b88f..6f2f7f38b5 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -167,6 +167,7 @@ def _retrieve_estimator_init_kwargs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> dict: """Retrieves kwargs for `Estimator`. @@ -193,6 +194,8 @@ def _retrieve_estimator_init_kwargs( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: dict: the kwargs to use for the use case. """ @@ -211,6 +214,7 @@ def _retrieve_estimator_init_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) kwargs = deepcopy(model_specs.estimator_kwargs) @@ -233,6 +237,7 @@ def _retrieve_estimator_fit_kwargs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> dict: """Retrieves kwargs for `Estimator.fit`. @@ -257,6 +262,8 @@ def _retrieve_estimator_fit_kwargs( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: dict: the kwargs to use for the use case. @@ -276,6 +283,7 @@ def _retrieve_estimator_fit_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) return model_specs.fit_kwargs diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 16e81b2785..d4a0386c08 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -18,6 +18,7 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( @@ -37,6 +38,7 @@ def _retrieve_default_training_metric_definitions( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model. @@ -63,6 +65,8 @@ def _retrieve_default_training_metric_definitions( instance_type (str): An instance type to optionally supply in order to get metric definitions specific for the instance type. config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: list: the default training metric definitions to use for the model or None. """ @@ -81,6 +85,7 @@ def _retrieve_default_training_metric_definitions( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) default_metric_definitions = ( diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index 7aa5be7507..c3b967d83e 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -130,6 +130,7 @@ def _retrieve_model_package_model_artifact_s3_uri( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[str]: """Retrieves s3 artifact uri associated with model package. @@ -156,6 +157,8 @@ def _retrieve_model_package_model_artifact_s3_uri( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: str: the model package artifact uri to use for the model or None. @@ -179,6 +182,7 @@ def _retrieve_model_package_model_artifact_s3_uri( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) if model_specs.training_model_package_artifact_uris is None: diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 5fac979b14..90ee7dea8d 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -20,6 +20,7 @@ ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( @@ -97,6 +98,7 @@ def _retrieve_model_uri( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -125,6 +127,8 @@ def _retrieve_model_uri( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: str: the model artifact S3 URI for the corresponding model. @@ -149,6 +153,7 @@ def _retrieve_model_uri( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) model_artifact_key: str @@ -195,6 +200,7 @@ def _model_supports_training_model_uri( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> bool: """Returns True if the model supports training with model uri field. @@ -219,6 +225,8 @@ def _model_supports_training_model_uri( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: bool: the support status for model uri with training. """ @@ -237,6 +245,7 @@ def _model_supports_training_model_uri( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) return model_specs.use_training_model_artifact() diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index 5029f53cfb..e9b58debc3 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -19,6 +19,7 @@ ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( @@ -39,6 +40,7 @@ def _retrieve_script_uri( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ): """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -67,6 +69,8 @@ def _retrieve_script_uri( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: str: the model script URI for the corresponding model. @@ -90,6 +94,7 @@ def _retrieve_script_uri( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) if script_scope == JumpStartScriptScope.INFERENCE: @@ -117,6 +122,7 @@ def _model_supports_inference_script_uri( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, + model_type: Optional[str] = None, ) -> bool: """Returns True if the model supports inference with script uri field. @@ -140,6 +146,8 @@ def _model_supports_inference_script_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: bool: the support status for script uri with inference. """ @@ -158,6 +166,7 @@ def _model_supports_inference_script_uri( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) return model_specs.use_inference_script_uri() diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index d6c26b0429..84c9d09c3d 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -29,6 +29,10 @@ _retrieve_model_package_model_artifact_s3_uri, ) from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base +from sagemaker.jumpstart.factory.utils import ( + _set_temp_sagemaker_session_if_not_set, + get_model_info_default_kwargs, +) from sagemaker.jumpstart.hub.utils import ( construct_hub_model_arn_from_inputs, construct_hub_model_reference_arn_from_inputs, @@ -203,10 +207,25 @@ def get_init_kwargs( enable_session_tag_chaining=enable_session_tag_chaining, ) + estimator_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set( + kwargs=estimator_init_kwargs + ) + estimator_init_kwargs.specs = verify_model_region_and_return_specs( + **get_model_info_default_kwargs( + estimator_init_kwargs, include_model_version=False, include_tolerate_flags=False + ), + version=estimator_init_kwargs.model_version or "*", + scope=JumpStartScriptScope.TRAINING, + # We set these flags to True to retrieve the json specs. + # Exceptions will be thrown later if these are not tolerated. + tolerate_deprecated_model=True, + tolerate_vulnerable_model=True, + ) + estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( - estimator_init_kwargs + estimator_init_kwargs, orig_session ) estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs) @@ -263,6 +282,19 @@ def get_fit_kwargs( config_name=config_name, ) + estimator_fit_kwargs, _ = _set_temp_sagemaker_session_if_not_set(kwargs=estimator_fit_kwargs) + estimator_fit_kwargs.specs = verify_model_region_and_return_specs( + **get_model_info_default_kwargs( + estimator_fit_kwargs, include_model_version=False, include_tolerate_flags=False + ), + version=estimator_fit_kwargs.model_version or "*", + scope=JumpStartScriptScope.TRAINING, + # We set these flags to True to retrieve the json specs. + # Exceptions will be thrown later if these are not tolerated. + tolerate_deprecated_model=True, + tolerate_vulnerable_model=True, + ) + estimator_fit_kwargs = _add_model_version_to_kwargs(estimator_fit_kwargs) estimator_fit_kwargs = _add_region_to_kwargs(estimator_fit_kwargs) estimator_fit_kwargs = _add_training_job_name_to_kwargs(estimator_fit_kwargs) @@ -442,17 +474,14 @@ def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: def _add_sagemaker_session_with_custom_user_agent_to_kwargs( - kwargs: JumpStartKwargs, + kwargs: JumpStartKwargs, orig_session: Optional[Session] ) -> JumpStartKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" - kwargs.sagemaker_session = ( - kwargs.sagemaker_session - or get_default_jumpstart_session_with_user_agent_suffix( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - config_name=None, - is_hub_content=kwargs.hub_arn is not None, - ) + kwargs.sagemaker_session = orig_session or get_default_jumpstart_session_with_user_agent_suffix( + model_id=kwargs.model_id, + model_version=kwargs.model_version, + config_name=None, + is_hub_content=kwargs.hub_arn is not None, ) return kwargs @@ -463,17 +492,7 @@ def _add_model_version_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: kwargs.model_version = kwargs.model_version or "*" if kwargs.hub_arn: - hub_content_version = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - scope=JumpStartScriptScope.TRAINING, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - ).version + hub_content_version = kwargs.specs.version kwargs.model_version = hub_content_version return kwargs @@ -500,15 +519,7 @@ def _add_instance_type_and_count_to_kwargs( orig_instance_type = kwargs.instance_type kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default( - region=kwargs.region, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - scope=JumpStartScriptScope.TRAINING, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + **get_model_info_default_kwargs(kwargs), scope=JumpStartScriptScope.TRAINING ) kwargs.instance_count = kwargs.instance_count or 1 @@ -524,17 +535,7 @@ def _add_instance_type_and_count_to_kwargs( def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs: """Sets tags in kwargs based on default or override, returns full kwargs.""" - full_model_version = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - scope=JumpStartScriptScope.TRAINING, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, - ).version + full_model_version = kwargs.specs.version if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_info_tags( @@ -563,17 +564,10 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE """Sets image uri in kwargs based on default or override, returns full kwargs.""" kwargs.image_uri = kwargs.image_uri or image_uris.retrieve( - region=kwargs.region, + **get_model_info_default_kwargs(kwargs), + instance_type=kwargs.instance_type, framework=None, image_scope=JumpStartScriptScope.TRAINING, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - instance_type=kwargs.instance_type, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, ) return kwargs @@ -584,17 +578,7 @@ def _add_model_reference_arn_to_kwargs( ) -> JumpStartEstimatorInitKwargs: """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" - hub_content_type = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - scope=JumpStartScriptScope.TRAINING, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - ).hub_content_type + hub_content_type = kwargs.specs.hub_content_type kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None if hub_content_type == HubContentType.MODEL_REFERENCE: @@ -609,40 +593,17 @@ def _add_model_reference_arn_to_kwargs( def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs: """Sets model uri in kwargs based on default or override, returns full kwargs.""" - if _model_supports_training_model_uri( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - ): + if _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)): default_model_uri = model_uris.retrieve( model_scope=JumpStartScriptScope.TRAINING, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - region=kwargs.region, instance_type=kwargs.instance_type, - config_name=kwargs.config_name, + **get_model_info_default_kwargs(kwargs), ) if ( kwargs.model_uri is not None and kwargs.model_uri != default_model_uri - and not _model_supports_incremental_training( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, - ) + and not _model_supports_incremental_training(**get_model_info_default_kwargs(kwargs)) ): JUMPSTART_LOGGER.warning( "'%s' does not support incremental training but is being trained with" @@ -670,15 +631,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart """Sets source dir in kwargs based on default or override, returns full kwargs.""" kwargs.source_dir = kwargs.source_dir or script_uris.retrieve( - script_scope=JumpStartScriptScope.TRAINING, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - region=kwargs.region, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + script_scope=JumpStartScriptScope.TRAINING, **get_model_info_default_kwargs(kwargs) ) return kwargs @@ -690,29 +643,15 @@ def _add_env_to_kwargs( """Sets environment in kwargs based on default or override, returns full kwargs.""" extra_env_vars = environment_variables.retrieve_default( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, - include_aws_sdk_env_vars=False, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, + **get_model_info_default_kwargs(kwargs), script=JumpStartScriptScope.TRAINING, instance_type=kwargs.instance_type, - config_name=kwargs.config_name, + include_aws_sdk_env_vars=False, ) model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, + **get_model_info_default_kwargs(kwargs), scope=JumpStartScriptScope.TRAINING, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, ) if model_package_artifact_uri: @@ -732,17 +671,7 @@ def _add_env_to_kwargs( environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY) and str(environment.get("accept_eula", "")).lower() != "true" ): - model_specs = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, - scope=JumpStartScriptScope.TRAINING, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, - ) + model_specs = kwargs.specs if model_specs.is_gated_model(): raise ValueError( "Need to define ‘accept_eula'='true' within Environment. " @@ -768,15 +697,8 @@ def _add_training_job_name_to_kwargs( """Sets resource name based on default or override, returns full kwargs.""" default_training_job_name = _retrieve_resource_name_base( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, + **get_model_info_default_kwargs(kwargs), scope=JumpStartScriptScope.TRAINING, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, ) kwargs.job_name = kwargs.job_name or ( @@ -796,15 +718,8 @@ def _add_hyperparameters_to_kwargs( ) default_hyperparameters = hyperparameters_utils.retrieve_default( - region=kwargs.region, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, + **get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type, - config_name=kwargs.config_name, ) for key, value in default_hyperparameters.items(): @@ -831,15 +746,8 @@ def _add_metric_definitions_to_kwargs( default_metric_definitions = ( metric_definitions_utils.retrieve_default( - region=kwargs.region, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, + **get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type, - config_name=kwargs.config_name, ) or [] ) @@ -862,15 +770,7 @@ def _add_estimator_extra_kwargs( """Sets extra kwargs based on default or override, returns full kwargs.""" estimator_kwargs_to_add = _retrieve_estimator_init_kwargs( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - instance_type=kwargs.instance_type, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + **get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type ) for key, value in estimator_kwargs_to_add.items(): @@ -888,16 +788,7 @@ def _add_estimator_extra_kwargs( def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstimatorFitKwargs: """Sets extra kwargs based on default or override, returns full kwargs.""" - fit_kwargs_to_add = _retrieve_estimator_fit_kwargs( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, - ) + fit_kwargs_to_add = _retrieve_estimator_fit_kwargs(**get_model_info_default_kwargs(kwargs)) for key, value in fit_kwargs_to_add.items(): if getattr(kwargs, key) is None: @@ -912,15 +803,8 @@ def _add_config_name_to_kwargs( """Sets tags in kwargs based on default or override, returns full kwargs.""" kwargs.config_name = kwargs.config_name or get_top_ranked_config_name( - region=kwargs.region, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - sagemaker_session=kwargs.sagemaker_session, scope=JumpStartScriptScope.TRAINING, - model_type=kwargs.model_type, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - hub_arn=kwargs.hub_arn, + **get_model_info_default_kwargs(kwargs, include_config_name=False), ) return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index a193732ca1..ccafed844d 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -29,7 +29,6 @@ ) from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base from sagemaker.jumpstart.constants import ( - DEFAULT_JUMPSTART_SAGEMAKER_SESSION, INFERENCE_ENTRY_POINT_SCRIPT_NAME, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, @@ -61,6 +60,10 @@ verify_model_region_and_return_specs, ) +from sagemaker.jumpstart.factory.utils import ( + _set_temp_sagemaker_session_if_not_set, + get_model_info_default_kwargs, +) from sagemaker.model_monitor.data_capture_config import DataCaptureConfig from sagemaker.base_predictor import Predictor from sagemaker import accept_types, content_types, serializers, deserializers @@ -158,18 +161,16 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni def _add_sagemaker_session_with_custom_user_agent_to_kwargs( - kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs] + kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs], + orig_session: Optional[Session], ) -> JumpStartModelInitKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" - kwargs.sagemaker_session = ( - kwargs.sagemaker_session - or get_default_jumpstart_session_with_user_agent_suffix( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - config_name=kwargs.config_name, - is_hub_content=kwargs.hub_arn is not None, - ) + kwargs.sagemaker_session = orig_session or get_default_jumpstart_session_with_user_agent_suffix( + model_id=kwargs.model_id, + model_version=kwargs.model_version, + config_name=kwargs.config_name, + is_hub_content=kwargs.hub_arn is not None, ) return kwargs @@ -196,17 +197,7 @@ def _add_model_version_to_kwargs( kwargs.model_version = kwargs.model_version or "*" if kwargs.hub_arn: - hub_content_version = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - scope=JumpStartScriptScope.INFERENCE, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - ).version + hub_content_version = kwargs.specs.version kwargs.model_version = hub_content_version return kwargs @@ -230,17 +221,9 @@ def _add_instance_type_to_kwargs( orig_instance_type = kwargs.instance_type kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default( - region=kwargs.region, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, + **get_model_info_default_kwargs(kwargs), scope=JumpStartScriptScope.INFERENCE, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, training_instance_type=kwargs.training_instance_type, - model_type=kwargs.model_type, - config_name=kwargs.config_name, ) if not disable_instance_type_logging and orig_instance_type is None: @@ -249,18 +232,7 @@ def _add_instance_type_to_kwargs( kwargs.instance_type, ) - specs = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - scope=JumpStartScriptScope.INFERENCE, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - config_name=kwargs.config_name, - hub_arn=kwargs.hub_arn, - ) + specs = kwargs.specs if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs: return kwargs @@ -290,17 +262,10 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel return kwargs kwargs.image_uri = kwargs.image_uri or image_uris.retrieve( - region=kwargs.region, + **get_model_info_default_kwargs(kwargs), framework=None, image_scope=JumpStartScriptScope.INFERENCE, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, ) return kwargs @@ -311,17 +276,7 @@ def _add_model_reference_arn_to_kwargs( ) -> JumpStartModelInitKwargs: """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" - hub_content_type = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - scope=JumpStartScriptScope.INFERENCE, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - ).hub_content_type + hub_content_type = kwargs.specs.hub_content_type kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None if hub_content_type == HubContentType.MODEL_REFERENCE: @@ -340,17 +295,11 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode kwargs.model_data = None return kwargs + model_info_kwargs = get_model_info_default_kwargs(kwargs) model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve( + **model_info_kwargs, model_scope=JumpStartScriptScope.INFERENCE, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, - config_name=kwargs.config_name, ) if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"): @@ -384,26 +333,9 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode source_dir = kwargs.source_dir - if _model_supports_inference_script_uri( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, - ): + if _model_supports_inference_script_uri(**get_model_info_default_kwargs(kwargs)): source_dir = source_dir or script_uris.retrieve( - script_scope=JumpStartScriptScope.INFERENCE, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + **get_model_info_default_kwargs(kwargs), script_scope=JumpStartScriptScope.INFERENCE ) kwargs.source_dir = source_dir @@ -420,16 +352,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod entry_point = kwargs.entry_point - if _model_supports_inference_script_uri( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, - ): + if _model_supports_inference_script_uri(**get_model_info_default_kwargs(kwargs)): entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME @@ -451,17 +374,10 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw env = {} extra_env_vars = environment_variables.retrieve_default( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, + **get_model_info_default_kwargs(kwargs), include_aws_sdk_env_vars=False, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, script=JumpStartScriptScope.INFERENCE, instance_type=kwargs.instance_type, - config_name=kwargs.config_name, ) for key, value in extra_env_vars.items(): @@ -483,17 +399,9 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt """Sets model package arn based on default or override, returns full kwargs.""" model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, + **get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type, scope=JumpStartScriptScope.INFERENCE, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - config_name=kwargs.config_name, ) kwargs.model_package_arn = model_package_arn @@ -503,17 +411,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets extra kwargs based on default or override, returns full kwargs.""" - model_kwargs_to_add = _retrieve_model_init_kwargs( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - config_name=kwargs.config_name, - ) + model_kwargs_to_add = _retrieve_model_init_kwargs(**get_model_info_default_kwargs(kwargs)) for key, value in model_kwargs_to_add.items(): if getattr(kwargs, key) is None: @@ -541,17 +439,7 @@ def _add_endpoint_name_to_kwargs( ) -> JumpStartModelDeployKwargs: """Sets resource name based on default or override, returns full kwargs.""" - default_endpoint_name = _retrieve_resource_name_base( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - config_name=kwargs.config_name, - ) + default_endpoint_name = _retrieve_resource_name_base(**get_model_info_default_kwargs(kwargs)) kwargs.endpoint_name = kwargs.endpoint_name or ( name_from_base(default_endpoint_name) if default_endpoint_name is not None else None @@ -565,17 +453,7 @@ def _add_model_name_to_kwargs( ) -> JumpStartModelInitKwargs: """Sets resource name based on default or override, returns full kwargs.""" - default_model_name = _retrieve_resource_name_base( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - config_name=kwargs.config_name, - ) + default_model_name = _retrieve_resource_name_base(**get_model_info_default_kwargs(kwargs)) kwargs.name = kwargs.name or ( name_from_base(default_model_name) if default_model_name is not None else None @@ -587,18 +465,7 @@ def _add_model_name_to_kwargs( def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: """Sets tags based on default or override, returns full kwargs.""" - full_model_version = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - scope=JumpStartScriptScope.INFERENCE, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - config_name=kwargs.config_name, - ).version + full_model_version = kwargs.specs.version if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_info_tags( @@ -628,16 +495,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] """Sets extra kwargs based on default or override, returns full kwargs.""" deploy_kwargs_to_add = _retrieve_model_deploy_kwargs( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, - instance_type=kwargs.instance_type, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - config_name=kwargs.config_name, + **get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type ) for key, value in deploy_kwargs_to_add.items(): @@ -651,17 +509,9 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel """Sets the resource requirements based on the default or an override. Returns full kwargs.""" kwargs.resources = kwargs.resources or resource_requirements.retrieve_default( - region=kwargs.region, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - hub_arn=kwargs.hub_arn, + **get_model_info_default_kwargs(kwargs), scope=JumpStartScriptScope.INFERENCE, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, instance_type=kwargs.instance_type, - config_name=kwargs.config_name, ) return kwargs @@ -694,20 +544,9 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta ValueError: If the instance_type is not supported with the current config. """ - # we need to create a default JS session (without custom user agent) - # in order to retrieve config name info - temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION - kwargs.config_name = kwargs.config_name or get_top_ranked_config_name( - region=kwargs.region, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - sagemaker_session=temp_session, + **get_model_info_default_kwargs(kwargs, include_config_name=False), scope=JumpStartScriptScope.INFERENCE, - model_type=kwargs.model_type, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - hub_arn=kwargs.hub_arn, ) if kwargs.config_name is None: @@ -721,18 +560,7 @@ def _add_additional_model_data_sources_to_kwargs( ) -> JumpStartModelInitKwargs: """Sets default additional model data sources to init kwargs""" - specs = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - scope=JumpStartScriptScope.INFERENCE, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - config_name=kwargs.config_name, - hub_arn=kwargs.hub_arn, - ) + specs = kwargs.specs # Append speculative decoding data source from metadata speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources() for data_source in speculative_decoding_data_sources: @@ -765,39 +593,17 @@ def _add_config_name_to_deploy_kwargs( ValueError: If the instance_type is not supported with the current config. """ - # we need to create a default JS session (without custom user agent) - # in order to retrieve config name info - temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION - if training_config_name: - specs = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - scope=JumpStartScriptScope.INFERENCE, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=temp_session, - model_type=kwargs.model_type, - config_name=kwargs.config_name, - hub_arn=kwargs.hub_arn, - ) + specs = kwargs.specs default_config_name = _select_inference_config_from_training_config( specs=specs, training_config_name=training_config_name ) else: - default_config_name = get_top_ranked_config_name( - region=kwargs.region, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - sagemaker_session=temp_session, + default_config_name = kwargs.config_name or get_top_ranked_config_name( + **get_model_info_default_kwargs(kwargs, include_config_name=False), scope=JumpStartScriptScope.INFERENCE, - model_type=kwargs.model_type, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - hub_arn=kwargs.hub_arn, ) kwargs.config_name = kwargs.config_name or default_config_name @@ -878,6 +684,18 @@ def get_deploy_kwargs( config_name=config_name, routing_config=routing_config, ) + deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs) + deploy_kwargs.specs = verify_model_region_and_return_specs( + **get_model_info_default_kwargs( + deploy_kwargs, include_model_version=False, include_tolerate_flags=False + ), + version=deploy_kwargs.model_version or "*", + scope=JumpStartScriptScope.INFERENCE, + # We set these flags to True to retrieve the json specs. + # Exceptions will be thrown later if these are not tolerated. + tolerate_deprecated_model=True, + tolerate_vulnerable_model=True, + ) deploy_kwargs = _add_config_name_to_deploy_kwargs( kwargs=deploy_kwargs, training_config_name=training_config_name @@ -885,7 +703,9 @@ def get_deploy_kwargs( deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs) - deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( + kwargs=deploy_kwargs, orig_session=orig_session + ) deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) @@ -945,6 +765,7 @@ def get_register_kwargs( register_kwargs = JumpStartModelRegisterKwargs( model_id=model_id, model_version=model_version, + config_name=config_name, hub_arn=hub_arn, model_type=model_type, region=region, @@ -977,24 +798,25 @@ def get_register_kwargs( accept_eula=accept_eula, ) - model_specs = verify_model_region_and_return_specs( - model_id=model_id, - version=model_version, - hub_arn=hub_arn, - model_type=model_type, - region=region, + register_kwargs.specs = verify_model_region_and_return_specs( + **get_model_info_default_kwargs( + register_kwargs, include_model_version=False, include_tolerate_flags=False + ), + version=register_kwargs.model_version or "*", scope=JumpStartScriptScope.INFERENCE, - sagemaker_session=sagemaker_session, - tolerate_deprecated_model=tolerate_deprecated_model, - tolerate_vulnerable_model=tolerate_vulnerable_model, - config_name=config_name, + # We set these flags to True to retrieve the json specs. + # Exceptions will be thrown later if these are not tolerated. + tolerate_deprecated_model=True, + tolerate_vulnerable_model=True, ) register_kwargs.content_types = ( - register_kwargs.content_types or model_specs.predictor_specs.supported_content_types + register_kwargs.content_types + or register_kwargs.specs.predictor_specs.supported_content_types ) register_kwargs.response_types = ( - register_kwargs.response_types or model_specs.predictor_specs.supported_accept_types + register_kwargs.response_types + or register_kwargs.specs.predictor_specs.supported_accept_types ) return register_kwargs @@ -1068,13 +890,27 @@ def get_init_kwargs( config_name=config_name, additional_model_data_sources=additional_model_data_sources, ) + model_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set( + kwargs=model_init_kwargs + ) + model_init_kwargs.specs = verify_model_region_and_return_specs( + **get_model_info_default_kwargs( + model_init_kwargs, include_model_version=False, include_tolerate_flags=False + ), + version=model_init_kwargs.model_version or "*", + scope=JumpStartScriptScope.INFERENCE, + # We set these flags to True to retrieve the json specs. + # Exceptions will be thrown later if these are not tolerated. + tolerate_deprecated_model=True, + tolerate_vulnerable_model=True, + ) model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( - kwargs=model_init_kwargs + kwargs=model_init_kwargs, orig_session=orig_session ) model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/factory/utils.py b/src/sagemaker/jumpstart/factory/utils.py new file mode 100644 index 0000000000..faf1f8886f --- /dev/null +++ b/src/sagemaker/jumpstart/factory/utils.py @@ -0,0 +1,79 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores JumpStart factory utilities.""" + +from __future__ import absolute_import +from typing import Tuple, Union + +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.types import ( + JumpStartEstimatorDeployKwargs, + JumpStartEstimatorFitKwargs, + JumpStartEstimatorInitKwargs, + JumpStartModelDeployKwargs, + JumpStartModelInitKwargs, +) +from sagemaker.session import Session + +KwargsType = Union[ + JumpStartModelDeployKwargs, + JumpStartModelInitKwargs, + JumpStartEstimatorFitKwargs, + JumpStartEstimatorInitKwargs, + JumpStartEstimatorDeployKwargs, +] + + +def get_model_info_default_kwargs( + kwargs: KwargsType, + include_config_name: bool = True, + include_model_version: bool = True, + include_tolerate_flags: bool = True, +) -> dict: + """Returns a dictionary of model info kwargs to use with JumpStart APIs.""" + + kwargs_dict = { + "model_id": kwargs.model_id, + "hub_arn": kwargs.hub_arn, + "region": kwargs.region, + "sagemaker_session": kwargs.sagemaker_session, + "model_type": kwargs.model_type, + } + if include_config_name: + kwargs_dict.update({"config_name": kwargs.config_name}) + + if include_model_version: + kwargs_dict.update({"model_version": kwargs.model_version}) + + if include_tolerate_flags: + kwargs_dict.update( + { + "tolerate_deprecated_model": kwargs.tolerate_deprecated_model, + "tolerate_vulnerable_model": kwargs.tolerate_vulnerable_model, + } + ) + + return kwargs_dict + + +def _set_temp_sagemaker_session_if_not_set(kwargs: KwargsType) -> Tuple[KwargsType, Session]: + """Sets a temporary sagemaker session if one is not set, and returns original session. + + We need to create a default JS session (without custom user agent) + in order to retrieve config name info. + """ + + orig_session = kwargs.sagemaker_session + if kwargs.sagemaker_session is None: + kwargs.sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION + return kwargs, orig_session diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 9d5acf6c6e..f3313b3862 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2055,14 +2055,20 @@ def __init__( class JumpStartKwargs(JumpStartDataHolderType): """Data class for JumpStart object kwargs.""" + BASE_SERIALIZATION_EXCLUSION_SET: Set[str] = ["specs"] SERIALIZATION_EXCLUSION_SET: Set[str] = set() def to_kwargs_dict(self, exclude_keys: bool = True): """Serializes object to dictionary to be used for kwargs for method arguments.""" kwargs_dict = {} for field in self.__slots__: - if exclude_keys and field not in self.SERIALIZATION_EXCLUSION_SET or not exclude_keys: - att_value = getattr(self, field) + if ( + exclude_keys + and field + not in self.SERIALIZATION_EXCLUSION_SET.union(self.BASE_SERIALIZATION_EXCLUSION_SET) + or not exclude_keys + ): + att_value = getattr(self, field, None) if att_value is not None: kwargs_dict[field] = getattr(self, field) return kwargs_dict @@ -2104,6 +2110,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "additional_model_data_sources", "hub_content_type", "model_reference_arn", + "specs", ] SERIALIZATION_EXCLUSION_SET = { @@ -2226,6 +2233,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "endpoint_type", "config_name", "routing_config", + "specs", ] SERIALIZATION_EXCLUSION_SET = { @@ -2379,6 +2387,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "enable_session_tag_chaining", "hub_content_type", "model_reference_arn", + "specs", ] SERIALIZATION_EXCLUSION_SET = { @@ -2534,6 +2543,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "sagemaker_session", "config_name", + "specs", ] SERIALIZATION_EXCLUSION_SET = { @@ -2628,6 +2638,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "model_name", "use_compiled_model", "config_name", + "specs", ] SERIALIZATION_EXCLUSION_SET = { @@ -2767,6 +2778,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "config_name", "model_card", "accept_eula", + "specs", ] SERIALIZATION_EXCLUSION_SET = { diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index dbf7ef7650..8b7d80b48d 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -20,6 +20,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model matching the given arguments. @@ -61,6 +63,8 @@ def retrieve_default( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: list: The default metric definitions to use for the model or None. @@ -83,4 +87,5 @@ def retrieve_default( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 2949fbaf5f..6f788eb8b9 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -19,6 +19,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -36,6 +37,7 @@ def retrieve( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the model artifact Amazon S3 URI for the model matching the given arguments. @@ -62,6 +64,8 @@ def retrieve( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: str: The model artifact S3 URI for the corresponding model. @@ -89,4 +93,5 @@ def retrieve( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index d60095b521..f280a627d2 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -20,6 +20,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ def retrieve( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -60,6 +62,8 @@ def retrieve( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: str: The model script URI for the corresponding model. @@ -85,4 +89,5 @@ def retrieve( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, + model_type=model_type, ) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index fbf76d1c98..f6666e68ae 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1425,6 +1425,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( sagemaker_session=sagemaker_session, config_name=None, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @@ -1482,6 +1483,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( sagemaker_session=sagemaker_session, config_name=None, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @mock.patch("sagemaker.utils.sagemaker_timestamp")