Skip to content

Commit af1e554

Browse files
author
Malav Shastri
committed
fix: ModelReference deployment for Alt Configs models
1 parent 6c2c4c9 commit af1e554

File tree

5 files changed

+21
-23
lines changed

5 files changed

+21
-23
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,7 @@ def _add_config_name_to_kwargs(
912912
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
913913
sagemaker_session=kwargs.sagemaker_session,
914914
config_name=kwargs.config_name,
915+
hub_arn=kwargs.hub_arn,
915916
)
916917

917918
if specs.training_configs and specs.training_configs.get_top_config_from_ranking():

src/sagemaker/jumpstart/factory/model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,9 +672,14 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
672672
sagemaker_session=kwargs.sagemaker_session,
673673
model_type=kwargs.model_type,
674674
config_name=kwargs.config_name,
675+
hub_arn=kwargs.hub_arn,
675676
)
676677
if specs.inference_configs:
677-
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
678+
default_config_name = (
679+
specs.inference_configs.get_top_config_from_ranking().config_name
680+
if specs.inference_configs.get_top_config_from_ranking()
681+
else None
682+
)
678683
kwargs.config_name = kwargs.config_name or default_config_name
679684

680685
if not kwargs.config_name:
@@ -707,6 +712,7 @@ def _add_additional_model_data_sources_to_kwargs(
707712
sagemaker_session=kwargs.sagemaker_session,
708713
model_type=kwargs.model_type,
709714
config_name=kwargs.config_name,
715+
hub_arn=kwargs.hub_arn,
710716
)
711717
# Append speculative decoding data source from metadata
712718
speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources()
@@ -750,6 +756,7 @@ def _add_config_name_to_deploy_kwargs(
750756
sagemaker_session=kwargs.sagemaker_session,
751757
model_type=kwargs.model_type,
752758
config_name=kwargs.config_name,
759+
hub_arn=kwargs.hub_arn,
753760
)
754761

755762
if training_config_name:
@@ -759,7 +766,11 @@ def _add_config_name_to_deploy_kwargs(
759766
return kwargs
760767

761768
if specs.inference_configs:
762-
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
769+
default_config_name = (
770+
specs.inference_configs.get_top_config_from_ranking().config_name
771+
if specs.inference_configs.get_top_config_from_ranking()
772+
else None
773+
)
763774
kwargs.config_name = kwargs.config_name or default_config_name
764775

765776
return kwargs

src/sagemaker/jumpstart/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def _validate_model_id_and_type():
393393
model_version=self.model_version,
394394
sagemaker_session=self.sagemaker_session,
395395
model_type=self.model_type,
396+
hub_arn=self.hub_arn,
396397
)
397398

398399
def log_subscription_warning(self) -> None:

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,32 +1732,15 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
17321732
)
17331733
self.inference_config_rankings: Optional[Dict[str, JumpStartConfigRanking]] = (
17341734
{
1735-
alias: JumpStartConfigRanking(ranking)
1735+
alias: JumpStartConfigRanking(ranking, is_hub_content=self._is_hub_content)
17361736
for alias, ranking in json_obj["inference_config_rankings"].items()
17371737
}
17381738
if json_obj.get("inference_config_rankings")
17391739
else None
17401740
)
1741-
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1742-
{
1743-
alias: JumpStartMetadataConfig(
1744-
alias,
1745-
config,
1746-
json_obj,
1747-
(
1748-
{
1749-
component_name: self.inference_config_components.get(component_name)
1750-
for component_name in config.get("component_names")
1751-
}
1752-
if config and config.get("component_names")
1753-
else None
1754-
),
1755-
)
1756-
for alias, config in json_obj["inference_configs"].items()
1757-
}
1758-
if json_obj.get("inference_configs")
1759-
else None
1760-
)
1741+
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = json_obj[
1742+
"inference_configs"
1743+
]["Configs"]
17611744
self.inference_configs: Optional[JumpStartMetadataConfigs] = (
17621745
JumpStartMetadataConfigs(
17631746
inference_configs_dict,

src/sagemaker/jumpstart/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,7 @@ def get_jumpstart_configs(
10741074
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
10751075
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
10761076
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
1077+
hub_arn: Optional[str] = None,
10771078
) -> Dict[str, JumpStartMetadataConfig]:
10781079
"""Returns metadata configs for the given model ID and region.
10791080
@@ -1087,6 +1088,7 @@ def get_jumpstart_configs(
10871088
sagemaker_session=sagemaker_session,
10881089
scope=scope,
10891090
model_type=model_type,
1091+
hub_arn=hub_arn,
10901092
)
10911093

10921094
if scope == enums.JumpStartScriptScope.INFERENCE:

0 commit comments

Comments
 (0)