diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 66003c9f03..20a2d16c15 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -14,6 +14,7 @@ """This module contains accessors related to SageMaker JumpStart.""" from __future__ import absolute_import import functools +import logging from typing import Any, Dict, List, Optional import boto3 @@ -289,15 +290,6 @@ def get_model_specs( if hub_arn: try: - hub_model_arn = construct_hub_model_arn_from_inputs( - hub_arn=hub_arn, model_name=model_id, version=version - ) - model_specs = JumpStartModelsAccessor._cache.get_hub_model( - hub_model_arn=hub_model_arn - ) - model_specs.set_hub_content_type(HubContentType.MODEL) - return model_specs - except: # noqa: E722 hub_model_arn = construct_hub_model_reference_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) @@ -307,6 +299,21 @@ def get_model_specs( model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE) return model_specs + except Exception as ex: + logging.info( + "Received exeption while calling APIs for ContentType ModelReference, \ + retrying with ContentType Model: " + + str(ex) + ) + hub_model_arn = construct_hub_model_arn_from_inputs( + hub_arn=hub_arn, model_name=model_id, version=version + ) + model_specs = JumpStartModelsAccessor._cache.get_hub_model( + hub_model_arn=hub_model_arn + ) + model_specs.set_hub_content_type(HubContentType.MODEL) + return model_specs + return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, version_str=version, model_type=model_type ) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 8540f53ca4..0d156c415f 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -29,6 +29,10 @@ _retrieve_model_package_model_artifact_s3_uri, ) from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base +from sagemaker.jumpstart.hub.utils import ( + construct_hub_model_arn_from_inputs, + construct_hub_model_reference_arn_from_inputs, +) from sagemaker.session import Session from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -52,6 +56,7 @@ from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.factory import model from sagemaker.jumpstart.types import ( + HubContentType, JumpStartEstimatorDeployKwargs, JumpStartEstimatorFitKwargs, JumpStartEstimatorInitKwargs, @@ -203,6 +208,11 @@ def get_init_kwargs( estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs) + if hub_arn: + estimator_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=estimator_init_kwargs) + else: + estimator_init_kwargs.model_reference_arn = None + estimator_init_kwargs.hub_content_type = None estimator_init_kwargs = _add_model_uri_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_source_dir_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_entry_point_to_kwargs(estimator_init_kwargs) @@ -433,7 +443,7 @@ def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs kwargs.sagemaker_session = ( kwargs.sagemaker_session or get_default_jumpstart_session_with_user_agent_suffix( - kwargs.model_id, kwargs.model_version + kwargs.model_id, kwargs.model_version, kwargs.hub_arn ) ) return kwargs @@ -528,7 +538,15 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima ) if kwargs.hub_arn: - kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn) + if kwargs.model_reference_arn: + hub_content_arn = construct_hub_model_reference_arn_from_inputs( + kwargs.hub_arn, kwargs.model_id, kwargs.model_version + ) + else: + hub_content_arn = construct_hub_model_arn_from_inputs( + kwargs.hub_arn, kwargs.model_id, kwargs.model_version + ) + kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn) return kwargs @@ -553,6 +571,33 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE return kwargs +def _add_model_reference_arn_to_kwargs( + kwargs: JumpStartEstimatorInitKwargs, +) -> JumpStartEstimatorInitKwargs: + """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" + + hub_content_type = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + hub_arn=kwargs.hub_arn, + scope=JumpStartScriptScope.TRAINING, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + ).hub_content_type + kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None + + if hub_content_type == HubContentType.MODEL_REFERENCE: + kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs( + hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version + ) + else: + kwargs.model_reference_arn = None + return kwargs + + def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs: """Sets model uri in kwargs based on default or override, returns full kwargs.""" diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 61fcff242f..f4e13de6d7 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -34,7 +34,10 @@ JUMPSTART_LOGGER, ) from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard -from sagemaker.jumpstart.hub.utils import construct_hub_model_reference_arn_from_inputs +from sagemaker.jumpstart.hub.utils import ( + construct_hub_model_arn_from_inputs, + construct_hub_model_reference_arn_from_inputs, +) from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines @@ -156,12 +159,14 @@ def _add_sagemaker_session_to_kwargs( kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs] ) -> JumpStartModelInitKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" + kwargs.sagemaker_session = ( kwargs.sagemaker_session or get_default_jumpstart_session_with_user_agent_suffix( - kwargs.model_id, kwargs.model_version + kwargs.model_id, kwargs.model_version, kwargs.hub_arn ) ) + return kwargs @@ -273,6 +278,7 @@ def _add_model_reference_arn_to_kwargs( kwargs: JumpStartModelInitKwargs, ) -> JumpStartModelInitKwargs: """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" + hub_content_type = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, @@ -573,7 +579,15 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: ) if kwargs.hub_arn: - kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn) + if kwargs.model_reference_arn: + hub_content_arn = construct_hub_model_reference_arn_from_inputs( + kwargs.hub_arn, kwargs.model_id, kwargs.model_version + ) + else: + hub_content_arn = construct_hub_model_arn_from_inputs( + kwargs.hub_arn, kwargs.model_id, kwargs.model_version + ) + kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn) return kwargs diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index d208220965..69d1dbb5c1 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -23,7 +23,6 @@ from sagemaker.session import Session from sagemaker.jumpstart.constants import ( - DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER, ) from sagemaker.jumpstart.types import ( @@ -68,7 +67,7 @@ def __init__( self, hub_name: str, bucket_name: Optional[str] = None, - sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + sagemaker_session: Optional[Session] = None, ) -> None: """Instantiates a SageMaker ``Hub``. @@ -79,7 +78,10 @@ def __init__( """ self.hub_name = hub_name self.region = sagemaker_session.boto_region_name - self._sagemaker_session = sagemaker_session + self._sagemaker_session = ( + sagemaker_session + or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True) + ) self.hub_storage_location = self._generate_hub_storage_location(bucket_name) def _fetch_hub_bucket_name(self) -> str: @@ -274,8 +276,8 @@ def describe_model( try: model_version = get_hub_model_version( hub_model_name=model_name, - hub_model_type=HubContentType.MODEL.value, - hub_name=self.hub_name, + hub_model_type=HubContentType.MODEL_REFERENCE.value, + hub_name=self.hub_name if not hub_name else hub_name, sagemaker_session=self._sagemaker_session, hub_model_version=model_version, ) @@ -284,24 +286,27 @@ def describe_model( hub_name=self.hub_name if not hub_name else hub_name, hub_content_name=model_name, hub_content_version=model_version, - hub_content_type=HubContentType.MODEL.value, + hub_content_type=HubContentType.MODEL_REFERENCE.value, ) except Exception as ex: - logging.info("Recieved expection while calling APIs for ContentType Model: " + str(ex)) + logging.info( + "Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: " + + str(ex) + ) model_version = get_hub_model_version( hub_model_name=model_name, - hub_model_type=HubContentType.MODEL_REFERENCE.value, - hub_name=self.hub_name, + hub_model_type=HubContentType.MODEL.value, + hub_name=self.hub_name if not hub_name else hub_name, sagemaker_session=self._sagemaker_session, hub_model_version=model_version, ) hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=self.hub_name, + hub_name=self.hub_name if not hub_name else hub_name, hub_content_name=model_name, hub_content_version=model_version, - hub_content_type=HubContentType.MODEL_REFERENCE.value, + hub_content_type=HubContentType.MODEL.value, ) return DescribeHubContentResponse(hub_content_description) diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 3dfe99a8c4..2624796b3f 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -193,7 +193,11 @@ def get_hub_model_version( hub_model_version: Optional[str] = None, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: - """Returns available Jumpstart hub model version""" + """Returns available Jumpstart hub model version + + Raises: + ClientError: If the specified model is not found in the hub. + """ try: hub_content_summaries = sagemaker_session.list_hub_content_versions( diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 15cfea5c86..0cb8bbd55a 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -527,6 +527,7 @@ def attach( model_id: Optional[str] = None, model_version: Optional[str] = None, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + hub_name: Optional[str] = None, ) -> "JumpStartModel": """Attaches a JumpStartModel object to an existing SageMaker Endpoint. @@ -552,6 +553,7 @@ def attach( model_id=model_id, model_version=model_version, sagemaker_session=sagemaker_session, + hub_name=hub_name, ) model.endpoint_name = endpoint_name model.inference_component_name = inference_component_name diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index fb4c157a67..ae54bc72b8 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1708,6 +1708,7 @@ def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False) Args: spec (Dict[str, Any]): Dictionary representation of spec. + is_hub_content (Optional[bool]): Whether the model is from a private hub. """ super().__init__(spec, is_hub_content) self.from_json(spec) @@ -2335,6 +2336,8 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "enable_remote_debug", "config_name", "enable_session_tag_chaining", + "hub_content_type", + "model_reference_arn", ] SERIALIZATION_EXCLUSION_SET = { @@ -2345,6 +2348,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_content_type", "config_name", } diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 83425d62b3..f521dbcc5a 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -433,12 +433,12 @@ def add_jumpstart_model_info_tags( def add_hub_content_arn_tags( tags: Optional[List[TagsDict]], - hub_arn: str, + hub_content_arn: str, ) -> Optional[List[TagsDict]]: """Adds custom Hub arn tag to JumpStart related resources.""" tags = add_single_jumpstart_tag( - hub_arn, + hub_content_arn, enums.JumpStartTag.HUB_CONTENT_ARN, tags, is_uri=False, @@ -1108,24 +1108,40 @@ def get_jumpstart_configs( ) -def get_jumpstart_user_agent_extra_suffix(model_id: str, model_version: str) -> str: +def get_jumpstart_user_agent_extra_suffix( + model_id: Optional[str], model_version: Optional[str], is_hub_content: Optional[bool] +) -> str: """Returns the model-specific user agent string to be added to requests.""" sagemaker_python_sdk_headers = get_user_agent_extra_suffix() jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}" - return ( - sagemaker_python_sdk_headers - if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None) - else f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}" - ) + hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}" + + if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None): + headers = sagemaker_python_sdk_headers + elif is_hub_content is True: + if model_id is None and model_version is None: + headers = f"{sagemaker_python_sdk_headers} {hub_specific_suffix}" + else: + headers = ( + f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix} {hub_specific_suffix}" + ) + else: + headers = f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}" + + return headers def get_default_jumpstart_session_with_user_agent_suffix( - model_id: str, model_version: str + model_id: Optional[str] = None, + model_version: Optional[str] = None, + is_hub_content: Optional[bool] = False, ) -> Session: """Returns default JumpStart SageMaker Session with model-specific user agent suffix.""" botocore_session = botocore.session.get_session() botocore_config = botocore.config.Config( - user_agent_extra=get_jumpstart_user_agent_extra_suffix(model_id, model_version), + user_agent_extra=get_jumpstart_user_agent_extra_suffix( + model_id, model_version, is_hub_content + ), ) botocore_session.set_default_client_config(botocore_config) # shallow copy to not affect default session constant diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py index e2085e5ab9..8522b33bc3 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_hub.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -182,13 +182,13 @@ def test_describe_model_success(mock_describe_hub_content_response, sagemaker_se hub.describe_model("test-model") mock_list_hub_content_versions.assert_called_with( - hub_name=HUB_NAME, hub_content_name="test-model", hub_content_type="Model" + hub_name=HUB_NAME, hub_content_name="test-model", hub_content_type="ModelReference" ) sagemaker_session.describe_hub_content.assert_called_with( hub_name=HUB_NAME, hub_content_name="test-model", hub_content_version="3.0", - hub_content_type="Model", + hub_content_type="ModelReference", ) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 56eaa0b660..baf9d19a54 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1465,6 +1465,7 @@ def test_attach( model_id="model-id", model_version="model-version", sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + hub_name=None, ) assert isinstance(val, JumpStartModel) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 533483a497..07c49a308c 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import os from unittest import TestCase +from unittest.mock import call from botocore.exceptions import ClientError from mock.mock import Mock, patch @@ -1884,6 +1885,88 @@ def test_get_jumpstart_benchmark_stats_training( } +class TestUserAgent: + @patch("sagemaker.jumpstart.utils.os.getenv") + def test_get_jumpstart_user_agent_extra_suffix(self, mock_getenv): + mock_getenv.return_value = False + assert utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", "False" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version") + mock_getenv.return_value = None + assert utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", "False" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version") + mock_getenv.return_value = "True" + assert not utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", "True" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#True") + mock_getenv.return_value = True + assert not utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", "True" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#True") + + @patch("sagemaker.jumpstart.utils.botocore.session") + @patch("sagemaker.jumpstart.utils.botocore.config.Config") + @patch("sagemaker.jumpstart.utils.get_jumpstart_user_agent_extra_suffix") + @patch("sagemaker.jumpstart.utils.boto3.Session") + @patch("sagemaker.jumpstart.utils.boto3.client") + @patch("sagemaker.jumpstart.utils.Session") + def test_get_default_jumpstart_session_with_user_agent_suffix( + self, + mock_sm_session, + mock_boto3_client, + mock_botocore_session, + mock_get_jumpstart_user_agent_extra_suffix, + mock_botocore_config, + mock_boto3_session, + ): + utils.get_default_jumpstart_session_with_user_agent_suffix("model_id", "model_version") + mock_boto3_session.get_session.assert_called_once_with() + mock_get_jumpstart_user_agent_extra_suffix.assert_called_once_with( + "model_id", "model_version", False + ) + mock_botocore_config.assert_called_once_with( + user_agent_extra=mock_get_jumpstart_user_agent_extra_suffix.return_value + ) + mock_botocore_session.assert_called_once_with( + region_name=JUMPSTART_DEFAULT_REGION_NAME, + botocore_session=mock_boto3_session.get_session.return_value, + ) + mock_boto3_client.assert_has_calls( + [ + call( + "sagemaker", + region_name=JUMPSTART_DEFAULT_REGION_NAME, + config=mock_botocore_config.return_value, + ), + call( + "sagemaker-runtime", + region_name=JUMPSTART_DEFAULT_REGION_NAME, + config=mock_botocore_config.return_value, + ), + ], + any_order=True, + ) + + @patch("botocore.client.BaseClient._make_request") + def test_get_default_jumpstart_session_with_user_agent_suffix_http_header( + self, + mock_make_request, + ): + session = utils.get_default_jumpstart_session_with_user_agent_suffix( + "model_id", "model_version" + ) + try: + session.sagemaker_client.list_endpoints() + except Exception: + pass + + assert ( + "md/js_model_id#model_id md/js_model_ver#model_version" + in mock_make_request.call_args[0][1]["headers"]["User-Agent"] + ) + + def test_extract_metrics_from_deployment_configs(): configs = get_base_deployment_configs_metadata() configs[0].benchmark_metrics = None