Skip to content

Commit 80ab347

Browse files
committed
fix model reference train by excluding certian keys from camelization
1 parent 6b6e6e2 commit 80ab347

File tree

3 files changed

+47
-12
lines changed

3 files changed

+47
-12
lines changed

src/sagemaker/jumpstart/hub/parser_utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import re
17-
from typing import Any, Dict
17+
from typing import Any, Dict, Optional, List
1818

1919

2020
def camel_to_snake(camel_case_string: str) -> str:
@@ -29,20 +29,29 @@ def snake_to_upper_camel(snake_case_string: str) -> str:
2929
return upper_camel_case_string
3030

3131

32-
def walk_and_apply_json(json_obj: Dict[Any, Any], apply) -> Dict[Any, Any]:
33-
"""Recursively walks a json object and applies a given function to the keys."""
32+
def walk_and_apply_json(
33+
json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = []
34+
) -> Dict[Any, Any]:
35+
"""Recursively walks a json object and applies a given function to the keys.
36+
37+
stop_keys (Optional[list[str]]): List of field keys that should stop the application function. Any
38+
children of these keys will not have the application function applied to them.
39+
"""
3440

3541
def _walk_and_apply_json(json_obj, new):
3642
if isinstance(json_obj, dict) and isinstance(new, dict):
3743
for key, value in json_obj.items():
3844
new_key = apply(key)
39-
if isinstance(value, dict):
40-
new[new_key] = {}
41-
_walk_and_apply_json(value, new=new[new_key])
42-
elif isinstance(value, list):
43-
new[new_key] = []
44-
for item in value:
45-
_walk_and_apply_json(item, new=new[new_key])
45+
if new_key not in stop_keys:
46+
if isinstance(value, dict):
47+
new[new_key] = {}
48+
_walk_and_apply_json(value, new=new[new_key])
49+
elif isinstance(value, list):
50+
new[new_key] = []
51+
for item in value:
52+
_walk_and_apply_json(item, new=new[new_key])
53+
else:
54+
new[new_key] = value
4655
else:
4756
new[new_key] = value
4857
elif isinstance(json_obj, dict) and isinstance(new, list):

src/sagemaker/jumpstart/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12781278
json_obj (Dict[str, Any]): Dictionary representation of spec.
12791279
"""
12801280
if self._is_hub_content:
1281-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
1281+
json_obj = walk_and_apply_json(json_obj, camel_to_snake, ["metrics"])
12821282
self.model_id: str = json_obj.get("model_id")
12831283
self.url: str = json_obj.get("url")
12841284
self.version: str = json_obj.get("version")

tests/unit/sagemaker/jumpstart/hub/test_utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from unittest.mock import patch, Mock
1616
from sagemaker.jumpstart.types import HubArnExtractedInfo
1717
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
18-
from sagemaker.jumpstart.hub import utils
18+
from sagemaker.jumpstart.hub import utils, parser_utils
1919

2020

2121
def test_get_info_from_hub_resource_arn():
@@ -254,3 +254,29 @@ def test_get_hub_model_version_wildcard_char(mock_session):
254254
)
255255

256256
assert result == "2.0.0"
257+
258+
259+
def test_walk_and_apply_json():
260+
test_json = {
261+
"CamelCaseKey": "value",
262+
"CamelCaseObjectKey": {
263+
"CamelCaseObjectChildOne": "value1",
264+
"CamelCaseObjectChildTwo": "value2",
265+
},
266+
"IgnoreMyChildren": {"ShouldNotBeTouchedOne": "const1", "ShouldNotBeTouchedTwo": "const2"},
267+
}
268+
269+
result = parser_utils.walk_and_apply_json(
270+
test_json, parser_utils.camel_to_snake, ["ignore_my_children"]
271+
)
272+
assert result == {
273+
"camel_case_key": "value",
274+
"camel_case_object_key": {
275+
"camel_case_object_child_one": "value1",
276+
"camel_case_object_child_two": "value2",
277+
},
278+
"ignore_my_children": {
279+
"ShouldNotBeTouchedOne": "const1",
280+
"ShouldNotBeTouchedTwo": "const2",
281+
},
282+
}

0 commit comments

Comments
 (0)