From 0092ff4fa9e5efac6e15af05cba9902068a5e500 Mon Sep 17 00:00:00 2001 From: chrstfu Date: Wed, 30 Oct 2024 21:31:12 +0000 Subject: [PATCH 1/8] feat: Marketplace model support in HubService --- src/sagemaker/jumpstart/factory/model.py | 2 +- src/sagemaker/jumpstart/hub/interfaces.py | 29 ++-- src/sagemaker/jumpstart/hub/parser_utils.py | 13 +- src/sagemaker/jumpstart/hub/parsers.py | 32 +++-- src/sagemaker/jumpstart/hub/utils.py | 69 +++++++++- src/sagemaker/jumpstart/types.py | 30 +++-- src/sagemaker/jumpstart/utils.py | 42 +++++- .../jumpstart/model/test_jumpstart_model.py | 2 + tests/integ/sagemaker/jumpstart/utils.py | 29 ++-- .../hub/test_marketplace_hub_content.py | 125 ++++++++++++++++++ .../jumpstart/hub/test_parser_utils.py | 34 +++++ .../sagemaker/jumpstart/hub/test_utils.py | 4 +- 12 files changed, 343 insertions(+), 68 deletions(-) create mode 100644 tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py create mode 100644 tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index ccafed844d..512a06cf64 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -41,7 +41,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability from sagemaker.jumpstart.types import ( HubContentType, JumpStartModelDeployKwargs, diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index d987216872..69a468a0b4 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -451,12 +451,14 @@ def from_json(self, json_obj: str) -> None: class HubModelDocument(HubDataHolderType): """Data class for model type HubContentDocument from session.describe_hub_content().""" - SCHEMA_VERSION = "2.2.0" + SCHEMA_VERSION = "2.3.0" __slots__ = [ "url", "min_sdk_version", "training_supported", + "model_types", + "capabilities", "incremental_training_supported", "dynamic_container_deployment_supported", "hosting_ecr_uri", @@ -469,6 +471,7 @@ class HubModelDocument(HubDataHolderType): "hosting_use_script_uri", "hosting_eula_uri", "hosting_model_package_arn", + "model_subscription_link", "inference_configs", "inference_config_components", "inference_config_rankings", @@ -549,18 +552,22 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of hub model document. """ - self.url: str = json_obj["Url"] - self.min_sdk_version: str = json_obj["MinSdkVersion"] - self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"] - self.hosting_artifact_uri = json_obj["HostingArtifactUri"] - self.hosting_script_uri = json_obj["HostingScriptUri"] - self.inference_dependencies: List[str] = json_obj["InferenceDependencies"] + self.url: str = json_obj.get("Url") + self.min_sdk_version: str = json_obj.get("MinSdkVersion") + self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri") + self.hosting_artifact_uri = json_obj.get("HostingArtifactUri") + self.hosting_script_uri = json_obj.get("HostingScriptUri") + self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies") self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [ JumpStartEnvironmentVariable(env_variable, is_hub_content=True) - for env_variable in json_obj["InferenceEnvironmentVariables"] + for env_variable in json_obj.get("InferenceEnvironmentVariables", []) ] - self.training_supported: bool = bool(json_obj["TrainingSupported"]) - self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"]) + self.model_types: Optional[List[str]] = json_obj.get("ModelTypes") + self.capabilities: Optional[List[str]] = json_obj.get("Capabilities") + self.training_supported: bool = bool(json_obj.get("TrainingSupported")) + self.incremental_training_supported: bool = bool( + json_obj.get("IncrementalTrainingSupported") + ) self.dynamic_container_deployment_supported: Optional[bool] = ( bool(json_obj.get("DynamicContainerDeploymentSupported")) if json_obj.get("DynamicContainerDeploymentSupported") @@ -586,6 +593,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri") self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn") + self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink") + self.inference_config_rankings = self._get_config_rankings(json_obj) self.inference_config_components = self._get_config_components(json_obj) self.inference_configs = self._get_configs(json_obj) diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py index 54147dd8e6..53a1f54b0a 100644 --- a/src/sagemaker/jumpstart/hub/parser_utils.py +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -18,13 +18,12 @@ from typing import Any, Dict, List, Optional -def camel_to_snake(camel_case_string: str) -> str: - """Converts camelCaseString or UpperCamelCaseString to snake_case_string.""" - snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string) - if "-" in snake_case_string: - # remove any hyphen from the string for accurate conversion. - snake_case_string = snake_case_string.replace("-", "") - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower() +def pascal_to_snake(camel_case_string: str) -> str: + """Converts PascalCase to snake_case_string using a regex. + + This regex cannot handle whitespace ("PascalString TwoWords") + """ + return re.sub(r"(? str: diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py index 28c2d9b32d..d36ca1270d 100644 --- a/src/sagemaker/jumpstart/hub/parsers.py +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -27,7 +27,7 @@ HubModelDocument, ) from sagemaker.jumpstart.hub.parser_utils import ( - camel_to_snake, + pascal_to_snake, snake_to_upper_camel, walk_and_apply_json, ) @@ -86,7 +86,7 @@ def get_model_spec_arg_keys( arg_keys = [] if naming_convention == NamingConventionType.SNAKE_CASE: - arg_keys = [camel_to_snake(key) for key in arg_keys] + arg_keys = [pascal_to_snake(key) for key in arg_keys] elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE: return arg_keys else: @@ -137,6 +137,8 @@ def make_model_specs_from_describe_hub_content_response( hub_model_document: HubModelDocument = response.hub_content_document specs["url"] = hub_model_document.url specs["min_sdk_version"] = hub_model_document.min_sdk_version + specs["model_types"] = hub_model_document.model_types + specs["capabilities"] = hub_model_document.capabilities specs["training_supported"] = bool(hub_model_document.training_supported) specs["incremental_training_supported"] = bool( hub_model_document.incremental_training_supported @@ -146,15 +148,19 @@ def make_model_specs_from_describe_hub_content_response( specs["inference_config_components"] = hub_model_document.inference_config_components specs["inference_config_rankings"] = hub_model_document.inference_config_rankings - hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable - hub_model_document.hosting_artifact_uri - ) - specs["hosting_artifact_key"] = hosting_artifact_key - specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri - hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable - hub_model_document.hosting_script_uri - ) - specs["hosting_script_key"] = hosting_script_key + if hub_model_document.hosting_artifact_uri: + _, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_artifact_uri + ) + specs["hosting_artifact_key"] = hosting_artifact_key + specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri + + if hub_model_document.hosting_script_uri: + _, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_script_uri + ) + specs["hosting_script_key"] = hosting_script_key + specs["inference_environment_variables"] = hub_model_document.inference_environment_variables specs["inference_vulnerable"] = False specs["inference_dependencies"] = hub_model_document.inference_dependencies @@ -201,7 +207,7 @@ def make_model_specs_from_describe_hub_content_response( default_payloads: Dict[str, Any] = {} if hub_model_document.default_payloads is not None: for alias, payload in hub_model_document.default_payloads.items(): - default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake) + default_payloads[alias] = walk_and_apply_json(payload.to_json(), pascal_to_snake) specs["default_payloads"] = default_payloads specs["gated_bucket"] = hub_model_document.gated_bucket specs["inference_volume_size"] = hub_model_document.inference_volume_size @@ -220,6 +226,8 @@ def make_model_specs_from_describe_hub_content_response( if hub_model_document.hosting_model_package_arn: specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn} + specs["model_subscription_link"] = hub_model_document.model_subscription_link + specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 535bc5e9be..2acbad5135 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -14,7 +14,7 @@ """This module contains utilities related to SageMaker JumpStart Hub.""" from __future__ import absolute_import import re -from typing import Optional +from typing import Optional, List, Any from sagemaker.jumpstart.hub.types import S3ObjectLocation from sagemaker.s3_utils import parse_s3_url from sagemaker.session import Session @@ -23,6 +23,14 @@ from sagemaker.jumpstart import constants from packaging.specifiers import SpecifierSet, InvalidSpecifier +PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:" + + +def _convert_str_to_optional(string: str) -> Optional[str]: + if string == "None": + string = None + return string + def get_info_from_hub_resource_arn( arn: str, @@ -37,7 +45,7 @@ def get_info_from_hub_resource_arn( hub_name = match.group(4) hub_content_type = match.group(5) hub_content_name = match.group(6) - hub_content_version = match.group(7) + hub_content_version = _convert_str_to_optional(match.group(7)) return HubArnExtractedInfo( partition=partition, @@ -194,10 +202,14 @@ def get_hub_model_version( hub_model_version: Optional[str] = None, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: - """Returns available Jumpstart hub model version + """Returns available Jumpstart hub model version. + + It will attempt both a semantic HubContent version search and Marketplace version search. + If the Marketplace version is also semantic, this function will default to HubContent version. Raises: ClientError: If the specified model is not found in the hub. + KeyError: If the specified model version is not found. """ try: @@ -207,6 +219,23 @@ def get_hub_model_version( except Exception as ex: raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") + marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version( + hub_content_summaries, hub_model_version + ) + + try: + return _get_hub_model_version_for_open_weight_version( + hub_content_summaries, hub_model_version + ) + except KeyError as e: + if marketplace_hub_content_version: + return marketplace_hub_content_version + raise e + + +def _get_hub_model_version_for_open_weight_version( + hub_content_summaries: List[Any], hub_model_version: Optional[str] = None +) -> str: available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries] if hub_model_version == "*" or hub_model_version is None: @@ -222,3 +251,37 @@ def get_hub_model_version( hub_model_version = str(max(available_versions_filtered)) return hub_model_version + + +def _get_hub_model_version_for_marketplace_version( + hub_content_summaries: List[Any], marketplace_version: str +) -> Optional[str]: + """Returns the HubContent version associated with the Marketplace version. + + This function will check within the HubContentSearchKeywords for the proprietary version. + """ + for model in hub_content_summaries: + model_search_keywords = model.get("HubContentSearchKeywords", []) + if _hub_search_keywords_contains_marketplace_version( + model_search_keywords, marketplace_version + ): + return model.get("HubContentVersion") + + return None + + +def _hub_search_keywords_contains_marketplace_version( + model_search_keywords: List[str], marketplace_version: str +) -> bool: + proprietary_version_keyword = next( + filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None + ) + + if not proprietary_version_keyword: + return False + + proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD) + if proprietary_version == marketplace_version: + return True + + return False diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index f3313b3862..b716cac057 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -39,7 +39,7 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType from sagemaker.jumpstart.hub.parser_utils import ( - camel_to_snake, + pascal_to_snake, walk_and_apply_json, ) @@ -239,7 +239,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.framework = json_obj.get("framework") self.framework_version = json_obj.get("framework_version") @@ -293,7 +293,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: """ if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -361,7 +361,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of environment variable. """ - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -411,7 +411,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.default_content_type = json_obj["default_content_type"] self.supported_content_types = json_obj["supported_content_types"] self.default_accept_type = json_obj["default_accept_type"] @@ -465,7 +465,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.raw_payload = json_obj self.content_type = json_obj["content_type"] self.body = json_obj.get("body") @@ -538,7 +538,7 @@ def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) if response is None: return - response = walk_and_apply_json(response, camel_to_snake) + response = walk_and_apply_json(response, pascal_to_snake) self.aliases: Optional[dict] = response.get("aliases") self.regional_aliases = None self.variants: Optional[dict] = response.get("variants") @@ -1174,7 +1174,7 @@ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False): spec (Dict[str, Any]): Dictionary representation of training config ranking. """ if is_hub_content: - spec = walk_and_apply_json(spec, camel_to_snake) + spec = walk_and_apply_json(spec, pascal_to_snake) self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -1200,6 +1200,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "url", "version", "min_sdk_version", + "model_types", + "capabilities", "incremental_training_supported", "hosting_ecr_specs", "hosting_ecr_uri", @@ -1278,7 +1280,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of spec. """ if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.model_id: str = json_obj.get("model_id") self.url: str = json_obj.get("url") self.version: str = json_obj.get("version") @@ -1287,6 +1289,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj.get("incremental_training_supported", False) ) if self._is_hub_content: + self.capabilities: Optional[List[str]] = json_obj.get("capabilities") + self.model_types: Optional[List[str]] = json_obj.get("model_types") self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri") self._non_serializable_slots.append("hosting_ecr_specs") else: @@ -1505,7 +1509,7 @@ def __init__( ValueError: If the component field is invalid. """ if is_hub_content: - component = walk_and_apply_json(component, camel_to_snake) + component = walk_and_apply_json(component, pascal_to_snake) self.component_name = component_name super().__init__(component, is_hub_content) self.from_json(component) @@ -1558,8 +1562,8 @@ def __init__( The list of components that are used to construct the resolved config. """ if is_hub_content: - config = walk_and_apply_json(config, camel_to_snake) - base_fields = walk_and_apply_json(base_fields, camel_to_snake) + config = walk_and_apply_json(config, pascal_to_snake) + base_fields = walk_and_apply_json(base_fields, pascal_to_snake) self.base_fields = base_fields self.config_components: Dict[str, JumpStartConfigComponent] = config_components self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = ( @@ -1725,7 +1729,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: """ super().from_json(json_obj) if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = ( { component_name: JumpStartConfigComponent(component_name, component) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 4d784c8275..245c31fb4c 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -33,7 +33,7 @@ from sagemaker.jumpstart import constants, enums from sagemaker.jumpstart import accessors -from sagemaker.jumpstart.hub.parser_utils import camel_to_snake, snake_to_upper_camel +from sagemaker.jumpstart.hub.parser_utils import pascal_to_snake, snake_to_upper_camel from sagemaker.s3 import parse_s3_url from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, @@ -856,7 +856,14 @@ def validate_model_id_and_get_type( if not isinstance(model_id, str): return None if hub_arn: - return None + model_types = _validate_hub_service_model_id_and_get_type( + model_id=model_id, + hub_arn=hub_arn, + region=region, + model_version=model_version, + sagemaker_session=sagemaker_session, + ) + return model_types[0] # Currently this function only supports one model type s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME @@ -881,6 +888,35 @@ def validate_model_id_and_get_type( return None +def _validate_hub_service_model_id_and_get_type( + model_id: Optional[str], + hub_arn: str, + region: Optional[str] = None, + model_version: Optional[str] = None, + sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> List[enums.JumpStartModelType]: + """Returns a list of JumpStartModelType based off the HubContent. + + Only returns valid JumpStartModelType. Returns an empty array if none are found. + """ + hub_model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + region=region, + model_id=model_id, + version=model_version, + hub_arn=hub_arn, + sagemaker_session=sagemaker_session, + ) + + hub_content_model_types = [] + for model_type in getattr(hub_model_specs, "model_types", []): + try: + hub_content_model_types.append(enums.JumpStartModelType[model_type]) + except ValueError: + continue + + return hub_content_model_types + + def _extract_value_from_list_of_tags( tag_keys: List[str], list_tags_result: List[str], @@ -1113,7 +1149,7 @@ def get_jumpstart_configs( return ( { config_name: metadata_configs.configs[ - camel_to_snake(snake_to_upper_camel(config_name)) + pascal_to_snake(snake_to_upper_camel(config_name)) ] for config_name in config_names } diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index ec98786da4..c3dd9c96fb 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -53,6 +53,8 @@ "ap-southeast-2", } +TEST_HUB_WITH_REFERENCE = "mock-hub-name" + def test_non_prepacked_jumpstart_model(setup): diff --git a/tests/integ/sagemaker/jumpstart/utils.py b/tests/integ/sagemaker/jumpstart/utils.py index 47dc1f45d3..d439ef7e95 100644 --- a/tests/integ/sagemaker/jumpstart/utils.py +++ b/tests/integ/sagemaker/jumpstart/utils.py @@ -53,23 +53,18 @@ def get_sm_session() -> Session: return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)) -# def get_sm_session_with_override() -> Session: -# # [TODO]: Remove service endpoint override before GA -# # boto3.set_stream_logger(name='botocore', level=logging.DEBUG) -# boto_session = boto3.Session(region_name="us-west-2") -# sagemaker = boto3.client( -# service_name="sagemaker-internal", -# endpoint_url="https://sagemaker.beta.us-west-2.ml-platform.aws.a2z.com", -# ) -# sagemaker_runtime = boto3.client( -# service_name="runtime.maeve", -# endpoint_url="https://maeveruntime.beta.us-west-2.ml-platform.aws.a2z.com", -# ) -# return Session( -# boto_session=boto_session, -# sagemaker_client=sagemaker, -# sagemaker_runtime_client=sagemaker_runtime, -# ) +def get_sm_session_with_override() -> Session: + # [TODO]: Remove service endpoint override before GA + # boto3.set_stream_logger(name='botocore', level=logging.DEBUG) + boto_session = boto3.Session(region_name="us-west-2") + sagemaker = boto3.client( + service_name="sagemaker", + endpoint_url="https://sagemaker.gamma.us-west-2.ml-platform.aws.a2z.com", + ) + return Session( + boto_session=boto_session, + sagemaker_client=sagemaker, + ) def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict: diff --git a/tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py b/tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py new file mode 100644 index 0000000000..50bd08a62b --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py @@ -0,0 +1,125 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from unittest.mock import patch, MagicMock +from mock import Mock +from sagemaker.jumpstart.hub import utils as hub_utils +from sagemaker.jumpstart.enums import JumpStartModelType +from sagemaker.jumpstart.utils import _validate_hub_service_model_id_and_get_type + +REGION = "us-east-1" +ACCOUNT_ID = "123456789123" +HUB_NAME = "mock-hub-name" + +MOCK_MODEL_ID = "test-model-id" + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name="boto_session") + sagemaker_session_mock = Mock( + name="sagemaker_session", boto_session=boto_mock, boto_region_name=REGION + ) + sagemaker_session_mock._client_config.user_agent = ( + "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" + ) + sagemaker_session_mock.account_id.return_value = ACCOUNT_ID + return sagemaker_session_mock + + +@pytest.mark.parametrize( + "input_version, expected_version, expected_exception, expected_message", + [ + ("1.0.0", "1.0.0", None, None), + ("*", "3.2.0", None, None), + (None, "3.2.0", None, None), + ("1.*", "1.1.0", None, None), + ("240612.4", "2.0.0", None, None), + ("3.0.0", "3.0.0", None, None), + ("4.0.0", "3.2.0", None, None), + ("5.0.0", None, KeyError, "Model version not available in the Hub"), + ("Blah", None, KeyError, "Bad semantic version"), + ], +) +def test_proprietary_model( + input_version, expected_version, expected_exception, expected_message, sagemaker_session +): + sagemaker_session.list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0", "HubContentSearchKeywords": []}, + {"HubContentVersion": "1.1.0", "HubContentSearchKeywords": []}, + { + "HubContentVersion": "2.0.0", + "HubContentSearchKeywords": ["@marketplace-version:240612.4"], + }, + { + "HubContentVersion": "3.0.0", + "HubContentSearchKeywords": ["@marketplace-version:240612.5"], + }, + { + "HubContentVersion": "3.1.0", + "HubContentSearchKeywords": ["@marketplace-version:3.0.0"], + }, + { + "HubContentVersion": "3.2.0", + "HubContentSearchKeywords": ["@marketplace-version:4.0.0"], + }, + ] + } + + if expected_exception: + with pytest.raises(expected_exception, match=expected_message): + _test_proprietary_model(input_version, expected_version, sagemaker_session) + else: + _test_proprietary_model(input_version, expected_version, sagemaker_session) + + +def _test_proprietary_model(input_version, expected_version, sagemaker_session): + result = hub_utils.get_hub_model_version( + hub_model_name=MOCK_MODEL_ID, + hub_model_type="Model", + hub_name="blah", + sagemaker_session=sagemaker_session, + hub_model_version=input_version, + ) + + assert result == expected_version + + +@pytest.mark.parametrize( + "get_model_specs_response, expected, expected_exception, expected_message", + [ + (None, [], None, None), + ([], [], None, None), + (["OPEN_WEIGHTS"], [JumpStartModelType.OPEN_WEIGHTS], None, None), + ( + ["OPEN_WEIGHTS", "PROPRIETARY"], + [JumpStartModelType.OPEN_WEIGHTS, JumpStartModelType.PROPRIETARY], + None, + None, + ), + ], +) +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_validate_hub_service_model_id_and_get_type( + mock_get_model_specs, get_model_specs_response, expected, expected_exception, expected_message +): + mock_object = MagicMock() + if get_model_specs_response: + mock_object.model_types = get_model_specs_response + mock_get_model_specs.return_value = mock_object + + result = _validate_hub_service_model_id_and_get_type(model_id="blah", hub_arn="blah") + + assert result == expected diff --git a/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py new file mode 100644 index 0000000000..a4f9486f6d --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py @@ -0,0 +1,34 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from sagemaker.jumpstart.hub.parser_utils import pascal_to_snake + +REGION = "us-east-1" +ACCOUNT_ID = "123456789123" +HUB_NAME = "mock-hub-name" + + +@pytest.mark.parametrize( + "input_string, expected", + [ + ("PascalCase", "pascal_case"), + ("already_snake", "already_snake"), + ("", ""), + ("A", "a"), + ("PascalCase123", "pascal_case123"), + ("123StartWithNumber", "123_start_with_number"), + ], +) +def test_parse_(input_string, expected): + assert expected == pascal_to_snake(input_string) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index 22bc527b18..0ef740398c 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -268,7 +268,7 @@ def test_walk_and_apply_json(): } result = parser_utils.walk_and_apply_json( - test_json, parser_utils.camel_to_snake, ["ignore_my_children"] + test_json, parser_utils.pascal_to_snake, ["ignore_my_children"] ) assert result == { "camel_case_key": "value", @@ -294,7 +294,7 @@ def test_walk_and_apply_json_no_stop(): "CamelCaseObjectListKey": {"instance.ml.type.xlarge": [{"ShouldChangeMe": "string"}]}, } - result = parser_utils.walk_and_apply_json(test_json, parser_utils.camel_to_snake) + result = parser_utils.walk_and_apply_json(test_json, parser_utils.pascal_to_snake) assert result == { "camel_case_key": "value", "camel_case_object_key": { From 8b0ec90f9bb2fa4715e8e973d33b67d4d1178a56 Mon Sep 17 00:00:00 2001 From: chrstfu Date: Wed, 30 Oct 2024 21:46:50 +0000 Subject: [PATCH 2/8] fix: removing field --- src/sagemaker/jumpstart/factory/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 512a06cf64..ccafed844d 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -41,7 +41,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.types import ( HubContentType, JumpStartModelDeployKwargs, From 273449c7b016f3140037ad11ff8189ab16567680 Mon Sep 17 00:00:00 2001 From: chrstfu Date: Wed, 30 Oct 2024 22:19:43 +0000 Subject: [PATCH 3/8] fix: Reverting name change for code coverage --- src/sagemaker/jumpstart/hub/parser_utils.py | 2 +- src/sagemaker/jumpstart/hub/parsers.py | 6 ++--- src/sagemaker/jumpstart/types.py | 26 +++++++++---------- src/sagemaker/jumpstart/utils.py | 4 +-- .../jumpstart/hub/test_parser_utils.py | 4 +-- .../sagemaker/jumpstart/hub/test_utils.py | 4 +-- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py index 53a1f54b0a..a720b02a17 100644 --- a/src/sagemaker/jumpstart/hub/parser_utils.py +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional -def pascal_to_snake(camel_case_string: str) -> str: +def camel_to_snake(camel_case_string: str) -> str: """Converts PascalCase to snake_case_string using a regex. This regex cannot handle whitespace ("PascalString TwoWords") diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py index d36ca1270d..51da974217 100644 --- a/src/sagemaker/jumpstart/hub/parsers.py +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -27,7 +27,7 @@ HubModelDocument, ) from sagemaker.jumpstart.hub.parser_utils import ( - pascal_to_snake, + camel_to_snake, snake_to_upper_camel, walk_and_apply_json, ) @@ -86,7 +86,7 @@ def get_model_spec_arg_keys( arg_keys = [] if naming_convention == NamingConventionType.SNAKE_CASE: - arg_keys = [pascal_to_snake(key) for key in arg_keys] + arg_keys = [camel_to_snake(key) for key in arg_keys] elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE: return arg_keys else: @@ -207,7 +207,7 @@ def make_model_specs_from_describe_hub_content_response( default_payloads: Dict[str, Any] = {} if hub_model_document.default_payloads is not None: for alias, payload in hub_model_document.default_payloads.items(): - default_payloads[alias] = walk_and_apply_json(payload.to_json(), pascal_to_snake) + default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake) specs["default_payloads"] = default_payloads specs["gated_bucket"] = hub_model_document.gated_bucket specs["inference_volume_size"] = hub_model_document.inference_volume_size diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index b716cac057..7e075e6b8a 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -39,7 +39,7 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType from sagemaker.jumpstart.hub.parser_utils import ( - pascal_to_snake, + camel_to_snake, walk_and_apply_json, ) @@ -239,7 +239,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.framework = json_obj.get("framework") self.framework_version = json_obj.get("framework_version") @@ -293,7 +293,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: """ if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -361,7 +361,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of environment variable. """ - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -411,7 +411,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.default_content_type = json_obj["default_content_type"] self.supported_content_types = json_obj["supported_content_types"] self.default_accept_type = json_obj["default_accept_type"] @@ -465,7 +465,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.raw_payload = json_obj self.content_type = json_obj["content_type"] self.body = json_obj.get("body") @@ -538,7 +538,7 @@ def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) if response is None: return - response = walk_and_apply_json(response, pascal_to_snake) + response = walk_and_apply_json(response, camel_to_snake) self.aliases: Optional[dict] = response.get("aliases") self.regional_aliases = None self.variants: Optional[dict] = response.get("variants") @@ -1174,7 +1174,7 @@ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False): spec (Dict[str, Any]): Dictionary representation of training config ranking. """ if is_hub_content: - spec = walk_and_apply_json(spec, pascal_to_snake) + spec = walk_and_apply_json(spec, camel_to_snake) self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -1280,7 +1280,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of spec. """ if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.model_id: str = json_obj.get("model_id") self.url: str = json_obj.get("url") self.version: str = json_obj.get("version") @@ -1509,7 +1509,7 @@ def __init__( ValueError: If the component field is invalid. """ if is_hub_content: - component = walk_and_apply_json(component, pascal_to_snake) + component = walk_and_apply_json(component, camel_to_snake) self.component_name = component_name super().__init__(component, is_hub_content) self.from_json(component) @@ -1562,8 +1562,8 @@ def __init__( The list of components that are used to construct the resolved config. """ if is_hub_content: - config = walk_and_apply_json(config, pascal_to_snake) - base_fields = walk_and_apply_json(base_fields, pascal_to_snake) + config = walk_and_apply_json(config, camel_to_snake) + base_fields = walk_and_apply_json(base_fields, camel_to_snake) self.base_fields = base_fields self.config_components: Dict[str, JumpStartConfigComponent] = config_components self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = ( @@ -1729,7 +1729,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: """ super().from_json(json_obj) if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = ( { component_name: JumpStartConfigComponent(component_name, component) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 245c31fb4c..da6d1445da 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -33,7 +33,7 @@ from sagemaker.jumpstart import constants, enums from sagemaker.jumpstart import accessors -from sagemaker.jumpstart.hub.parser_utils import pascal_to_snake, snake_to_upper_camel +from sagemaker.jumpstart.hub.parser_utils import camel_to_snake, snake_to_upper_camel from sagemaker.s3 import parse_s3_url from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, @@ -1149,7 +1149,7 @@ def get_jumpstart_configs( return ( { config_name: metadata_configs.configs[ - pascal_to_snake(snake_to_upper_camel(config_name)) + camel_to_snake(snake_to_upper_camel(config_name)) ] for config_name in config_names } diff --git a/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py index a4f9486f6d..4d840584e1 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import pytest -from sagemaker.jumpstart.hub.parser_utils import pascal_to_snake +from sagemaker.jumpstart.hub.parser_utils import camel_to_snake REGION = "us-east-1" ACCOUNT_ID = "123456789123" @@ -31,4 +31,4 @@ ], ) def test_parse_(input_string, expected): - assert expected == pascal_to_snake(input_string) + assert expected == camel_to_snake(input_string) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index 0ef740398c..22bc527b18 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -268,7 +268,7 @@ def test_walk_and_apply_json(): } result = parser_utils.walk_and_apply_json( - test_json, parser_utils.pascal_to_snake, ["ignore_my_children"] + test_json, parser_utils.camel_to_snake, ["ignore_my_children"] ) assert result == { "camel_case_key": "value", @@ -294,7 +294,7 @@ def test_walk_and_apply_json_no_stop(): "CamelCaseObjectListKey": {"instance.ml.type.xlarge": [{"ShouldChangeMe": "string"}]}, } - result = parser_utils.walk_and_apply_json(test_json, parser_utils.pascal_to_snake) + result = parser_utils.walk_and_apply_json(test_json, parser_utils.camel_to_snake) assert result == { "camel_case_key": "value", "camel_case_object_key": { From a8a245370b4bdd66a1bbfb2456c16b61c031fcd8 Mon Sep 17 00:00:00 2001 From: chrstfu Date: Wed, 30 Oct 2024 23:07:07 +0000 Subject: [PATCH 4/8] fix: Adding more code coverage --- src/sagemaker/jumpstart/hub/utils.py | 9 ++++----- tests/unit/sagemaker/jumpstart/constants.py | 1 + .../sagemaker/jumpstart/hub/test_parser_utils.py | 15 +++++++++++++++ 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 2acbad5135..46d15dfa84 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -219,18 +219,17 @@ def get_hub_model_version( except Exception as ex: raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") - marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version( - hub_content_summaries, hub_model_version - ) - try: return _get_hub_model_version_for_open_weight_version( hub_content_summaries, hub_model_version ) except KeyError as e: + marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version( + hub_content_summaries, hub_model_version + ) if marketplace_hub_content_version: return marketplace_hub_content_version - raise e + raise def _get_hub_model_version_for_open_weight_version( diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 9117b2d26d..d22428f4f0 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -9178,6 +9178,7 @@ "TrainingArtifactS3DataType": "S3Prefix", "TrainingArtifactCompressionType": "None", "TrainingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "ModelTypes": ["OPEN_WEIGHTS", "PROPRIETARY"], "Hyperparameters": [ { "Name": "peft_type", diff --git a/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py index 4d840584e1..6ee508b57d 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py @@ -13,6 +13,10 @@ from __future__ import absolute_import import pytest from sagemaker.jumpstart.hub.parser_utils import camel_to_snake +from sagemaker.jumpstart.hub.parsers import make_model_specs_from_describe_hub_content_response +from sagemaker.jumpstart.hub.interfaces import HubModelDocument +from tests.unit.sagemaker.jumpstart.constants import HUB_MODEL_DOCUMENT_DICTS +from unittest.mock import MagicMock REGION = "us-east-1" ACCOUNT_ID = "123456789123" @@ -32,3 +36,14 @@ ) def test_parse_(input_string, expected): assert expected == camel_to_snake(input_string) + + +def test_make_model_specs_from_describe_hub_content_response(): + mock_describe_response = MagicMock() + region = "us-west-2" + mock_describe_response.get_hub_region.return_value = region + mock_describe_response.hub_content_version = "1.0.0" + json_obj = HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"] + mock_describe_response.hub_content_document = HubModelDocument(json_obj=json_obj, region=region) + + make_model_specs_from_describe_hub_content_response(mock_describe_response) From d4430e27205c90c12410c7f393feafc85187c2ab Mon Sep 17 00:00:00 2001 From: chrstfu Date: Wed, 30 Oct 2024 23:14:09 +0000 Subject: [PATCH 5/8] fix: linting --- src/sagemaker/jumpstart/hub/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 46d15dfa84..77540926c6 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -223,7 +223,7 @@ def get_hub_model_version( return _get_hub_model_version_for_open_weight_version( hub_content_summaries, hub_model_version ) - except KeyError as e: + except KeyError: marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version( hub_content_summaries, hub_model_version ) From cd82335b3b4993389afe01daa9967fda0a3aff75 Mon Sep 17 00:00:00 2001 From: chrstfu Date: Wed, 30 Oct 2024 23:45:48 +0000 Subject: [PATCH 6/8] fix: Fixing coverage tests --- tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py index 6ee508b57d..4412ad467e 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py @@ -17,6 +17,8 @@ from sagemaker.jumpstart.hub.interfaces import HubModelDocument from tests.unit.sagemaker.jumpstart.constants import HUB_MODEL_DOCUMENT_DICTS from unittest.mock import MagicMock +from sagemaker.jumpstart.types import HubContentType + REGION = "us-east-1" ACCOUNT_ID = "123456789123" @@ -41,6 +43,7 @@ def test_parse_(input_string, expected): def test_make_model_specs_from_describe_hub_content_response(): mock_describe_response = MagicMock() region = "us-west-2" + mock_describe_response.hub_content_type = HubContentType.MODEL mock_describe_response.get_hub_region.return_value = region mock_describe_response.hub_content_version = "1.0.0" json_obj = HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"] From 3ca1deb13f1a26c7e1dc6a1ce1f0ae214bf12950 Mon Sep 17 00:00:00 2001 From: chrstfu Date: Thu, 31 Oct 2024 01:13:22 +0000 Subject: [PATCH 7/8] fix: Fixing integration tests --- src/sagemaker/jumpstart/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index da6d1445da..ece29f1927 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -908,7 +908,9 @@ def _validate_hub_service_model_id_and_get_type( ) hub_content_model_types = [] - for model_type in getattr(hub_model_specs, "model_types", []): + model_types_field = getattr(hub_model_specs, "model_types", []) + model_types = model_types_field if model_types_field is not None else [] + for model_type in model_types: try: hub_content_model_types.append(enums.JumpStartModelType[model_type]) except ValueError: From 79a116317e8ad657388662b391240e47e674f219 Mon Sep 17 00:00:00 2001 From: chrstfu Date: Thu, 31 Oct 2024 02:01:50 +0000 Subject: [PATCH 8/8] fix: Minor fixes --- src/sagemaker/jumpstart/utils.py | 8 +++++--- .../hub/test_marketplace_hub_content.py | 19 +++++++++++++------ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index ece29f1927..b33d6563e5 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -863,7 +863,9 @@ def validate_model_id_and_get_type( model_version=model_version, sagemaker_session=sagemaker_session, ) - return model_types[0] # Currently this function only supports one model type + return ( + model_types[0] if model_types else None + ) # Currently this function only supports one model type s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME @@ -908,8 +910,8 @@ def _validate_hub_service_model_id_and_get_type( ) hub_content_model_types = [] - model_types_field = getattr(hub_model_specs, "model_types", []) - model_types = model_types_field if model_types_field is not None else [] + model_types_field: Optional[List[str]] = getattr(hub_model_specs, "model_types", []) + model_types = model_types_field if model_types_field else [] for model_type in model_types: try: hub_content_model_types.append(enums.JumpStartModelType[model_type]) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py b/tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py index 50bd08a62b..49d97d177d 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py @@ -98,12 +98,14 @@ def _test_proprietary_model(input_version, expected_version, sagemaker_session): @pytest.mark.parametrize( - "get_model_specs_response, expected, expected_exception, expected_message", + "get_model_specs_attr, get_model_specs_response, expected, expected_exception, expected_message", [ - (None, [], None, None), - ([], [], None, None), - (["OPEN_WEIGHTS"], [JumpStartModelType.OPEN_WEIGHTS], None, None), + (False, None, [], None, None), + (True, None, [], None, None), + (True, [], [], None, None), + (True, ["OPEN_WEIGHTS"], [JumpStartModelType.OPEN_WEIGHTS], None, None), ( + True, ["OPEN_WEIGHTS", "PROPRIETARY"], [JumpStartModelType.OPEN_WEIGHTS, JumpStartModelType.PROPRIETARY], None, @@ -113,10 +115,15 @@ def _test_proprietary_model(input_version, expected_version, sagemaker_session): ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_validate_hub_service_model_id_and_get_type( - mock_get_model_specs, get_model_specs_response, expected, expected_exception, expected_message + mock_get_model_specs, + get_model_specs_attr, + get_model_specs_response, + expected, + expected_exception, + expected_message, ): mock_object = MagicMock() - if get_model_specs_response: + if get_model_specs_attr: mock_object.model_types = get_model_specs_response mock_get_model_specs.return_value = mock_object