From 8a5c58c2668eb04f5fb4a56989505504743502a0 Mon Sep 17 00:00:00 2001 From: chrstfu Date: Tue, 8 Oct 2024 20:26:07 +0000 Subject: [PATCH 1/3] feat: Bedrock Store and Marketplace model support in HubService Feat: Bedrock Store and Marketplace model support in HubService fix: Adding test for tags fix: Adding more fields fix: Fixing tag generation feat: Adding BRS and marketplace model support fix: Adding better flag fix: Making fields non-required fix: Adding more parsing logic fix: Renaming fix: Adding test fix: reverting comment fix: Adding tests fix: Adding initial prop support logic fix: Adding initial prop management fix: Adding new stuff fix: linting fix: trigger github acitons fix: revertin --- src/sagemaker/jumpstart/enums.py | 8 +++ src/sagemaker/jumpstart/factory/model.py | 8 ++- src/sagemaker/jumpstart/hub/interfaces.py | 29 +++++--- src/sagemaker/jumpstart/hub/parser_utils.py | 10 +-- src/sagemaker/jumpstart/hub/parsers.py | 34 ++++++---- src/sagemaker/jumpstart/hub/utils.py | 61 +++++++++++++++-- src/sagemaker/jumpstart/types.py | 30 +++++---- src/sagemaker/jumpstart/utils.py | 67 ++++++++++++++++++- .../jumpstart/model/test_jumpstart_model.py | 51 ++++++++++++++ tests/integ/sagemaker/jumpstart/utils.py | 9 +-- .../sagemaker/jumpstart/hub/test_utils.py | 4 +- 11 files changed, 251 insertions(+), 60 deletions(-) diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index a83964e394..91f547afb6 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -82,6 +82,12 @@ class VariableTypes(str, Enum): BOOL = "bool" +class HubContentCapability(str, Enum): + """Enum class for HubContent capabilities.""" + + BEDROCK_CONSOLE = "BEDROCK_CONSOLE" + + class JumpStartTag(str, Enum): """Enum class for tag keys to apply to JumpStart models.""" @@ -99,6 +105,8 @@ class JumpStartTag(str, Enum): HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn" + BEDROCK = "sagemaker-sdk:bedrock" + class SerializerType(str, Enum): """Enum class for serializers associated with JumpStart models.""" diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index ccafed844d..d8f0b252c2 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -41,7 +41,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability from sagemaker.jumpstart.types import ( HubContentType, JumpStartModelDeployKwargs, @@ -51,7 +51,9 @@ ) from sagemaker.jumpstart.utils import ( add_hub_content_arn_tags, + add_bedrock_store_tags, add_jumpstart_model_info_tags, + add_bedrock_store_tags, get_default_jumpstart_session_with_user_agent_suffix, get_neo_content_bucket, get_top_ranked_config_name, @@ -488,6 +490,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: ) kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn) + if hasattr(kwargs.specs, "capabilities") and kwargs.specs.capabilities is not None: + if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities: + kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible") + return kwargs diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index d987216872..69a468a0b4 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -451,12 +451,14 @@ def from_json(self, json_obj: str) -> None: class HubModelDocument(HubDataHolderType): """Data class for model type HubContentDocument from session.describe_hub_content().""" - SCHEMA_VERSION = "2.2.0" + SCHEMA_VERSION = "2.3.0" __slots__ = [ "url", "min_sdk_version", "training_supported", + "model_types", + "capabilities", "incremental_training_supported", "dynamic_container_deployment_supported", "hosting_ecr_uri", @@ -469,6 +471,7 @@ class HubModelDocument(HubDataHolderType): "hosting_use_script_uri", "hosting_eula_uri", "hosting_model_package_arn", + "model_subscription_link", "inference_configs", "inference_config_components", "inference_config_rankings", @@ -549,18 +552,22 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of hub model document. """ - self.url: str = json_obj["Url"] - self.min_sdk_version: str = json_obj["MinSdkVersion"] - self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"] - self.hosting_artifact_uri = json_obj["HostingArtifactUri"] - self.hosting_script_uri = json_obj["HostingScriptUri"] - self.inference_dependencies: List[str] = json_obj["InferenceDependencies"] + self.url: str = json_obj.get("Url") + self.min_sdk_version: str = json_obj.get("MinSdkVersion") + self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri") + self.hosting_artifact_uri = json_obj.get("HostingArtifactUri") + self.hosting_script_uri = json_obj.get("HostingScriptUri") + self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies") self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [ JumpStartEnvironmentVariable(env_variable, is_hub_content=True) - for env_variable in json_obj["InferenceEnvironmentVariables"] + for env_variable in json_obj.get("InferenceEnvironmentVariables", []) ] - self.training_supported: bool = bool(json_obj["TrainingSupported"]) - self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"]) + self.model_types: Optional[List[str]] = json_obj.get("ModelTypes") + self.capabilities: Optional[List[str]] = json_obj.get("Capabilities") + self.training_supported: bool = bool(json_obj.get("TrainingSupported")) + self.incremental_training_supported: bool = bool( + json_obj.get("IncrementalTrainingSupported") + ) self.dynamic_container_deployment_supported: Optional[bool] = ( bool(json_obj.get("DynamicContainerDeploymentSupported")) if json_obj.get("DynamicContainerDeploymentSupported") @@ -586,6 +593,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri") self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn") + self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink") + self.inference_config_rankings = self._get_config_rankings(json_obj) self.inference_config_components = self._get_config_components(json_obj) self.inference_configs = self._get_configs(json_obj) diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py index 54147dd8e6..cae4e08f17 100644 --- a/src/sagemaker/jumpstart/hub/parser_utils.py +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -18,13 +18,9 @@ from typing import Any, Dict, List, Optional -def camel_to_snake(camel_case_string: str) -> str: - """Converts camelCaseString or UpperCamelCaseString to snake_case_string.""" - snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string) - if "-" in snake_case_string: - # remove any hyphen from the string for accurate conversion. - snake_case_string = snake_case_string.replace("-", "") - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower() +def pascal_to_snake(camel_case_string: str) -> str: + """Converts PascalCase to snake_case_string.""" + return re.sub(r"(? str: diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py index 28c2d9b32d..4e754b24e4 100644 --- a/src/sagemaker/jumpstart/hub/parsers.py +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -27,7 +27,7 @@ HubModelDocument, ) from sagemaker.jumpstart.hub.parser_utils import ( - camel_to_snake, + pascal_to_snake, snake_to_upper_camel, walk_and_apply_json, ) @@ -86,7 +86,7 @@ def get_model_spec_arg_keys( arg_keys = [] if naming_convention == NamingConventionType.SNAKE_CASE: - arg_keys = [camel_to_snake(key) for key in arg_keys] + arg_keys = [pascal_to_snake(key) for key in arg_keys] elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE: return arg_keys else: @@ -137,6 +137,8 @@ def make_model_specs_from_describe_hub_content_response( hub_model_document: HubModelDocument = response.hub_content_document specs["url"] = hub_model_document.url specs["min_sdk_version"] = hub_model_document.min_sdk_version + specs["model_types"] = hub_model_document.model_types + specs["capabilities"] = hub_model_document.capabilities specs["training_supported"] = bool(hub_model_document.training_supported) specs["incremental_training_supported"] = bool( hub_model_document.incremental_training_supported @@ -146,15 +148,19 @@ def make_model_specs_from_describe_hub_content_response( specs["inference_config_components"] = hub_model_document.inference_config_components specs["inference_config_rankings"] = hub_model_document.inference_config_rankings - hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable - hub_model_document.hosting_artifact_uri - ) - specs["hosting_artifact_key"] = hosting_artifact_key - specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri - hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable - hub_model_document.hosting_script_uri - ) - specs["hosting_script_key"] = hosting_script_key + if hub_model_document.hosting_artifact_uri: + _, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_artifact_uri + ) + specs["hosting_artifact_key"] = hosting_artifact_key + specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri + + if hub_model_document.hosting_script_uri: + _, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_script_uri + ) + specs["hosting_script_key"] = hosting_script_key + specs["inference_environment_variables"] = hub_model_document.inference_environment_variables specs["inference_vulnerable"] = False specs["inference_dependencies"] = hub_model_document.inference_dependencies @@ -201,7 +207,7 @@ def make_model_specs_from_describe_hub_content_response( default_payloads: Dict[str, Any] = {} if hub_model_document.default_payloads is not None: for alias, payload in hub_model_document.default_payloads.items(): - default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake) + default_payloads[alias] = walk_and_apply_json(payload.to_json(), pascal_to_snake) specs["default_payloads"] = default_payloads specs["gated_bucket"] = hub_model_document.gated_bucket specs["inference_volume_size"] = hub_model_document.inference_volume_size @@ -219,6 +225,10 @@ def make_model_specs_from_describe_hub_content_response( if hub_model_document.hosting_model_package_arn: specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn} + + specs["model_subscription_link"] = hub_model_document.model_subscription_link + + specs["model_subscription_link"] = hub_model_document.model_subscription_link specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 535bc5e9be..b2298d4314 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -14,7 +14,7 @@ """This module contains utilities related to SageMaker JumpStart Hub.""" from __future__ import absolute_import import re -from typing import Optional +from typing import Optional, List, Any from sagemaker.jumpstart.hub.types import S3ObjectLocation from sagemaker.s3_utils import parse_s3_url from sagemaker.session import Session @@ -23,6 +23,8 @@ from sagemaker.jumpstart import constants from packaging.specifiers import SpecifierSet, InvalidSpecifier +PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:" + def get_info_from_hub_resource_arn( arn: str, @@ -117,8 +119,8 @@ def generate_hub_arn_for_init_kwargs( hub_arn = None if hub_name: - if hub_name == constants.JUMPSTART_MODEL_HUB_NAME: - return None + # if hub_name == constants.JUMPSTART_MODEL_HUB_NAME: + # return None match = re.match(constants.HUB_ARN_REGEX, hub_name) if match: hub_arn = hub_name @@ -207,6 +209,24 @@ def get_hub_model_version( except Exception as ex: raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") + open_weight_hub_content_version = _get_open_weight_hub_model_version( + hub_content_summaries, hub_model_version + ) + if open_weight_hub_content_version: + return open_weight_hub_content_version + + proprietary_hub_content_version = _get_proprietary_hub_model_version( + hub_content_summaries, hub_model_version + ) + if proprietary_hub_content_version: + return proprietary_hub_content_version + + raise KeyError(f"Could not find HubContent with specified version: {hub_model_version}") + + +def _get_open_weight_hub_model_version( + hub_content_summaries: List[Any], hub_model_version: Optional[str] = None +) -> Optional[str]: available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries] if hub_model_version == "*" or hub_model_version is None: @@ -215,10 +235,41 @@ def get_hub_model_version( try: spec = SpecifierSet(f"=={hub_model_version}") except InvalidSpecifier: - raise KeyError(f"Bad semantic version: {hub_model_version}") + return None available_versions_filtered = list(spec.filter(available_model_versions)) if not available_versions_filtered: - raise KeyError("Model version not available in the Hub") + return None hub_model_version = str(max(available_versions_filtered)) return hub_model_version + + +def _get_proprietary_hub_model_version( + hub_content_summaries: List[Any], proprietary_hub_model_version: str +) -> Optional[str]: + + for model in hub_content_summaries: + model_search_keywords = model.get("HubContentSearchKeywords", []) + if _hub_search_keywords_contains_proprietary_version( + model_search_keywords, proprietary_hub_model_version + ): + return model.get("HubContentVersion") + + return None + + +def _hub_search_keywords_contains_proprietary_version( + model_search_keywords: List[str], proprietary_hub_model_version: str +) -> bool: + proprietary_version_keyword = next( + filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None + ) + + if not proprietary_version_keyword: + return False + + proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD) + if proprietary_version == proprietary_hub_model_version: + return True + + return False diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index f3313b3862..b716cac057 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -39,7 +39,7 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType from sagemaker.jumpstart.hub.parser_utils import ( - camel_to_snake, + pascal_to_snake, walk_and_apply_json, ) @@ -239,7 +239,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.framework = json_obj.get("framework") self.framework_version = json_obj.get("framework_version") @@ -293,7 +293,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: """ if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -361,7 +361,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of environment variable. """ - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -411,7 +411,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.default_content_type = json_obj["default_content_type"] self.supported_content_types = json_obj["supported_content_types"] self.default_accept_type = json_obj["default_accept_type"] @@ -465,7 +465,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.raw_payload = json_obj self.content_type = json_obj["content_type"] self.body = json_obj.get("body") @@ -538,7 +538,7 @@ def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) if response is None: return - response = walk_and_apply_json(response, camel_to_snake) + response = walk_and_apply_json(response, pascal_to_snake) self.aliases: Optional[dict] = response.get("aliases") self.regional_aliases = None self.variants: Optional[dict] = response.get("variants") @@ -1174,7 +1174,7 @@ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False): spec (Dict[str, Any]): Dictionary representation of training config ranking. """ if is_hub_content: - spec = walk_and_apply_json(spec, camel_to_snake) + spec = walk_and_apply_json(spec, pascal_to_snake) self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -1200,6 +1200,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "url", "version", "min_sdk_version", + "model_types", + "capabilities", "incremental_training_supported", "hosting_ecr_specs", "hosting_ecr_uri", @@ -1278,7 +1280,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of spec. """ if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.model_id: str = json_obj.get("model_id") self.url: str = json_obj.get("url") self.version: str = json_obj.get("version") @@ -1287,6 +1289,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj.get("incremental_training_supported", False) ) if self._is_hub_content: + self.capabilities: Optional[List[str]] = json_obj.get("capabilities") + self.model_types: Optional[List[str]] = json_obj.get("model_types") self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri") self._non_serializable_slots.append("hosting_ecr_specs") else: @@ -1505,7 +1509,7 @@ def __init__( ValueError: If the component field is invalid. """ if is_hub_content: - component = walk_and_apply_json(component, camel_to_snake) + component = walk_and_apply_json(component, pascal_to_snake) self.component_name = component_name super().__init__(component, is_hub_content) self.from_json(component) @@ -1558,8 +1562,8 @@ def __init__( The list of components that are used to construct the resolved config. """ if is_hub_content: - config = walk_and_apply_json(config, camel_to_snake) - base_fields = walk_and_apply_json(base_fields, camel_to_snake) + config = walk_and_apply_json(config, pascal_to_snake) + base_fields = walk_and_apply_json(base_fields, pascal_to_snake) self.base_fields = base_fields self.config_components: Dict[str, JumpStartConfigComponent] = config_components self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = ( @@ -1725,7 +1729,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: """ super().from_json(json_obj) if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = ( { component_name: JumpStartConfigComponent(component_name, component) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 4d784c8275..7f2af99d61 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -33,7 +33,7 @@ from sagemaker.jumpstart import constants, enums from sagemaker.jumpstart import accessors -from sagemaker.jumpstart.hub.parser_utils import camel_to_snake, snake_to_upper_camel +from sagemaker.jumpstart.hub.parser_utils import pascal_to_snake, snake_to_upper_camel from sagemaker.s3 import parse_s3_url from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, @@ -451,6 +451,35 @@ def add_hub_content_arn_tags( ) return tags +def add_bedrock_store_tags( + tags: Optional[List[TagsDict]], + compatibility: str, +) -> Optional[List[TagsDict]]: + """Adds custom Hub arn tag to JumpStart related resources.""" + + tags = add_single_jumpstart_tag( + compatibility, + enums.JumpStartTag.BEDROCK, + tags, + is_uri=False, + ) + return tags + + +def add_bedrock_store_tags( + tags: Optional[List[TagsDict]], + compatibility: str, +) -> Optional[List[TagsDict]]: + """Adds custom Hub arn tag to JumpStart related resources.""" + + tags = add_single_jumpstart_tag( + compatibility, + enums.JumpStartTag.BEDROCK, + tags, + is_uri=False, + ) + return tags + def add_jumpstart_uri_tags( tags: Optional[List[TagsDict]] = None, @@ -856,7 +885,10 @@ def validate_model_id_and_get_type( if not isinstance(model_id, str): return None if hub_arn: - return None + model_types = _validate_hub_service_model_id_and_get_type( + model_id, hub_arn, region, model_version, sagemaker_session + ) + return model_types[0] # Currently this function only supports one model type s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME @@ -881,6 +913,35 @@ def validate_model_id_and_get_type( return None +def _validate_hub_service_model_id_and_get_type( + model_id: Optional[str], + hub_arn: str, + region: Optional[str] = None, + model_version: Optional[str] = None, + sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> List[enums.JumpStartModelType]: + """Returns a list of JumpStartModelType based off the HubContent. + + Only returns valid JumpStartModelType. Returns an empty array if none are found. + """ + hub_model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + region=region, + model_id=model_id, + version=model_version, + hub_arn=hub_arn, + sagemaker_session=sagemaker_session, + ) + + hub_content_model_types = [] + for model_type in getattr(hub_model_specs, "model_types", []): + try: + hub_content_model_types.append(enums.JumpStartModelType[model_type]) + except ValueError: + continue + + return hub_content_model_types + + def _extract_value_from_list_of_tags( tag_keys: List[str], list_tags_result: List[str], @@ -1113,7 +1174,7 @@ def get_jumpstart_configs( return ( { config_name: metadata_configs.configs[ - camel_to_snake(snake_to_upper_camel(config_name)) + pascal_to_snake(snake_to_upper_camel(config_name)) ] for config_name in config_names } diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 7733041579..a3d29f7e36 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -416,3 +416,54 @@ def test_jumpstart_session_with_config_name(): "md/js_model_id#meta-textgeneration-llama-2-7b md/js_model_ver#* md/js_config#tgi" in mock_make_request.call_args[0][1]["headers"]["User-Agent"] ) + + +# Currently JumpStartModel does not pull from HubService for the Public Hub. +# def test_bedrock_store_model_tags_from_hub_service(setup): + +# model_id = "huggingface-llm-gemma-2b-instruct" + +# model = JumpStartModel( +# model_id=model_id, +# hub_name="SageMakerPublicHub", +# role=get_sm_session().get_caller_identity_arn(), +# sagemaker_session=get_sm_session(), +# ) + +# predictor = model.deploy( +# tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], +# accept_eula=True, +# ) + +# endpoint_arn = ( +# f"arn:aws:sagemaker:{get_sm_session().boto_region_name}:" +# f"{get_sm_session().account_id()}:endpoint/{predictor.endpoint_name}" +# ) +# tags = get_sm_session().list_tags(endpoint_arn) +# expected_tag = {"Key": "sagemaker-sdk:bedrock", "Value": "compatible"} +# assert expected_tag in tags + +# def test_proprietary_from_hub_service(setup): + +# model_id = "upstage-solar-mini-chat" + +# model = JumpStartModel( +# model_id=model_id, +# hub_name="SageMakerPublicHub", +# role=get_sm_session_with_override().get_caller_identity_arn(), +# sagemaker_session=get_sm_session_with_override(), +# model_version="240612.5" +# ) + +# predictor = model.deploy( +# tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], +# accept_eula=True, +# ) + +# payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1} + +# response = predictor.predict(payload) + +# predictor.delete_predictor() + +# assert response is not None diff --git a/tests/integ/sagemaker/jumpstart/utils.py b/tests/integ/sagemaker/jumpstart/utils.py index 0f2fd01572..8f1e443573 100644 --- a/tests/integ/sagemaker/jumpstart/utils.py +++ b/tests/integ/sagemaker/jumpstart/utils.py @@ -54,17 +54,12 @@ def get_sm_session() -> Session: # # boto3.set_stream_logger(name='botocore', level=logging.DEBUG) # boto_session = boto3.Session(region_name="us-west-2") # sagemaker = boto3.client( -# service_name="sagemaker-internal", -# endpoint_url="https://sagemaker.beta.us-west-2.ml-platform.aws.a2z.com", -# ) -# sagemaker_runtime = boto3.client( -# service_name="runtime.maeve", -# endpoint_url="https://maeveruntime.beta.us-west-2.ml-platform.aws.a2z.com", +# service_name="sagemaker", +# endpoint_url="https://sagemaker.gamma.us-west-2.ml-platform.aws.a2z.com", # ) # return Session( # boto_session=boto_session, # sagemaker_client=sagemaker, -# sagemaker_runtime_client=sagemaker_runtime, # ) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index 22bc527b18..0ef740398c 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -268,7 +268,7 @@ def test_walk_and_apply_json(): } result = parser_utils.walk_and_apply_json( - test_json, parser_utils.camel_to_snake, ["ignore_my_children"] + test_json, parser_utils.pascal_to_snake, ["ignore_my_children"] ) assert result == { "camel_case_key": "value", @@ -294,7 +294,7 @@ def test_walk_and_apply_json_no_stop(): "CamelCaseObjectListKey": {"instance.ml.type.xlarge": [{"ShouldChangeMe": "string"}]}, } - result = parser_utils.walk_and_apply_json(test_json, parser_utils.camel_to_snake) + result = parser_utils.walk_and_apply_json(test_json, parser_utils.pascal_to_snake) assert result == { "camel_case_key": "value", "camel_case_object_key": { From b4aced23839bbb68507ed9f4c09111c765ac8a35 Mon Sep 17 00:00:00 2001 From: chrstfu Date: Wed, 16 Oct 2024 20:15:01 +0000 Subject: [PATCH 2/3] fix: reverting some changes --- src/sagemaker/jumpstart/enums.py | 8 -------- src/sagemaker/jumpstart/factory/model.py | 8 +------- src/sagemaker/jumpstart/hub/utils.py | 4 ++-- 3 files changed, 3 insertions(+), 17 deletions(-) diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index 91f547afb6..a83964e394 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -82,12 +82,6 @@ class VariableTypes(str, Enum): BOOL = "bool" -class HubContentCapability(str, Enum): - """Enum class for HubContent capabilities.""" - - BEDROCK_CONSOLE = "BEDROCK_CONSOLE" - - class JumpStartTag(str, Enum): """Enum class for tag keys to apply to JumpStart models.""" @@ -105,8 +99,6 @@ class JumpStartTag(str, Enum): HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn" - BEDROCK = "sagemaker-sdk:bedrock" - class SerializerType(str, Enum): """Enum class for serializers associated with JumpStart models.""" diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index d8f0b252c2..ccafed844d 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -41,7 +41,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.types import ( HubContentType, JumpStartModelDeployKwargs, @@ -51,9 +51,7 @@ ) from sagemaker.jumpstart.utils import ( add_hub_content_arn_tags, - add_bedrock_store_tags, add_jumpstart_model_info_tags, - add_bedrock_store_tags, get_default_jumpstart_session_with_user_agent_suffix, get_neo_content_bucket, get_top_ranked_config_name, @@ -490,10 +488,6 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: ) kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn) - if hasattr(kwargs.specs, "capabilities") and kwargs.specs.capabilities is not None: - if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities: - kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible") - return kwargs diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index b2298d4314..be4a3da419 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -119,8 +119,8 @@ def generate_hub_arn_for_init_kwargs( hub_arn = None if hub_name: - # if hub_name == constants.JUMPSTART_MODEL_HUB_NAME: - # return None + if hub_name == constants.JUMPSTART_MODEL_HUB_NAME: + return None match = re.match(constants.HUB_ARN_REGEX, hub_name) if match: hub_arn = hub_name From 1b5440004117259a6e9886bf6cd80af275cb2eec Mon Sep 17 00:00:00 2001 From: chrstfu Date: Wed, 16 Oct 2024 20:32:09 +0000 Subject: [PATCH 3/3] fix: removing a test --- .../jumpstart/model/test_jumpstart_model.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index a3d29f7e36..95593be139 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -419,30 +419,6 @@ def test_jumpstart_session_with_config_name(): # Currently JumpStartModel does not pull from HubService for the Public Hub. -# def test_bedrock_store_model_tags_from_hub_service(setup): - -# model_id = "huggingface-llm-gemma-2b-instruct" - -# model = JumpStartModel( -# model_id=model_id, -# hub_name="SageMakerPublicHub", -# role=get_sm_session().get_caller_identity_arn(), -# sagemaker_session=get_sm_session(), -# ) - -# predictor = model.deploy( -# tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], -# accept_eula=True, -# ) - -# endpoint_arn = ( -# f"arn:aws:sagemaker:{get_sm_session().boto_region_name}:" -# f"{get_sm_session().account_id()}:endpoint/{predictor.endpoint_name}" -# ) -# tags = get_sm_session().list_tags(endpoint_arn) -# expected_tag = {"Key": "sagemaker-sdk:bedrock", "Value": "compatible"} -# assert expected_tag in tags - # def test_proprietary_from_hub_service(setup): # model_id = "upstage-solar-mini-chat"