diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py index 140c089b11..ccabde63cd 100644 --- a/src/sagemaker/jumpstart/hub/parser_utils.py +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import re -from typing import Any, Dict +from typing import Any, Dict, Optional, List def camel_to_snake(camel_case_string: str) -> str: @@ -29,20 +29,29 @@ def snake_to_upper_camel(snake_case_string: str) -> str: return upper_camel_case_string -def walk_and_apply_json(json_obj: Dict[Any, Any], apply) -> Dict[Any, Any]: - """Recursively walks a json object and applies a given function to the keys.""" +def walk_and_apply_json( + json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = None +) -> Dict[Any, Any]: + """Recursively walks a json object and applies a given function to the keys. + + stop_keys (Optional[list[str]]): List of field keys that should stop the application function. + Any children of these keys will not have the application function applied to them. + """ def _walk_and_apply_json(json_obj, new): if isinstance(json_obj, dict) and isinstance(new, dict): for key, value in json_obj.items(): new_key = apply(key) - if isinstance(value, dict): - new[new_key] = {} - _walk_and_apply_json(value, new=new[new_key]) - elif isinstance(value, list): - new[new_key] = [] - for item in value: - _walk_and_apply_json(item, new=new[new_key]) + if stop_keys and new_key not in stop_keys: + if isinstance(value, dict): + new[new_key] = {} + _walk_and_apply_json(value, new=new[new_key]) + elif isinstance(value, list): + new[new_key] = [] + for item in value: + _walk_and_apply_json(item, new=new[new_key]) + else: + new[new_key] = value else: new[new_key] = value elif isinstance(json_obj, dict) and isinstance(new, list): diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 88f1dd59e3..0cd970a40f 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1278,7 +1278,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of spec. """ if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake, ["metrics"]) self.model_id: str = json_obj.get("model_id") self.url: str = json_obj.get("url") self.version: str = json_obj.get("version") diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 9117b2d26d..61543225f3 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -9813,7 +9813,7 @@ "ComponentNames": ["variant1"], "BenchmarkMetrics": { "ml.g5.12xlarge": [ - {"Name": "latency", "Unit": "sec", "Value": "0.19", "Concurrency": "1"}, + {"name": "latency", "unit": "sec", "value": "0.19", "concurrency": "1"}, ] }, }, diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index ee50805792..f6ce1f238e 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -15,7 +15,7 @@ from unittest.mock import patch, Mock from sagemaker.jumpstart.types import HubArnExtractedInfo from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME -from sagemaker.jumpstart.hub import utils +from sagemaker.jumpstart.hub import utils, parser_utils def test_get_info_from_hub_resource_arn(): @@ -254,3 +254,52 @@ def test_get_hub_model_version_wildcard_char(mock_session): ) assert result == "2.0.0" + + +def test_walk_and_apply_json(): + test_json = { + "CamelCaseKey": "value", + "CamelCaseObjectKey": { + "CamelCaseObjectChildOne": "value1", + "CamelCaseObjectChildTwo": "value2", + }, + "IgnoreMyChildren": {"ShouldNotBeTouchedOne": "const1", "ShouldNotBeTouchedTwo": "const2"}, + "ShouldNotIgnoreMyChildren": {"NopeNope": "no"}, + } + + result = parser_utils.walk_and_apply_json( + test_json, parser_utils.camel_to_snake, ["ignore_my_children"] + ) + assert result == { + "camel_case_key": "value", + "camel_case_object_key": { + "camel_case_object_child_one": "value1", + "camel_case_object_child_two": "value2", + }, + "ignore_my_children": { + "ShouldNotBeTouchedOne": "const1", + "ShouldNotBeTouchedTwo": "const2", + }, + "should_not_ignore_my_children": {"nope_nope": "no"}, + } + + +def test_walk_and_apply_json_no_stop(): + test_json = { + "CamelCaseKey": "value", + "CamelCaseObjectKey": { + "CamelCaseObjectChildOne": "value1", + "CamelCaseObjectChildTwo": "value2", + }, + "CamelCaseObjectListKey": {"instance.ml.type.xlarge": [{"ShouldChangeMe": "string"}]}, + } + + result = parser_utils.walk_and_apply_json(test_json, parser_utils.camel_to_snake) + assert result == { + "camel_case_key": "value", + "camel_case_object_key": { + "camel_case_object_child_one": "value1", + "camel_case_object_child_two": "value2", + }, + "camel_case_object_list_key": {"instance.ml.type.xlarge": [{"should_change_me": "string"}]}, + }