Skip to content

Commit 7dc1803

Browse files
author
Malav Shastri
committed
fix training job with alt configs and telemetry changes
1 parent d5d8324 commit 7dc1803

File tree

4 files changed

+46
-23
lines changed

4 files changed

+46
-23
lines changed

src/sagemaker/jumpstart/artifacts/metric_definitions.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,17 @@ def _retrieve_default_training_metric_definitions(
9696
else []
9797
)
9898

99-
instance_specific_metric_name: str
100-
for instance_specific_metric_definition in instance_specific_metric_definitions:
101-
instance_specific_metric_name = instance_specific_metric_definition["Name"]
102-
default_metric_definitions = list(
103-
filter(
104-
lambda metric_definition: metric_definition["Name"]
105-
!= instance_specific_metric_name,
106-
default_metric_definitions,
99+
if instance_specific_metric_definitions:
100+
instance_specific_metric_name: str
101+
for instance_specific_metric_definition in instance_specific_metric_definitions:
102+
instance_specific_metric_name = instance_specific_metric_definition["Name"]
103+
default_metric_definitions = list(
104+
filter(
105+
lambda metric_definition: metric_definition["Name"]
106+
!= instance_specific_metric_name,
107+
default_metric_definitions,
108+
)
107109
)
108-
)
109-
default_metric_definitions.append(instance_specific_metric_definition)
110+
default_metric_definitions.append(instance_specific_metric_definition)
110111

111112
return default_metric_definitions

src/sagemaker/jumpstart/hub/parser_utils.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414
from __future__ import absolute_import
1515

1616
import re
17-
from typing import Any, Dict
17+
from typing import Any, Dict, List, Optional
1818

1919

2020
def camel_to_snake(camel_case_string: str) -> str:
2121
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string."""
2222
snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string)
23+
if "-" in snake_case_string:
24+
#remove any hyphen from the string for accurate conversion.
25+
snake_case_string = snake_case_string.replace("-", "")
2326
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower()
2427

2528

@@ -29,20 +32,28 @@ def snake_to_upper_camel(snake_case_string: str) -> str:
2932
return upper_camel_case_string
3033

3134

32-
def walk_and_apply_json(json_obj: Dict[Any, Any], apply) -> Dict[Any, Any]:
33-
"""Recursively walks a json object and applies a given function to the keys."""
35+
def walk_and_apply_json(
36+
json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = ["metrics"]
37+
) -> Dict[Any, Any]:
38+
"""Recursively walks a json object and applies a given function to the keys.
39+
stop_keys (Optional[list[str]]): List of field keys that should stop the application function.
40+
Any children of these keys will not have the application function applied to them.
41+
"""
3442

3543
def _walk_and_apply_json(json_obj, new):
3644
if isinstance(json_obj, dict) and isinstance(new, dict):
3745
for key, value in json_obj.items():
38-
new_key = apply(key)
39-
if isinstance(value, dict):
40-
new[new_key] = {}
41-
_walk_and_apply_json(value, new=new[new_key])
42-
elif isinstance(value, list):
43-
new[new_key] = []
44-
for item in value:
45-
_walk_and_apply_json(item, new=new[new_key])
46+
new_key = apply(key)
47+
if (stop_keys and new_key not in stop_keys) or stop_keys is None:
48+
if isinstance(value, dict):
49+
new[new_key] = {}
50+
_walk_and_apply_json(value, new=new[new_key])
51+
elif isinstance(value, list):
52+
new[new_key] = []
53+
for item in value:
54+
_walk_and_apply_json(item, new=new[new_key])
55+
else:
56+
new[new_key] = value
4657
else:
4758
new[new_key] = value
4859
elif isinstance(json_obj, dict) and isinstance(new, list):
@@ -51,6 +62,8 @@ def _walk_and_apply_json(json_obj, new):
5162
new.update(json_obj)
5263
elif isinstance(json_obj, list) and isinstance(new, list):
5364
new.append(json_obj)
65+
elif isinstance(json_obj, str) and isinstance(new, list):
66+
new.append(json_obj)
5467
return new
5568

5669
return _walk_and_apply_json(json_obj, new={})

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False):
11741174
spec (Dict[str, Any]): Dictionary representation of training config ranking.
11751175
"""
11761176
if is_hub_content:
1177-
spec = {camel_to_snake(key): val for key, val in spec.items()}
1177+
spec = walk_and_apply_json(spec, camel_to_snake)
11781178
self.from_json(spec)
11791179

11801180
def from_json(self, json_obj: Dict[str, Any]) -> None:
@@ -1277,6 +1277,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12771277
Args:
12781278
json_obj (Dict[str, Any]): Dictionary representation of spec.
12791279
"""
1280+
if self._is_hub_content:
1281+
json_obj = walk_and_apply_json(json_obj, camel_to_snake, ["metrics"])
12801282
self.model_id: str = json_obj.get("model_id")
12811283
self.url: str = json_obj.get("url")
12821284
self.version: str = json_obj.get("version")
@@ -1398,7 +1400,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
13981400

13991401
if self.training_supported:
14001402
if self._is_hub_content:
1401-
self.training_ecr_uri: Optional[str] = json_obj["training_ecr_uri"]
1403+
self.training_ecr_uri: Optional[str] = json_obj.get("training_ecr_uri")
14021404
self._non_serializable_slots.append("training_ecr_specs")
14031405
else:
14041406
self.training_ecr_specs: Optional[JumpStartECRSpecs] = (

src/sagemaker/jumpstart/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from sagemaker.jumpstart import constants, enums
3535
from sagemaker.jumpstart import accessors
36+
from sagemaker.jumpstart.hub.parser_utils import camel_to_snake, snake_to_upper_camel
3637
from sagemaker.s3 import parse_s3_url
3738
from sagemaker.jumpstart.exceptions import (
3839
DeprecatedJumpStartModelError,
@@ -1103,6 +1104,12 @@ def get_jumpstart_configs(
11031104
metadata_configs.config_rankings.get("overall").rankings if metadata_configs else []
11041105
)
11051106

1107+
if hub_arn:
1108+
return (
1109+
{config_name: metadata_configs.configs[camel_to_snake(snake_to_upper_camel(config_name))] for config_name in config_names}
1110+
if metadata_configs
1111+
else {}
1112+
)
11061113
return (
11071114
{config_name: metadata_configs.configs[config_name] for config_name in config_names}
11081115
if metadata_configs

0 commit comments

Comments
 (0)