-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: alt configs model deployment and training issues #4833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3045e4d
d5d8324
7dc1803
f52351b
c9add06
e88cb4d
bca701f
149267a
62b2596
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -259,6 +259,7 @@ def _add_instance_type_to_kwargs( | |
sagemaker_session=kwargs.sagemaker_session, | ||
model_type=kwargs.model_type, | ||
config_name=kwargs.config_name, | ||
hub_arn=kwargs.hub_arn, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're going to be finding these lines forever aren't we.... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is an optional field but looks like integ tests could be a possible solution here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we desperately need integ tests for (at least):
|
||
) | ||
|
||
if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs: | ||
|
@@ -780,6 +781,7 @@ def _add_config_name_to_deploy_kwargs( | |
sagemaker_session=temp_session, | ||
model_type=kwargs.model_type, | ||
config_name=kwargs.config_name, | ||
hub_arn=kwargs.hub_arn, | ||
) | ||
default_config_name = _select_inference_config_from_training_config( | ||
specs=specs, training_config_name=training_config_name | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,16 +10,20 @@ | |
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
# pylint: skip-file | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: generally not a fan of skipping entire file, we should fix in a follow-up PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. me and @bencrabtree just discuss this, there's one apparently pylint false positive which we are not able to get rid of. Skipping it to unblock customer but sure I can followup |
||
"""This module contains utilities related to SageMaker JumpStart Hub.""" | ||
from __future__ import absolute_import | ||
|
||
import re | ||
from typing import Any, Dict | ||
from typing import Any, Dict, List, Optional | ||
|
||
|
||
def camel_to_snake(camel_case_string: str) -> str: | ||
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string.""" | ||
snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string) | ||
if "-" in snake_case_string: | ||
# remove any hyphen from the string for accurate conversion. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: unnecessary comment |
||
snake_case_string = snake_case_string.replace("-", "") | ||
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower() | ||
|
||
|
||
|
@@ -29,20 +33,29 @@ def snake_to_upper_camel(snake_case_string: str) -> str: | |
return upper_camel_case_string | ||
|
||
|
||
def walk_and_apply_json(json_obj: Dict[Any, Any], apply) -> Dict[Any, Any]: | ||
"""Recursively walks a json object and applies a given function to the keys.""" | ||
def walk_and_apply_json( | ||
json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = ["metrics"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. non-blocking: apply typing please |
||
) -> Dict[Any, Any]: | ||
"""Recursively walks a json object and applies a given function to the keys. | ||
|
||
stop_keys (Optional[list[str]]): List of field keys that should stop the application function. | ||
Any children of these keys will not have the application function applied to them. | ||
""" | ||
|
||
def _walk_and_apply_json(json_obj, new): | ||
if isinstance(json_obj, dict) and isinstance(new, dict): | ||
for key, value in json_obj.items(): | ||
new_key = apply(key) | ||
if isinstance(value, dict): | ||
new[new_key] = {} | ||
_walk_and_apply_json(value, new=new[new_key]) | ||
elif isinstance(value, list): | ||
new[new_key] = [] | ||
for item in value: | ||
_walk_and_apply_json(item, new=new[new_key]) | ||
if (stop_keys and new_key not in stop_keys) or stop_keys is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, conditional is too complex. If you want to keep this condition, create a utils and unit-test separately please. |
||
if isinstance(value, dict): | ||
new[new_key] = {} | ||
_walk_and_apply_json(value, new=new[new_key]) | ||
elif isinstance(value, list): | ||
new[new_key] = [] | ||
for item in value: | ||
_walk_and_apply_json(item, new=new[new_key]) | ||
else: | ||
new[new_key] = value | ||
else: | ||
new[new_key] = value | ||
elif isinstance(json_obj, dict) and isinstance(new, list): | ||
|
@@ -51,6 +64,8 @@ def _walk_and_apply_json(json_obj, new): | |
new.update(json_obj) | ||
elif isinstance(json_obj, list) and isinstance(new, list): | ||
new.append(json_obj) | ||
elif isinstance(json_obj, str) and isinstance(new, list): | ||
new.append(json_obj) | ||
return new | ||
|
||
return _walk_and_apply_json(json_obj, new={}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ | |
|
||
from sagemaker.jumpstart import constants, enums | ||
from sagemaker.jumpstart import accessors | ||
from sagemaker.jumpstart.hub.parser_utils import camel_to_snake, snake_to_upper_camel | ||
from sagemaker.s3 import parse_s3_url | ||
from sagemaker.jumpstart.exceptions import ( | ||
DeprecatedJumpStartModelError, | ||
|
@@ -1103,6 +1104,17 @@ def get_jumpstart_configs( | |
metadata_configs.config_rankings.get("overall").rankings if metadata_configs else [] | ||
) | ||
|
||
if hub_arn: | ||
return ( | ||
{ | ||
config_name: metadata_configs.configs[ | ||
camel_to_snake(snake_to_upper_camel(config_name)) | ||
] | ||
for config_name in config_names | ||
} | ||
if metadata_configs | ||
else {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ternary + list-comprehension is a recipe for unreadability. Consider: if metadata_configs:
return { ... }
else:
return {} |
||
) | ||
return ( | ||
{config_name: metadata_configs.configs[config_name] for config_name in config_names} | ||
if metadata_configs | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: