Skip to content

Commit 7adccde

Browse files
committed
fix: config logic support
1 parent 40c791c commit 7adccde

File tree

4 files changed

+20
-10
lines changed

4 files changed

+20
-10
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,7 @@ def _add_config_name_to_kwargs(
837837

838838
kwargs.config_name = kwargs.config_name or get_top_ranked_config_name(
839839
scope=JumpStartScriptScope.TRAINING,
840+
instance_type=kwargs.instance_type,
840841
**get_model_info_default_kwargs(kwargs, include_config_name=False),
841842
)
842843

src/sagemaker/jumpstart/factory/model.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
559559
kwargs.config_name = kwargs.config_name or get_top_ranked_config_name(
560560
**get_model_info_default_kwargs(kwargs, include_config_name=False),
561561
scope=JumpStartScriptScope.INFERENCE,
562+
instance_type=kwargs.instance_type,
562563
)
563564

564565
if kwargs.config_name is None:
@@ -618,6 +619,7 @@ def _add_config_name_to_deploy_kwargs(
618619
default_config_name = kwargs.config_name or get_top_ranked_config_name(
619620
**get_model_info_default_kwargs(kwargs, include_config_name=False),
620621
scope=JumpStartScriptScope.INFERENCE,
622+
instance_type=kwargs.instance_type,
621623
)
622624

623625
kwargs.config_name = kwargs.config_name or default_config_name
@@ -927,6 +929,12 @@ def get_init_kwargs(
927929

928930
model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs)
929931
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
932+
933+
# Add instance type before config selection so config compatibility can be checked
934+
model_init_kwargs = _add_instance_type_to_kwargs(
935+
kwargs=model_init_kwargs, disable_instance_type_logging=disable_instance_type_logging
936+
)
937+
930938
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
931939

932940
model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
@@ -936,10 +944,6 @@ def get_init_kwargs(
936944

937945
model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs)
938946

939-
model_init_kwargs = _add_instance_type_to_kwargs(
940-
kwargs=model_init_kwargs, disable_instance_type_logging=disable_instance_type_logging
941-
)
942-
943947
model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs)
944948

945949
if hub_arn:

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,10 +1723,10 @@ def get_top_config_from_ranking(
17231723
ranked_config_names = rankings.rankings
17241724
for config_name in ranked_config_names:
17251725
resolved_config = self.configs[config_name].resolved_config
1726-
if instance_type and instance_type not in getattr(
1727-
resolved_config, instance_type_attribute
1728-
):
1729-
continue
1726+
if instance_type:
1727+
supported_instance_types = getattr(resolved_config, instance_type_attribute, [])
1728+
if supported_instance_types and instance_type not in supported_instance_types:
1729+
continue
17301730
return self.configs[config_name]
17311731

17321732
return None

src/sagemaker/jumpstart/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,9 +1233,14 @@ def get_top_ranked_config_name(
12331233
tolerate_vulnerable_model: bool = False,
12341234
hub_arn: Optional[str] = None,
12351235
ranking_name: enums.JumpStartConfigRankingName = enums.JumpStartConfigRankingName.DEFAULT,
1236+
instance_type: Optional[str] = None,
12361237
) -> Optional[str]:
12371238
"""Returns the top ranked config name for the given model ID and region.
12381239
1240+
Args:
1241+
instance_type (Optional[str]): The instance type to filter configs by compatibility.
1242+
If provided, only configs that support this instance type will be considered.
1243+
12391244
Raises:
12401245
ValueError: If the script scope is not supported by JumpStart.
12411246
"""
@@ -1254,15 +1259,15 @@ def get_top_ranked_config_name(
12541259
if scope == enums.JumpStartScriptScope.INFERENCE:
12551260
return (
12561261
model_specs.inference_configs.get_top_config_from_ranking(
1257-
ranking_name=ranking_name
1262+
ranking_name=ranking_name, instance_type=instance_type
12581263
).config_name
12591264
if model_specs.inference_configs
12601265
else None
12611266
)
12621267
if scope == enums.JumpStartScriptScope.TRAINING:
12631268
return (
12641269
model_specs.training_configs.get_top_config_from_ranking(
1265-
ranking_name=ranking_name
1270+
ranking_name=ranking_name, instance_type=instance_type
12661271
).config_name
12671272
if model_specs.training_configs
12681273
else None

0 commit comments

Comments
 (0)