diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 5e5c0d79a0..16e81b2785 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -96,16 +96,17 @@ def _retrieve_default_training_metric_definitions( else [] ) - instance_specific_metric_name: str - for instance_specific_metric_definition in instance_specific_metric_definitions: - instance_specific_metric_name = instance_specific_metric_definition["Name"] - default_metric_definitions = list( - filter( - lambda metric_definition: metric_definition["Name"] - != instance_specific_metric_name, - default_metric_definitions, + if instance_specific_metric_definitions: + instance_specific_metric_name: str + for instance_specific_metric_definition in instance_specific_metric_definitions: + instance_specific_metric_name = instance_specific_metric_definition["Name"] + default_metric_definitions = list( + filter( + lambda metric_definition: metric_definition["Name"] + != instance_specific_metric_name, + default_metric_definitions, + ) ) - ) - default_metric_definitions.append(instance_specific_metric_definition) + default_metric_definitions.append(instance_specific_metric_definition) return default_metric_definitions diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 49d9c93dd4..a193732ca1 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -259,6 +259,7 @@ def _add_instance_type_to_kwargs( sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, config_name=kwargs.config_name, + hub_arn=kwargs.hub_arn, ) if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs: @@ -780,6 +781,7 @@ def _add_config_name_to_deploy_kwargs( sagemaker_session=temp_session, model_type=kwargs.model_type, config_name=kwargs.config_name, + hub_arn=kwargs.hub_arn, ) default_config_name = _select_inference_config_from_training_config( specs=specs, training_config_name=training_config_name diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py index 140c089b11..54147dd8e6 100644 --- a/src/sagemaker/jumpstart/hub/parser_utils.py +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -10,16 +10,20 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +# pylint: skip-file """This module contains utilities related to SageMaker JumpStart Hub.""" from __future__ import absolute_import import re -from typing import Any, Dict +from typing import Any, Dict, List, Optional def camel_to_snake(camel_case_string: str) -> str: """Converts camelCaseString or UpperCamelCaseString to snake_case_string.""" snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string) + if "-" in snake_case_string: + # remove any hyphen from the string for accurate conversion. + snake_case_string = snake_case_string.replace("-", "") return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower() @@ -29,20 +33,29 @@ def snake_to_upper_camel(snake_case_string: str) -> str: return upper_camel_case_string -def walk_and_apply_json(json_obj: Dict[Any, Any], apply) -> Dict[Any, Any]: - """Recursively walks a json object and applies a given function to the keys.""" +def walk_and_apply_json( + json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = ["metrics"] +) -> Dict[Any, Any]: + """Recursively walks a json object and applies a given function to the keys. + + stop_keys (Optional[list[str]]): List of field keys that should stop the application function. + Any children of these keys will not have the application function applied to them. + """ def _walk_and_apply_json(json_obj, new): if isinstance(json_obj, dict) and isinstance(new, dict): for key, value in json_obj.items(): new_key = apply(key) - if isinstance(value, dict): - new[new_key] = {} - _walk_and_apply_json(value, new=new[new_key]) - elif isinstance(value, list): - new[new_key] = [] - for item in value: - _walk_and_apply_json(item, new=new[new_key]) + if (stop_keys and new_key not in stop_keys) or stop_keys is None: + if isinstance(value, dict): + new[new_key] = {} + _walk_and_apply_json(value, new=new[new_key]) + elif isinstance(value, list): + new[new_key] = [] + for item in value: + _walk_and_apply_json(item, new=new[new_key]) + else: + new[new_key] = value else: new[new_key] = value elif isinstance(json_obj, dict) and isinstance(new, list): @@ -51,6 +64,8 @@ def _walk_and_apply_json(json_obj, new): new.update(json_obj) elif isinstance(json_obj, list) and isinstance(new, list): new.append(json_obj) + elif isinstance(json_obj, str) and isinstance(new, list): + new.append(json_obj) return new return _walk_and_apply_json(json_obj, new={}) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 88f1dd59e3..9d5acf6c6e 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1174,7 +1174,7 @@ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False): spec (Dict[str, Any]): Dictionary representation of training config ranking. """ if is_hub_content: - spec = {camel_to_snake(key): val for key, val in spec.items()} + spec = walk_and_apply_json(spec, camel_to_snake) self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -1400,7 +1400,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if self.training_supported: if self._is_hub_content: - self.training_ecr_uri: Optional[str] = json_obj["training_ecr_uri"] + self.training_ecr_uri: Optional[str] = json_obj.get("training_ecr_uri") self._non_serializable_slots.append("training_ecr_specs") else: self.training_ecr_specs: Optional[JumpStartECRSpecs] = ( diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index bebf14d5c0..c881bce482 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -33,6 +33,7 @@ from sagemaker.jumpstart import constants, enums from sagemaker.jumpstart import accessors +from sagemaker.jumpstart.hub.parser_utils import camel_to_snake, snake_to_upper_camel from sagemaker.s3 import parse_s3_url from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, @@ -1103,6 +1104,17 @@ def get_jumpstart_configs( metadata_configs.config_rankings.get("overall").rankings if metadata_configs else [] ) + if hub_arn: + return ( + { + config_name: metadata_configs.configs[ + camel_to_snake(snake_to_upper_camel(config_name)) + ] + for config_name in config_names + } + if metadata_configs + else {} + ) return ( {config_name: metadata_configs.configs[config_name] for config_name in config_names} if metadata_configs diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index ee50805792..22bc527b18 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -15,7 +15,7 @@ from unittest.mock import patch, Mock from sagemaker.jumpstart.types import HubArnExtractedInfo from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME -from sagemaker.jumpstart.hub import utils +from sagemaker.jumpstart.hub import parser_utils, utils def test_get_info_from_hub_resource_arn(): @@ -254,3 +254,52 @@ def test_get_hub_model_version_wildcard_char(mock_session): ) assert result == "2.0.0" + + +def test_walk_and_apply_json(): + test_json = { + "CamelCaseKey": "value", + "CamelCaseObjectKey": { + "CamelCaseObjectChildOne": "value1", + "CamelCaseObjectChildTwo": "value2", + }, + "IgnoreMyChildren": {"ShouldNotBeTouchedOne": "const1", "ShouldNotBeTouchedTwo": "const2"}, + "ShouldNotIgnoreMyChildren": {"NopeNope": "no"}, + } + + result = parser_utils.walk_and_apply_json( + test_json, parser_utils.camel_to_snake, ["ignore_my_children"] + ) + assert result == { + "camel_case_key": "value", + "camel_case_object_key": { + "camel_case_object_child_one": "value1", + "camel_case_object_child_two": "value2", + }, + "ignore_my_children": { + "ShouldNotBeTouchedOne": "const1", + "ShouldNotBeTouchedTwo": "const2", + }, + "should_not_ignore_my_children": {"nope_nope": "no"}, + } + + +def test_walk_and_apply_json_no_stop(): + test_json = { + "CamelCaseKey": "value", + "CamelCaseObjectKey": { + "CamelCaseObjectChildOne": "value1", + "CamelCaseObjectChildTwo": "value2", + }, + "CamelCaseObjectListKey": {"instance.ml.type.xlarge": [{"ShouldChangeMe": "string"}]}, + } + + result = parser_utils.walk_and_apply_json(test_json, parser_utils.camel_to_snake) + assert result == { + "camel_case_key": "value", + "camel_case_object_key": { + "camel_case_object_child_one": "value1", + "camel_case_object_child_two": "value2", + }, + "camel_case_object_list_key": {"instance.ml.type.xlarge": [{"should_change_me": "string"}]}, + }