diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 4c29c3f625..d6c26b0429 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -69,6 +69,7 @@ add_jumpstart_model_info_tags, get_eula_message, get_default_jumpstart_session_with_user_agent_suffix, + get_top_ranked_config_name, update_dict_if_key_not_present, resolve_estimator_sagemaker_config_field, verify_model_region_and_return_specs, @@ -204,7 +205,9 @@ def get_init_kwargs( estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(estimator_init_kwargs) - estimator_init_kwargs = _add_sagemaker_session_to_kwargs(estimator_init_kwargs) + estimator_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( + estimator_init_kwargs + ) estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs) @@ -438,12 +441,17 @@ def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: return kwargs -def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: +def _add_sagemaker_session_with_custom_user_agent_to_kwargs( + kwargs: JumpStartKwargs, +) -> JumpStartKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" kwargs.sagemaker_session = ( kwargs.sagemaker_session or get_default_jumpstart_session_with_user_agent_suffix( - kwargs.model_id, kwargs.model_version, kwargs.hub_arn + model_id=kwargs.model_id, + model_version=kwargs.model_version, + config_name=None, + is_hub_content=kwargs.hub_arn is not None, ) ) return kwargs @@ -903,21 +911,16 @@ def _add_config_name_to_kwargs( ) -> JumpStartEstimatorInitKwargs: """Sets tags in kwargs based on default or override, returns full kwargs.""" - specs = verify_model_region_and_return_specs( + kwargs.config_name = kwargs.config_name or get_top_ranked_config_name( + region=kwargs.region, model_id=kwargs.model_id, - version=kwargs.model_version, + model_version=kwargs.model_version, + sagemaker_session=kwargs.sagemaker_session, scope=JumpStartScriptScope.TRAINING, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + model_type=kwargs.model_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - config_name=kwargs.config_name, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, hub_arn=kwargs.hub_arn, ) - if specs.training_configs and specs.training_configs.get_top_config_from_ranking(): - kwargs.config_name = ( - kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name - ) - return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 3f0b627d5d..49d9c93dd4 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -29,6 +29,7 @@ ) from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, INFERENCE_ENTRY_POINT_SCRIPT_NAME, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, @@ -54,6 +55,7 @@ add_jumpstart_model_info_tags, get_default_jumpstart_session_with_user_agent_suffix, get_neo_content_bucket, + get_top_ranked_config_name, update_dict_if_key_not_present, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, @@ -155,7 +157,7 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni return kwargs -def _add_sagemaker_session_to_kwargs( +def _add_sagemaker_session_with_custom_user_agent_to_kwargs( kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs] ) -> JumpStartModelInitKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" @@ -163,7 +165,10 @@ def _add_sagemaker_session_to_kwargs( kwargs.sagemaker_session = ( kwargs.sagemaker_session or get_default_jumpstart_session_with_user_agent_suffix( - kwargs.model_id, kwargs.model_version, kwargs.hub_arn + model_id=kwargs.model_id, + model_version=kwargs.model_version, + config_name=kwargs.config_name, + is_hub_content=kwargs.hub_arn is not None, ) ) @@ -244,6 +249,32 @@ def _add_instance_type_to_kwargs( kwargs.instance_type, ) + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + config_name=kwargs.config_name, + ) + + if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs: + return kwargs + + resolved_config = ( + specs.inference_configs.configs[kwargs.config_name].resolved_config + if specs.inference_configs + else None + ) + if resolved_config is None: + return kwargs + supported_instance_types = resolved_config.get("supported_inference_instance_types", []) + if kwargs.instance_type not in supported_instance_types: + JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) + return kwargs @@ -662,38 +693,25 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta ValueError: If the instance_type is not supported with the current config. """ - specs = verify_model_region_and_return_specs( + # we need to create a default JS session (without custom user agent) + # in order to retrieve config name info + temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION + + kwargs.config_name = kwargs.config_name or get_top_ranked_config_name( + region=kwargs.region, model_id=kwargs.model_id, - version=kwargs.model_version, + model_version=kwargs.model_version, + sagemaker_session=temp_session, scope=JumpStartScriptScope.INFERENCE, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, - config_name=kwargs.config_name, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, hub_arn=kwargs.hub_arn, ) - if specs.inference_configs: - default_config_name = ( - specs.inference_configs.get_top_config_from_ranking().config_name - if specs.inference_configs.get_top_config_from_ranking() - else None - ) - kwargs.config_name = kwargs.config_name or default_config_name - - if not kwargs.config_name: - return kwargs - if kwargs.config_name not in set(specs.inference_configs.configs.keys()): - raise ValueError( - f"Config {kwargs.config_name} is not supported for model {kwargs.model_id}." - ) + if kwargs.config_name is None: + return kwargs - resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config - supported_instance_types = resolved_config.get("supported_inference_instance_types", []) - if kwargs.instance_type not in supported_instance_types: - JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) return kwargs @@ -746,32 +764,41 @@ def _add_config_name_to_deploy_kwargs( ValueError: If the instance_type is not supported with the current config. """ - specs = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - scope=JumpStartScriptScope.INFERENCE, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - config_name=kwargs.config_name, - hub_arn=kwargs.hub_arn, - ) + # we need to create a default JS session (without custom user agent) + # in order to retrieve config name info + temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION if training_config_name: - kwargs.config_name = _select_inference_config_from_training_config( + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=temp_session, + model_type=kwargs.model_type, + config_name=kwargs.config_name, + ) + default_config_name = _select_inference_config_from_training_config( specs=specs, training_config_name=training_config_name ) - return kwargs - if specs.inference_configs: - default_config_name = ( - specs.inference_configs.get_top_config_from_ranking().config_name - if specs.inference_configs.get_top_config_from_ranking() - else None + else: + default_config_name = get_top_ranked_config_name( + region=kwargs.region, + model_id=kwargs.model_id, + model_version=kwargs.model_version, + sagemaker_session=temp_session, + scope=JumpStartScriptScope.INFERENCE, + model_type=kwargs.model_type, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + hub_arn=kwargs.hub_arn, ) - kwargs.config_name = kwargs.config_name or default_config_name + + kwargs.config_name = kwargs.config_name or default_config_name return kwargs @@ -850,15 +877,15 @@ def get_deploy_kwargs( routing_config=routing_config, ) - deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_config_name_to_deploy_kwargs( + kwargs=deploy_kwargs, training_config_name=training_config_name + ) deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs) - deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs) - deploy_kwargs = _add_config_name_to_deploy_kwargs( - kwargs=deploy_kwargs, training_config_name=training_config_name - ) + deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs) @@ -1041,11 +1068,14 @@ def get_init_kwargs( ) model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( + kwargs=model_init_kwargs + ) model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_instance_type_to_kwargs( @@ -1073,8 +1103,6 @@ def get_init_kwargs( model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs) return model_init_kwargs diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index de1b6656b4..88f1dd59e3 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2458,7 +2458,7 @@ def __init__( self.model_id = model_id self.model_version = model_version self.hub_arn = hub_arn - self.model_type = (model_type,) + self.model_type = model_type self.instance_type = instance_type self.instance_count = instance_count self.region = region diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 014f60ae8a..bebf14d5c0 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1111,11 +1111,15 @@ def get_jumpstart_configs( def get_jumpstart_user_agent_extra_suffix( - model_id: Optional[str], model_version: Optional[str], is_hub_content: Optional[bool] + model_id: Optional[str], + model_version: Optional[str], + config_name: Optional[str], + is_hub_content: Optional[bool], ) -> str: """Returns the model-specific user agent string to be added to requests.""" sagemaker_python_sdk_headers = get_user_agent_extra_suffix() jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}" + config_specific_suffix = f"md/js_config#{config_name}" hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}" if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None): @@ -1130,19 +1134,74 @@ def get_jumpstart_user_agent_extra_suffix( else: headers = f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}" + if config_name: + headers = f"{headers} {config_specific_suffix}" + return headers +def get_top_ranked_config_name( + region: str, + model_id: str, + model_version: str, + sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, + model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, + tolerate_deprecated_model: bool = False, + tolerate_vulnerable_model: bool = False, + hub_arn: Optional[str] = None, + ranking_name: enums.JumpStartConfigRankingName = enums.JumpStartConfigRankingName.DEFAULT, +) -> Optional[str]: + """Returns the top ranked config name for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=scope, + region=region, + hub_arn=hub_arn, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, + model_type=model_type, + ) + + if scope == enums.JumpStartScriptScope.INFERENCE: + return ( + model_specs.inference_configs.get_top_config_from_ranking( + ranking_name=ranking_name + ).config_name + if model_specs.inference_configs + else None + ) + if scope == enums.JumpStartScriptScope.TRAINING: + return ( + model_specs.training_configs.get_top_config_from_ranking( + ranking_name=ranking_name + ).config_name + if model_specs.training_configs + else None + ) + raise ValueError(f"Unsupported script scope: {scope}.") + + def get_default_jumpstart_session_with_user_agent_suffix( model_id: Optional[str] = None, model_version: Optional[str] = None, + config_name: Optional[str] = None, is_hub_content: Optional[bool] = False, ) -> Session: """Returns default JumpStart SageMaker Session with model-specific user agent suffix.""" botocore_session = botocore.session.get_session() botocore_config = botocore.config.Config( user_agent_extra=get_jumpstart_user_agent_extra_suffix( - model_id, model_version, is_hub_content + model_id=model_id, + model_version=model_version, + config_name=config_name, + is_hub_content=is_hub_content, ), ) botocore_session.set_default_client_config(botocore_config) diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 5ee0abd41f..7733041579 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -396,3 +396,23 @@ def test_jumpstart_model_with_deployment_configs(setup): response = predictor.predict(payload, custom_attributes="accept_eula=true") assert response is not None + + +def test_jumpstart_session_with_config_name(): + model = JumpStartModel(model_id="meta-textgeneration-llama-2-7b") + assert model.config_name is not None + session = model.sagemaker_session + + # we're mocking the http request, so it's expected to raise an Exception. + # we're interested that the low-level request attaches the correct + # jumpstart-related tags. + with mock.patch("botocore.client.BaseClient._make_request") as mock_make_request: + try: + session.sagemaker_client.list_endpoints() + except Exception: + pass + + assert ( + "md/js_model_id#meta-textgeneration-llama-2-7b md/js_model_ver#* md/js_config#tgi" + in mock_make_request.call_args[0][1]["headers"]["User-Agent"] + ) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 3678685db5..fbf76d1c98 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1637,7 +1637,7 @@ def test_training_passes_role_to_deploy( @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix", - sagemaker_session, + lambda *largs, **kwargs: sagemaker_session, ) @mock.patch( "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix", diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 07c49a308c..cbf918dee8 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1890,20 +1890,24 @@ class TestUserAgent: def test_get_jumpstart_user_agent_extra_suffix(self, mock_getenv): mock_getenv.return_value = False assert utils.get_jumpstart_user_agent_extra_suffix( - "some-id", "some-version", "False" + "some-id", "some-version", None, "False" ).endswith("md/js_model_id#some-id md/js_model_ver#some-version") mock_getenv.return_value = None assert utils.get_jumpstart_user_agent_extra_suffix( - "some-id", "some-version", "False" + "some-id", "some-version", None, "False" ).endswith("md/js_model_id#some-id md/js_model_ver#some-version") mock_getenv.return_value = "True" assert not utils.get_jumpstart_user_agent_extra_suffix( - "some-id", "some-version", "True" + "some-id", "some-version", None, "True" ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#True") mock_getenv.return_value = True assert not utils.get_jumpstart_user_agent_extra_suffix( - "some-id", "some-version", "True" + "some-id", "some-version", None, "True" ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#True") + mock_getenv.return_value = False + assert utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", "some-config", "False" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_config#some-config") @patch("sagemaker.jumpstart.utils.botocore.session") @patch("sagemaker.jumpstart.utils.botocore.config.Config") @@ -1923,7 +1927,10 @@ def test_get_default_jumpstart_session_with_user_agent_suffix( utils.get_default_jumpstart_session_with_user_agent_suffix("model_id", "model_version") mock_boto3_session.get_session.assert_called_once_with() mock_get_jumpstart_user_agent_extra_suffix.assert_called_once_with( - "model_id", "model_version", False + model_id="model_id", + model_version="model_version", + config_name=None, + is_hub_content=False, ) mock_botocore_config.assert_called_once_with( user_agent_extra=mock_get_jumpstart_user_agent_extra_suffix.return_value