diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index d987216872..69a468a0b4 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -451,12 +451,14 @@ def from_json(self, json_obj: str) -> None: class HubModelDocument(HubDataHolderType): """Data class for model type HubContentDocument from session.describe_hub_content().""" - SCHEMA_VERSION = "2.2.0" + SCHEMA_VERSION = "2.3.0" __slots__ = [ "url", "min_sdk_version", "training_supported", + "model_types", + "capabilities", "incremental_training_supported", "dynamic_container_deployment_supported", "hosting_ecr_uri", @@ -469,6 +471,7 @@ class HubModelDocument(HubDataHolderType): "hosting_use_script_uri", "hosting_eula_uri", "hosting_model_package_arn", + "model_subscription_link", "inference_configs", "inference_config_components", "inference_config_rankings", @@ -549,18 +552,22 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of hub model document. """ - self.url: str = json_obj["Url"] - self.min_sdk_version: str = json_obj["MinSdkVersion"] - self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"] - self.hosting_artifact_uri = json_obj["HostingArtifactUri"] - self.hosting_script_uri = json_obj["HostingScriptUri"] - self.inference_dependencies: List[str] = json_obj["InferenceDependencies"] + self.url: str = json_obj.get("Url") + self.min_sdk_version: str = json_obj.get("MinSdkVersion") + self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri") + self.hosting_artifact_uri = json_obj.get("HostingArtifactUri") + self.hosting_script_uri = json_obj.get("HostingScriptUri") + self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies") self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [ JumpStartEnvironmentVariable(env_variable, is_hub_content=True) - for env_variable in json_obj["InferenceEnvironmentVariables"] + for env_variable in json_obj.get("InferenceEnvironmentVariables", []) ] - self.training_supported: bool = bool(json_obj["TrainingSupported"]) - self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"]) + self.model_types: Optional[List[str]] = json_obj.get("ModelTypes") + self.capabilities: Optional[List[str]] = json_obj.get("Capabilities") + self.training_supported: bool = bool(json_obj.get("TrainingSupported")) + self.incremental_training_supported: bool = bool( + json_obj.get("IncrementalTrainingSupported") + ) self.dynamic_container_deployment_supported: Optional[bool] = ( bool(json_obj.get("DynamicContainerDeploymentSupported")) if json_obj.get("DynamicContainerDeploymentSupported") @@ -586,6 +593,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri") self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn") + self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink") + self.inference_config_rankings = self._get_config_rankings(json_obj) self.inference_config_components = self._get_config_components(json_obj) self.inference_configs = self._get_configs(json_obj) diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py index 54147dd8e6..a720b02a17 100644 --- a/src/sagemaker/jumpstart/hub/parser_utils.py +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -19,12 +19,11 @@ def camel_to_snake(camel_case_string: str) -> str: - """Converts camelCaseString or UpperCamelCaseString to snake_case_string.""" - snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string) - if "-" in snake_case_string: - # remove any hyphen from the string for accurate conversion. - snake_case_string = snake_case_string.replace("-", "") - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower() + """Converts PascalCase to snake_case_string using a regex. + + This regex cannot handle whitespace ("PascalString TwoWords") + """ + return re.sub(r"(? str: diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py index 28c2d9b32d..51da974217 100644 --- a/src/sagemaker/jumpstart/hub/parsers.py +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -137,6 +137,8 @@ def make_model_specs_from_describe_hub_content_response( hub_model_document: HubModelDocument = response.hub_content_document specs["url"] = hub_model_document.url specs["min_sdk_version"] = hub_model_document.min_sdk_version + specs["model_types"] = hub_model_document.model_types + specs["capabilities"] = hub_model_document.capabilities specs["training_supported"] = bool(hub_model_document.training_supported) specs["incremental_training_supported"] = bool( hub_model_document.incremental_training_supported @@ -146,15 +148,19 @@ def make_model_specs_from_describe_hub_content_response( specs["inference_config_components"] = hub_model_document.inference_config_components specs["inference_config_rankings"] = hub_model_document.inference_config_rankings - hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable - hub_model_document.hosting_artifact_uri - ) - specs["hosting_artifact_key"] = hosting_artifact_key - specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri - hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable - hub_model_document.hosting_script_uri - ) - specs["hosting_script_key"] = hosting_script_key + if hub_model_document.hosting_artifact_uri: + _, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_artifact_uri + ) + specs["hosting_artifact_key"] = hosting_artifact_key + specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri + + if hub_model_document.hosting_script_uri: + _, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_script_uri + ) + specs["hosting_script_key"] = hosting_script_key + specs["inference_environment_variables"] = hub_model_document.inference_environment_variables specs["inference_vulnerable"] = False specs["inference_dependencies"] = hub_model_document.inference_dependencies @@ -220,6 +226,8 @@ def make_model_specs_from_describe_hub_content_response( if hub_model_document.hosting_model_package_arn: specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn} + specs["model_subscription_link"] = hub_model_document.model_subscription_link + specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 535bc5e9be..77540926c6 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -14,7 +14,7 @@ """This module contains utilities related to SageMaker JumpStart Hub.""" from __future__ import absolute_import import re -from typing import Optional +from typing import Optional, List, Any from sagemaker.jumpstart.hub.types import S3ObjectLocation from sagemaker.s3_utils import parse_s3_url from sagemaker.session import Session @@ -23,6 +23,14 @@ from sagemaker.jumpstart import constants from packaging.specifiers import SpecifierSet, InvalidSpecifier +PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:" + + +def _convert_str_to_optional(string: str) -> Optional[str]: + if string == "None": + string = None + return string + def get_info_from_hub_resource_arn( arn: str, @@ -37,7 +45,7 @@ def get_info_from_hub_resource_arn( hub_name = match.group(4) hub_content_type = match.group(5) hub_content_name = match.group(6) - hub_content_version = match.group(7) + hub_content_version = _convert_str_to_optional(match.group(7)) return HubArnExtractedInfo( partition=partition, @@ -194,10 +202,14 @@ def get_hub_model_version( hub_model_version: Optional[str] = None, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: - """Returns available Jumpstart hub model version + """Returns available Jumpstart hub model version. + + It will attempt both a semantic HubContent version search and Marketplace version search. + If the Marketplace version is also semantic, this function will default to HubContent version. Raises: ClientError: If the specified model is not found in the hub. + KeyError: If the specified model version is not found. """ try: @@ -207,6 +219,22 @@ def get_hub_model_version( except Exception as ex: raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") + try: + return _get_hub_model_version_for_open_weight_version( + hub_content_summaries, hub_model_version + ) + except KeyError: + marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version( + hub_content_summaries, hub_model_version + ) + if marketplace_hub_content_version: + return marketplace_hub_content_version + raise + + +def _get_hub_model_version_for_open_weight_version( + hub_content_summaries: List[Any], hub_model_version: Optional[str] = None +) -> str: available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries] if hub_model_version == "*" or hub_model_version is None: @@ -222,3 +250,37 @@ def get_hub_model_version( hub_model_version = str(max(available_versions_filtered)) return hub_model_version + + +def _get_hub_model_version_for_marketplace_version( + hub_content_summaries: List[Any], marketplace_version: str +) -> Optional[str]: + """Returns the HubContent version associated with the Marketplace version. + + This function will check within the HubContentSearchKeywords for the proprietary version. + """ + for model in hub_content_summaries: + model_search_keywords = model.get("HubContentSearchKeywords", []) + if _hub_search_keywords_contains_marketplace_version( + model_search_keywords, marketplace_version + ): + return model.get("HubContentVersion") + + return None + + +def _hub_search_keywords_contains_marketplace_version( + model_search_keywords: List[str], marketplace_version: str +) -> bool: + proprietary_version_keyword = next( + filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None + ) + + if not proprietary_version_keyword: + return False + + proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD) + if proprietary_version == marketplace_version: + return True + + return False diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index f3313b3862..7e075e6b8a 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1200,6 +1200,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "url", "version", "min_sdk_version", + "model_types", + "capabilities", "incremental_training_supported", "hosting_ecr_specs", "hosting_ecr_uri", @@ -1287,6 +1289,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj.get("incremental_training_supported", False) ) if self._is_hub_content: + self.capabilities: Optional[List[str]] = json_obj.get("capabilities") + self.model_types: Optional[List[str]] = json_obj.get("model_types") self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri") self._non_serializable_slots.append("hosting_ecr_specs") else: diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 4d784c8275..b33d6563e5 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -856,7 +856,16 @@ def validate_model_id_and_get_type( if not isinstance(model_id, str): return None if hub_arn: - return None + model_types = _validate_hub_service_model_id_and_get_type( + model_id=model_id, + hub_arn=hub_arn, + region=region, + model_version=model_version, + sagemaker_session=sagemaker_session, + ) + return ( + model_types[0] if model_types else None + ) # Currently this function only supports one model type s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME @@ -881,6 +890,37 @@ def validate_model_id_and_get_type( return None +def _validate_hub_service_model_id_and_get_type( + model_id: Optional[str], + hub_arn: str, + region: Optional[str] = None, + model_version: Optional[str] = None, + sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> List[enums.JumpStartModelType]: + """Returns a list of JumpStartModelType based off the HubContent. + + Only returns valid JumpStartModelType. Returns an empty array if none are found. + """ + hub_model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + region=region, + model_id=model_id, + version=model_version, + hub_arn=hub_arn, + sagemaker_session=sagemaker_session, + ) + + hub_content_model_types = [] + model_types_field: Optional[List[str]] = getattr(hub_model_specs, "model_types", []) + model_types = model_types_field if model_types_field else [] + for model_type in model_types: + try: + hub_content_model_types.append(enums.JumpStartModelType[model_type]) + except ValueError: + continue + + return hub_content_model_types + + def _extract_value_from_list_of_tags( tag_keys: List[str], list_tags_result: List[str], diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index ec98786da4..c3dd9c96fb 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -53,6 +53,8 @@ "ap-southeast-2", } +TEST_HUB_WITH_REFERENCE = "mock-hub-name" + def test_non_prepacked_jumpstart_model(setup): diff --git a/tests/integ/sagemaker/jumpstart/utils.py b/tests/integ/sagemaker/jumpstart/utils.py index 47dc1f45d3..d439ef7e95 100644 --- a/tests/integ/sagemaker/jumpstart/utils.py +++ b/tests/integ/sagemaker/jumpstart/utils.py @@ -53,23 +53,18 @@ def get_sm_session() -> Session: return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)) -# def get_sm_session_with_override() -> Session: -# # [TODO]: Remove service endpoint override before GA -# # boto3.set_stream_logger(name='botocore', level=logging.DEBUG) -# boto_session = boto3.Session(region_name="us-west-2") -# sagemaker = boto3.client( -# service_name="sagemaker-internal", -# endpoint_url="https://sagemaker.beta.us-west-2.ml-platform.aws.a2z.com", -# ) -# sagemaker_runtime = boto3.client( -# service_name="runtime.maeve", -# endpoint_url="https://maeveruntime.beta.us-west-2.ml-platform.aws.a2z.com", -# ) -# return Session( -# boto_session=boto_session, -# sagemaker_client=sagemaker, -# sagemaker_runtime_client=sagemaker_runtime, -# ) +def get_sm_session_with_override() -> Session: + # [TODO]: Remove service endpoint override before GA + # boto3.set_stream_logger(name='botocore', level=logging.DEBUG) + boto_session = boto3.Session(region_name="us-west-2") + sagemaker = boto3.client( + service_name="sagemaker", + endpoint_url="https://sagemaker.gamma.us-west-2.ml-platform.aws.a2z.com", + ) + return Session( + boto_session=boto_session, + sagemaker_client=sagemaker, + ) def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict: diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 9117b2d26d..d22428f4f0 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -9178,6 +9178,7 @@ "TrainingArtifactS3DataType": "S3Prefix", "TrainingArtifactCompressionType": "None", "TrainingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "ModelTypes": ["OPEN_WEIGHTS", "PROPRIETARY"], "Hyperparameters": [ { "Name": "peft_type", diff --git a/tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py b/tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py new file mode 100644 index 0000000000..49d97d177d --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py @@ -0,0 +1,132 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from unittest.mock import patch, MagicMock +from mock import Mock +from sagemaker.jumpstart.hub import utils as hub_utils +from sagemaker.jumpstart.enums import JumpStartModelType +from sagemaker.jumpstart.utils import _validate_hub_service_model_id_and_get_type + +REGION = "us-east-1" +ACCOUNT_ID = "123456789123" +HUB_NAME = "mock-hub-name" + +MOCK_MODEL_ID = "test-model-id" + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name="boto_session") + sagemaker_session_mock = Mock( + name="sagemaker_session", boto_session=boto_mock, boto_region_name=REGION + ) + sagemaker_session_mock._client_config.user_agent = ( + "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" + ) + sagemaker_session_mock.account_id.return_value = ACCOUNT_ID + return sagemaker_session_mock + + +@pytest.mark.parametrize( + "input_version, expected_version, expected_exception, expected_message", + [ + ("1.0.0", "1.0.0", None, None), + ("*", "3.2.0", None, None), + (None, "3.2.0", None, None), + ("1.*", "1.1.0", None, None), + ("240612.4", "2.0.0", None, None), + ("3.0.0", "3.0.0", None, None), + ("4.0.0", "3.2.0", None, None), + ("5.0.0", None, KeyError, "Model version not available in the Hub"), + ("Blah", None, KeyError, "Bad semantic version"), + ], +) +def test_proprietary_model( + input_version, expected_version, expected_exception, expected_message, sagemaker_session +): + sagemaker_session.list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0", "HubContentSearchKeywords": []}, + {"HubContentVersion": "1.1.0", "HubContentSearchKeywords": []}, + { + "HubContentVersion": "2.0.0", + "HubContentSearchKeywords": ["@marketplace-version:240612.4"], + }, + { + "HubContentVersion": "3.0.0", + "HubContentSearchKeywords": ["@marketplace-version:240612.5"], + }, + { + "HubContentVersion": "3.1.0", + "HubContentSearchKeywords": ["@marketplace-version:3.0.0"], + }, + { + "HubContentVersion": "3.2.0", + "HubContentSearchKeywords": ["@marketplace-version:4.0.0"], + }, + ] + } + + if expected_exception: + with pytest.raises(expected_exception, match=expected_message): + _test_proprietary_model(input_version, expected_version, sagemaker_session) + else: + _test_proprietary_model(input_version, expected_version, sagemaker_session) + + +def _test_proprietary_model(input_version, expected_version, sagemaker_session): + result = hub_utils.get_hub_model_version( + hub_model_name=MOCK_MODEL_ID, + hub_model_type="Model", + hub_name="blah", + sagemaker_session=sagemaker_session, + hub_model_version=input_version, + ) + + assert result == expected_version + + +@pytest.mark.parametrize( + "get_model_specs_attr, get_model_specs_response, expected, expected_exception, expected_message", + [ + (False, None, [], None, None), + (True, None, [], None, None), + (True, [], [], None, None), + (True, ["OPEN_WEIGHTS"], [JumpStartModelType.OPEN_WEIGHTS], None, None), + ( + True, + ["OPEN_WEIGHTS", "PROPRIETARY"], + [JumpStartModelType.OPEN_WEIGHTS, JumpStartModelType.PROPRIETARY], + None, + None, + ), + ], +) +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_validate_hub_service_model_id_and_get_type( + mock_get_model_specs, + get_model_specs_attr, + get_model_specs_response, + expected, + expected_exception, + expected_message, +): + mock_object = MagicMock() + if get_model_specs_attr: + mock_object.model_types = get_model_specs_response + mock_get_model_specs.return_value = mock_object + + result = _validate_hub_service_model_id_and_get_type(model_id="blah", hub_arn="blah") + + assert result == expected diff --git a/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py new file mode 100644 index 0000000000..4412ad467e --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py @@ -0,0 +1,52 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from sagemaker.jumpstart.hub.parser_utils import camel_to_snake +from sagemaker.jumpstart.hub.parsers import make_model_specs_from_describe_hub_content_response +from sagemaker.jumpstart.hub.interfaces import HubModelDocument +from tests.unit.sagemaker.jumpstart.constants import HUB_MODEL_DOCUMENT_DICTS +from unittest.mock import MagicMock +from sagemaker.jumpstart.types import HubContentType + + +REGION = "us-east-1" +ACCOUNT_ID = "123456789123" +HUB_NAME = "mock-hub-name" + + +@pytest.mark.parametrize( + "input_string, expected", + [ + ("PascalCase", "pascal_case"), + ("already_snake", "already_snake"), + ("", ""), + ("A", "a"), + ("PascalCase123", "pascal_case123"), + ("123StartWithNumber", "123_start_with_number"), + ], +) +def test_parse_(input_string, expected): + assert expected == camel_to_snake(input_string) + + +def test_make_model_specs_from_describe_hub_content_response(): + mock_describe_response = MagicMock() + region = "us-west-2" + mock_describe_response.hub_content_type = HubContentType.MODEL + mock_describe_response.get_hub_region.return_value = region + mock_describe_response.hub_content_version = "1.0.0" + json_obj = HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"] + mock_describe_response.hub_content_document = HubModelDocument(json_obj=json_obj, region=region) + + make_model_specs_from_describe_hub_content_response(mock_describe_response)