-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: alt configs model deployment and training issues #4833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
3045e4d
d5d8324
7dc1803
f52351b
c9add06
e88cb4d
bca701f
149267a
62b2596
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -259,6 +259,7 @@ def _add_instance_type_to_kwargs( | |
sagemaker_session=kwargs.sagemaker_session, | ||
model_type=kwargs.model_type, | ||
config_name=kwargs.config_name, | ||
hub_arn=kwargs.hub_arn, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're going to be finding these lines forever aren't we.... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is an optional field but looks like integ tests could be a possible solution here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we desperately need integ tests for (at least):
|
||
) | ||
|
||
if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs: | ||
|
@@ -780,6 +781,7 @@ def _add_config_name_to_deploy_kwargs( | |
sagemaker_session=temp_session, | ||
model_type=kwargs.model_type, | ||
config_name=kwargs.config_name, | ||
hub_arn=kwargs.hub_arn, | ||
) | ||
default_config_name = _select_inference_config_from_training_config( | ||
specs=specs, training_config_name=training_config_name | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,12 +14,15 @@ | |
from __future__ import absolute_import | ||
|
||
import re | ||
from typing import Any, Dict | ||
from typing import Any, Dict, List, Optional | ||
|
||
|
||
def camel_to_snake(camel_case_string: str) -> str: | ||
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string.""" | ||
snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string) | ||
if "-" in snake_case_string: | ||
#remove any hyphen from the string for accurate conversion. | ||
snake_case_string = snake_case_string.replace("-", "") | ||
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower() | ||
|
||
|
||
|
@@ -29,20 +32,28 @@ def snake_to_upper_camel(snake_case_string: str) -> str: | |
return upper_camel_case_string | ||
|
||
|
||
def walk_and_apply_json(json_obj: Dict[Any, Any], apply) -> Dict[Any, Any]: | ||
"""Recursively walks a json object and applies a given function to the keys.""" | ||
def walk_and_apply_json( | ||
json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = ["metrics"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. non-blocking: apply typing please |
||
) -> Dict[Any, Any]: | ||
"""Recursively walks a json object and applies a given function to the keys. | ||
stop_keys (Optional[list[str]]): List of field keys that should stop the application function. | ||
Any children of these keys will not have the application function applied to them. | ||
""" | ||
|
||
def _walk_and_apply_json(json_obj, new): | ||
if isinstance(json_obj, dict) and isinstance(new, dict): | ||
for key, value in json_obj.items(): | ||
new_key = apply(key) | ||
if isinstance(value, dict): | ||
new[new_key] = {} | ||
_walk_and_apply_json(value, new=new[new_key]) | ||
elif isinstance(value, list): | ||
new[new_key] = [] | ||
for item in value: | ||
_walk_and_apply_json(item, new=new[new_key]) | ||
new_key = apply(key) | ||
if (stop_keys and new_key not in stop_keys) or stop_keys is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, conditional is too complex. If you want to keep this condition, create a utils and unit-test separately please. |
||
if isinstance(value, dict): | ||
new[new_key] = {} | ||
_walk_and_apply_json(value, new=new[new_key]) | ||
elif isinstance(value, list): | ||
new[new_key] = [] | ||
for item in value: | ||
_walk_and_apply_json(item, new=new[new_key]) | ||
else: | ||
new[new_key] = value | ||
else: | ||
new[new_key] = value | ||
elif isinstance(json_obj, dict) and isinstance(new, list): | ||
|
@@ -51,6 +62,8 @@ def _walk_and_apply_json(json_obj, new): | |
new.update(json_obj) | ||
elif isinstance(json_obj, list) and isinstance(new, list): | ||
new.append(json_obj) | ||
elif isinstance(json_obj, str) and isinstance(new, list): | ||
new.append(json_obj) | ||
return new | ||
|
||
return _walk_and_apply_json(json_obj, new={}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1174,7 +1174,7 @@ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False): | |
spec (Dict[str, Any]): Dictionary representation of training config ranking. | ||
""" | ||
if is_hub_content: | ||
spec = {camel_to_snake(key): val for key, val in spec.items()} | ||
spec = walk_and_apply_json(spec, camel_to_snake) | ||
self.from_json(spec) | ||
|
||
def from_json(self, json_obj: Dict[str, Any]) -> None: | ||
|
@@ -1278,7 +1278,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: | |
json_obj (Dict[str, Any]): Dictionary representation of spec. | ||
""" | ||
if self._is_hub_content: | ||
json_obj = walk_and_apply_json(json_obj, camel_to_snake) | ||
json_obj = walk_and_apply_json(json_obj, camel_to_snake, ["metrics"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit- you can remove |
||
self.model_id: str = json_obj.get("model_id") | ||
self.url: str = json_obj.get("url") | ||
self.version: str = json_obj.get("version") | ||
|
@@ -1400,7 +1400,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: | |
|
||
if self.training_supported: | ||
if self._is_hub_content: | ||
self.training_ecr_uri: Optional[str] = json_obj["training_ecr_uri"] | ||
self.training_ecr_uri: Optional[str] = json_obj.get("training_ecr_uri") | ||
self._non_serializable_slots.append("training_ecr_specs") | ||
else: | ||
self.training_ecr_specs: Optional[JumpStartECRSpecs] = ( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: