@@ -246,6 +246,32 @@ def _add_instance_type_to_kwargs(
246
246
kwargs .instance_type ,
247
247
)
248
248
249
+ specs = verify_model_region_and_return_specs (
250
+ model_id = kwargs .model_id ,
251
+ version = kwargs .model_version ,
252
+ scope = JumpStartScriptScope .INFERENCE ,
253
+ region = kwargs .region ,
254
+ tolerate_vulnerable_model = kwargs .tolerate_vulnerable_model ,
255
+ tolerate_deprecated_model = kwargs .tolerate_deprecated_model ,
256
+ sagemaker_session = kwargs .sagemaker_session ,
257
+ model_type = kwargs .model_type ,
258
+ config_name = kwargs .config_name ,
259
+ )
260
+
261
+ if specs .inference_configs and kwargs .config_name not in specs .inference_configs .configs :
262
+ return kwargs
263
+
264
+ resolved_config = (
265
+ specs .inference_configs .configs [kwargs .config_name ].resolved_config
266
+ if specs .inference_configs
267
+ else None
268
+ )
269
+ if resolved_config is None :
270
+ return kwargs
271
+ supported_instance_types = resolved_config .get ("supported_inference_instance_types" , [])
272
+ if kwargs .instance_type not in supported_instance_types :
273
+ JUMPSTART_LOGGER .warning ("Overriding instance type to %s" , kwargs .instance_type )
274
+
249
275
return kwargs
250
276
251
277
@@ -683,28 +709,6 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
683
709
if kwargs .config_name is None :
684
710
return kwargs
685
711
686
- specs = verify_model_region_and_return_specs (
687
- model_id = kwargs .model_id ,
688
- version = kwargs .model_version ,
689
- scope = JumpStartScriptScope .INFERENCE ,
690
- region = kwargs .region ,
691
- tolerate_vulnerable_model = kwargs .tolerate_vulnerable_model ,
692
- tolerate_deprecated_model = kwargs .tolerate_deprecated_model ,
693
- sagemaker_session = temp_session ,
694
- model_type = kwargs .model_type ,
695
- config_name = kwargs .config_name ,
696
- )
697
-
698
- resolved_config = (
699
- specs .inference_configs .configs [kwargs .config_name ].resolved_config
700
- if specs .inference_configs
701
- else None
702
- )
703
- if resolved_config is None :
704
- return kwargs
705
- supported_instance_types = resolved_config .get ("supported_inference_instance_types" , [])
706
- if kwargs .instance_type not in supported_instance_types :
707
- JUMPSTART_LOGGER .warning ("Overriding instance type to %s" , kwargs .instance_type )
708
712
return kwargs
709
713
710
714
@@ -873,10 +877,10 @@ def get_deploy_kwargs(
873
877
kwargs = deploy_kwargs , training_config_name = training_config_name
874
878
)
875
879
876
- deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs (kwargs = deploy_kwargs )
877
-
878
880
deploy_kwargs = _add_model_version_to_kwargs (kwargs = deploy_kwargs )
879
881
882
+ deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs (kwargs = deploy_kwargs )
883
+
880
884
deploy_kwargs = _add_endpoint_name_to_kwargs (kwargs = deploy_kwargs )
881
885
882
886
deploy_kwargs = _add_instance_type_to_kwargs (kwargs = deploy_kwargs )
@@ -1060,8 +1064,8 @@ def get_init_kwargs(
1060
1064
)
1061
1065
1062
1066
model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs (kwargs = model_init_kwargs )
1063
- model_init_kwargs = _add_config_name_to_init_kwargs (kwargs = model_init_kwargs )
1064
1067
model_init_kwargs = _add_model_version_to_kwargs (kwargs = model_init_kwargs )
1068
+ model_init_kwargs = _add_config_name_to_init_kwargs (kwargs = model_init_kwargs )
1065
1069
1066
1070
model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs (
1067
1071
kwargs = model_init_kwargs
0 commit comments