29
29
)
30
30
from sagemaker .jumpstart .artifacts .resource_names import _retrieve_resource_name_base
31
31
from sagemaker .jumpstart .constants import (
32
- DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
33
32
INFERENCE_ENTRY_POINT_SCRIPT_NAME ,
34
33
JUMPSTART_DEFAULT_REGION_NAME ,
35
34
JUMPSTART_LOGGER ,
63
62
64
63
from sagemaker .jumpstart .factory .utils import (
65
64
_set_temp_sagemaker_session_if_not_set ,
66
- get_model_info_kwargs ,
65
+ get_model_info_default_kwargs ,
67
66
)
68
67
from sagemaker .model_monitor .data_capture_config import DataCaptureConfig
69
68
from sagemaker .base_predictor import Predictor
@@ -224,7 +223,7 @@ def _add_instance_type_to_kwargs(
224
223
225
224
orig_instance_type = kwargs .instance_type
226
225
kwargs .instance_type = kwargs .instance_type or instance_types .retrieve_default (
227
- ** get_model_info_kwargs (kwargs ),
226
+ ** get_model_info_default_kwargs (kwargs ),
228
227
scope = JumpStartScriptScope .INFERENCE ,
229
228
training_instance_type = kwargs .training_instance_type ,
230
229
)
@@ -265,7 +264,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
265
264
return kwargs
266
265
267
266
kwargs .image_uri = kwargs .image_uri or image_uris .retrieve (
268
- ** get_model_info_kwargs (kwargs ),
267
+ ** get_model_info_default_kwargs (kwargs ),
269
268
framework = None ,
270
269
image_scope = JumpStartScriptScope .INFERENCE ,
271
270
instance_type = kwargs .instance_type ,
@@ -298,7 +297,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
298
297
kwargs .model_data = None
299
298
return kwargs
300
299
301
- model_info_kwargs = get_model_info_kwargs (kwargs )
300
+ model_info_kwargs = get_model_info_default_kwargs (kwargs )
302
301
model_data : Union [str , dict ] = kwargs .model_data or model_uris .retrieve (
303
302
** model_info_kwargs ,
304
303
model_scope = JumpStartScriptScope .INFERENCE ,
@@ -336,9 +335,9 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
336
335
337
336
source_dir = kwargs .source_dir
338
337
339
- if _model_supports_inference_script_uri (** get_model_info_kwargs (kwargs )):
338
+ if _model_supports_inference_script_uri (** get_model_info_default_kwargs (kwargs )):
340
339
source_dir = source_dir or script_uris .retrieve (
341
- ** get_model_info_kwargs (kwargs ), script_scope = JumpStartScriptScope .INFERENCE
340
+ ** get_model_info_default_kwargs (kwargs ), script_scope = JumpStartScriptScope .INFERENCE
342
341
)
343
342
344
343
kwargs .source_dir = source_dir
@@ -355,7 +354,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod
355
354
356
355
entry_point = kwargs .entry_point
357
356
358
- if _model_supports_inference_script_uri (** get_model_info_kwargs (kwargs )):
357
+ if _model_supports_inference_script_uri (** get_model_info_default_kwargs (kwargs )):
359
358
360
359
entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME
361
360
@@ -377,7 +376,7 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw
377
376
env = {}
378
377
379
378
extra_env_vars = environment_variables .retrieve_default (
380
- ** get_model_info_kwargs (kwargs ),
379
+ ** get_model_info_default_kwargs (kwargs ),
381
380
include_aws_sdk_env_vars = False ,
382
381
script = JumpStartScriptScope .INFERENCE ,
383
382
instance_type = kwargs .instance_type ,
@@ -402,7 +401,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt
402
401
"""Sets model package arn based on default or override, returns full kwargs."""
403
402
404
403
model_package_arn = kwargs .model_package_arn or _retrieve_model_package_arn (
405
- ** get_model_info_kwargs (kwargs ),
404
+ ** get_model_info_default_kwargs (kwargs ),
406
405
instance_type = kwargs .instance_type ,
407
406
scope = JumpStartScriptScope .INFERENCE ,
408
407
)
@@ -414,7 +413,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt
414
413
def _add_extra_model_kwargs (kwargs : JumpStartModelInitKwargs ) -> JumpStartModelInitKwargs :
415
414
"""Sets extra kwargs based on default or override, returns full kwargs."""
416
415
417
- model_kwargs_to_add = _retrieve_model_init_kwargs (** get_model_info_kwargs (kwargs ))
416
+ model_kwargs_to_add = _retrieve_model_init_kwargs (** get_model_info_default_kwargs (kwargs ))
418
417
419
418
for key , value in model_kwargs_to_add .items ():
420
419
if getattr (kwargs , key ) is None :
@@ -442,7 +441,7 @@ def _add_endpoint_name_to_kwargs(
442
441
) -> JumpStartModelDeployKwargs :
443
442
"""Sets resource name based on default or override, returns full kwargs."""
444
443
445
- default_endpoint_name = _retrieve_resource_name_base (** get_model_info_kwargs (kwargs ))
444
+ default_endpoint_name = _retrieve_resource_name_base (** get_model_info_default_kwargs (kwargs ))
446
445
447
446
kwargs .endpoint_name = kwargs .endpoint_name or (
448
447
name_from_base (default_endpoint_name ) if default_endpoint_name is not None else None
@@ -456,7 +455,7 @@ def _add_model_name_to_kwargs(
456
455
) -> JumpStartModelInitKwargs :
457
456
"""Sets resource name based on default or override, returns full kwargs."""
458
457
459
- default_model_name = _retrieve_resource_name_base (** get_model_info_kwargs (kwargs ))
458
+ default_model_name = _retrieve_resource_name_base (** get_model_info_default_kwargs (kwargs ))
460
459
461
460
kwargs .name = kwargs .name or (
462
461
name_from_base (default_model_name ) if default_model_name is not None else None
@@ -498,7 +497,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any]
498
497
"""Sets extra kwargs based on default or override, returns full kwargs."""
499
498
500
499
deploy_kwargs_to_add = _retrieve_model_deploy_kwargs (
501
- ** get_model_info_kwargs (kwargs ), instance_type = kwargs .instance_type
500
+ ** get_model_info_default_kwargs (kwargs ), instance_type = kwargs .instance_type
502
501
)
503
502
504
503
for key , value in deploy_kwargs_to_add .items ():
@@ -512,7 +511,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
512
511
"""Sets the resource requirements based on the default or an override. Returns full kwargs."""
513
512
514
513
kwargs .resources = kwargs .resources or resource_requirements .retrieve_default (
515
- ** get_model_info_kwargs (kwargs ),
514
+ ** get_model_info_default_kwargs (kwargs ),
516
515
scope = JumpStartScriptScope .INFERENCE ,
517
516
instance_type = kwargs .instance_type ,
518
517
)
@@ -548,7 +547,7 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
548
547
"""
549
548
550
549
kwargs .config_name = kwargs .config_name or get_top_ranked_config_name (
551
- ** get_model_info_kwargs (kwargs , include_config_name = False ),
550
+ ** get_model_info_default_kwargs (kwargs , include_config_name = False ),
552
551
scope = JumpStartScriptScope .INFERENCE ,
553
552
)
554
553
@@ -605,7 +604,7 @@ def _add_config_name_to_deploy_kwargs(
605
604
606
605
else :
607
606
default_config_name = kwargs .config_name or get_top_ranked_config_name (
608
- ** get_model_info_kwargs (kwargs , include_config_name = False ),
607
+ ** get_model_info_default_kwargs (kwargs , include_config_name = False ),
609
608
scope = JumpStartScriptScope .INFERENCE ,
610
609
)
611
610
@@ -689,14 +688,11 @@ def get_deploy_kwargs(
689
688
)
690
689
deploy_kwargs = _set_temp_sagemaker_session_if_not_set (kwargs = deploy_kwargs )
691
690
deploy_kwargs .specs = verify_model_region_and_return_specs (
692
- model_id = model_id ,
693
- version = model_version ,
694
- hub_arn = hub_arn ,
695
- model_type = model_type ,
696
- region = region ,
691
+ ** get_model_info_default_kwargs (
692
+ deploy_kwargs , include_model_version = False , include_tolerate_flags = False
693
+ ),
694
+ version = deploy_kwargs .model_version or "*" ,
697
695
scope = JumpStartScriptScope .INFERENCE ,
698
- sagemaker_session = deploy_kwargs .sagemaker_session ,
699
- config_name = config_name ,
700
696
# We set these flags to True to retrieve the json specs.
701
697
# Exceptions will be thrown later if these are not tolerated.
702
698
tolerate_deprecated_model = True ,
@@ -769,6 +765,7 @@ def get_register_kwargs(
769
765
register_kwargs = JumpStartModelRegisterKwargs (
770
766
model_id = model_id ,
771
767
model_version = model_version ,
768
+ config_name = config_name ,
772
769
hub_arn = hub_arn ,
773
770
model_type = model_type ,
774
771
region = region ,
@@ -802,14 +799,11 @@ def get_register_kwargs(
802
799
)
803
800
804
801
register_kwargs .specs = verify_model_region_and_return_specs (
805
- model_id = model_id ,
806
- version = model_version ,
807
- hub_arn = hub_arn ,
808
- model_type = model_type ,
809
- region = region ,
802
+ ** get_model_info_default_kwargs (
803
+ register_kwargs , include_model_version = False , include_tolerate_flags = False
804
+ ),
805
+ version = register_kwargs .model_version or "*" ,
810
806
scope = JumpStartScriptScope .INFERENCE ,
811
- sagemaker_session = sagemaker_session ,
812
- config_name = config_name ,
813
807
# We set these flags to True to retrieve the json specs.
814
808
# Exceptions will be thrown later if these are not tolerated.
815
809
tolerate_deprecated_model = True ,
@@ -898,14 +892,11 @@ def get_init_kwargs(
898
892
)
899
893
model_init_kwargs = _set_temp_sagemaker_session_if_not_set (kwargs = model_init_kwargs )
900
894
model_init_kwargs .specs = verify_model_region_and_return_specs (
901
- model_id = model_init_kwargs .model_id ,
902
- version = model_init_kwargs .model_version ,
903
- hub_arn = model_init_kwargs .hub_arn ,
904
- model_type = model_init_kwargs .model_type ,
905
- region = model_init_kwargs .region ,
895
+ ** get_model_info_default_kwargs (
896
+ model_init_kwargs , include_model_version = False , include_tolerate_flags = False
897
+ ),
898
+ version = model_init_kwargs .model_version or "*" ,
906
899
scope = JumpStartScriptScope .INFERENCE ,
907
- sagemaker_session = model_init_kwargs .sagemaker_session ,
908
- config_name = model_init_kwargs .config_name ,
909
900
# We set these flags to True to retrieve the json specs.
910
901
# Exceptions will be thrown later if these are not tolerated.
911
902
tolerate_deprecated_model = True ,
0 commit comments