Skip to content

Commit 3fc60f2

Browse files
committed
chore: address minor issues
1 parent 49a1c82 commit 3fc60f2

File tree

2 files changed

+30
-26
lines changed

2 files changed

+30
-26
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,32 @@ def _add_instance_type_to_kwargs(
246246
kwargs.instance_type,
247247
)
248248

249+
specs = verify_model_region_and_return_specs(
250+
model_id=kwargs.model_id,
251+
version=kwargs.model_version,
252+
scope=JumpStartScriptScope.INFERENCE,
253+
region=kwargs.region,
254+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
255+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
256+
sagemaker_session=kwargs.sagemaker_session,
257+
model_type=kwargs.model_type,
258+
config_name=kwargs.config_name,
259+
)
260+
261+
if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs:
262+
return kwargs
263+
264+
resolved_config = (
265+
specs.inference_configs.configs[kwargs.config_name].resolved_config
266+
if specs.inference_configs
267+
else None
268+
)
269+
if resolved_config is None:
270+
return kwargs
271+
supported_instance_types = resolved_config.get("supported_inference_instance_types", [])
272+
if kwargs.instance_type not in supported_instance_types:
273+
JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type)
274+
249275
return kwargs
250276

251277

@@ -683,28 +709,6 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
683709
if kwargs.config_name is None:
684710
return kwargs
685711

686-
specs = verify_model_region_and_return_specs(
687-
model_id=kwargs.model_id,
688-
version=kwargs.model_version,
689-
scope=JumpStartScriptScope.INFERENCE,
690-
region=kwargs.region,
691-
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
692-
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
693-
sagemaker_session=temp_session,
694-
model_type=kwargs.model_type,
695-
config_name=kwargs.config_name,
696-
)
697-
698-
resolved_config = (
699-
specs.inference_configs.configs[kwargs.config_name].resolved_config
700-
if specs.inference_configs
701-
else None
702-
)
703-
if resolved_config is None:
704-
return kwargs
705-
supported_instance_types = resolved_config.get("supported_inference_instance_types", [])
706-
if kwargs.instance_type not in supported_instance_types:
707-
JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type)
708712
return kwargs
709713

710714

@@ -873,10 +877,10 @@ def get_deploy_kwargs(
873877
kwargs=deploy_kwargs, training_config_name=training_config_name
874878
)
875879

876-
deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs)
877-
878880
deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs)
879881

882+
deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs)
883+
880884
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
881885

882886
deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
@@ -1060,8 +1064,8 @@ def get_init_kwargs(
10601064
)
10611065

10621066
model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs)
1063-
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
10641067
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
1068+
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
10651069

10661070
model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
10671071
kwargs=model_init_kwargs

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def test_jumpstart_model_with_deployment_configs(setup):
399399

400400

401401
def test_jumpstart_session_with_config_name():
402-
model = JumpStartModel(model_id="meta-textgeneration-llama-2-7b", model_version="*")
402+
model = JumpStartModel(model_id="meta-textgeneration-llama-2-7b")
403403
assert model.config_name != None
404404
session = model.sagemaker_session
405405

0 commit comments

Comments
 (0)