Skip to content

fix: ModelReference deployment for Alt Configs models #4813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 1, 2024
Merged
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,7 @@ def _add_config_name_to_kwargs(
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
config_name=kwargs.config_name,
hub_arn=kwargs.hub_arn,
)

if specs.training_configs and specs.training_configs.get_top_config_from_ranking():
Expand Down
15 changes: 13 additions & 2 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,14 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
config_name=kwargs.config_name,
hub_arn=kwargs.hub_arn,
)
if specs.inference_configs:
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
default_config_name = (
specs.inference_configs.get_top_config_from_ranking().config_name
if specs.inference_configs.get_top_config_from_ranking()
else None
)
kwargs.config_name = kwargs.config_name or default_config_name

if not kwargs.config_name:
Expand Down Expand Up @@ -707,6 +712,7 @@ def _add_additional_model_data_sources_to_kwargs(
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
config_name=kwargs.config_name,
hub_arn=kwargs.hub_arn,
)
# Append speculative decoding data source from metadata
speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources()
Expand Down Expand Up @@ -750,6 +756,7 @@ def _add_config_name_to_deploy_kwargs(
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
config_name=kwargs.config_name,
hub_arn=kwargs.hub_arn,
)

if training_config_name:
Expand All @@ -759,7 +766,11 @@ def _add_config_name_to_deploy_kwargs(
return kwargs

if specs.inference_configs:
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
default_config_name = (
specs.inference_configs.get_top_config_from_ranking().config_name
if specs.inference_configs.get_top_config_from_ranking()
else None
)
kwargs.config_name = kwargs.config_name or default_config_name

return kwargs
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/hub/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def list_sagemaker_public_hub_models(
f"arn:{info.partition}:"
f"sagemaker:{info.region}:"
f"aws:hub-content/{info.hub_name}/"
f"{HubContentType.MODEL}/{model[0]}"
f"{HubContentType.MODEL.value}/{model[0]}"
)
hub_content_summary = {
"hub_content_name": model[0],
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def _validate_model_id_and_type():
model_version=self.model_version,
sagemaker_session=self.sagemaker_session,
model_type=self.model_type,
hub_arn=self.hub_arn,
)

def log_subscription_warning(self) -> None:
Expand Down
25 changes: 4 additions & 21 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1732,32 +1732,15 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
)
self.inference_config_rankings: Optional[Dict[str, JumpStartConfigRanking]] = (
{
alias: JumpStartConfigRanking(ranking)
alias: JumpStartConfigRanking(ranking, is_hub_content=self._is_hub_content)
for alias, ranking in json_obj["inference_config_rankings"].items()
}
if json_obj.get("inference_config_rankings")
else None
)
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
config,
json_obj,
(
{
component_name: self.inference_config_components.get(component_name)
for component_name in config.get("component_names")
}
if config and config.get("component_names")
else None
),
)
for alias, config in json_obj["inference_configs"].items()
}
if json_obj.get("inference_configs")
else None
)
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = json_obj[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this works, now that the inference_configs_dict becomes a json object, how would we construct the JumpStartMetadataConfig class without parsing the json?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, also curious if we should do .get here also to mitigate future errors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we reintroduce the code above, please create a dedicated utility & unit test separately.

dict comprehension coupled w/ ternary operator are cool, but they make the code much harder to read and maintain.

"inference_configs"
]["Configs"]
self.inference_configs: Optional[JumpStartMetadataConfigs] = (
JumpStartMetadataConfigs(
inference_configs_dict,
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,7 @@ def get_jumpstart_configs(
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
hub_arn: Optional[str] = None,
) -> Dict[str, JumpStartMetadataConfig]:
"""Returns metadata configs for the given model ID and region.
Expand All @@ -1087,6 +1088,7 @@ def get_jumpstart_configs(
sagemaker_session=sagemaker_session,
scope=scope,
model_type=model_type,
hub_arn=hub_arn,
)

if scope == enums.JumpStartScriptScope.INFERENCE:
Expand Down
Loading