@@ -581,6 +581,56 @@ def get_instance_specific_environment_variables(self, instance_type: str) -> Dic
581581
582582 return instance_family_environment_variables
583583
584+ def get_instance_specific_default_inference_instance_type (
585+ self , instance_type : str
586+ ) -> Optional [str ]:
587+ """Returns instance specific default inference instance type.
588+
589+ Returns None if a model, instance type tuple does not have instance
590+ specific inference instance types.
591+ """
592+
593+ return self ._get_instance_specific_property (
594+ instance_type , "default_inference_instance_type"
595+ )
596+
597+ def get_instance_specific_supported_inference_instance_types (
598+ self , instance_type : str
599+ ) -> List [str ]:
600+ """Returns instance specific supported inference instance types.
601+
602+ Returns empty list if a model, instance type tuple does not have instance
603+ specific inference instance types.
604+ """
605+
606+ if self .variants is None :
607+ return []
608+
609+ instance_specific_inference_instance_types : List [str ] = (
610+ self .variants .get (instance_type , {})
611+ .get ("properties" , {})
612+ .get ("supported_inference_instance_types" , [])
613+ )
614+
615+ instance_type_family = get_instance_type_family (instance_type )
616+
617+ instance_family_inference_instance_types : List [str ] = (
618+ self .variants .get (instance_type_family , {})
619+ .get ("properties" , {})
620+ .get ("supported_inference_instance_types" , [])
621+ if instance_type_family not in {"" , None }
622+ else []
623+ )
624+
625+ return sorted (
626+ list (
627+ set (
628+ instance_specific_inference_instance_types
629+ + instance_family_inference_instance_types
630+ )
631+ )
632+ )
633+
584634 def get_image_uri (self , instance_type : str , region : str ) -> Optional [str ]:
585635 """Returns image uri from instance type and region.
586636
@@ -971,6 +1021,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
9711021 "dependencies" ,
9721022 "git_config" ,
9731023 "model_package_arn" ,
1024+ "training_instance_type" ,
9741025 ]
9751026
9761027 SERIALIZATION_EXCLUSION_SET = {
@@ -981,6 +1032,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
9811032 "tolerate_deprecated_model" ,
9821033 "region" ,
9831034 "model_package_arn" ,
1035+ "training_instance_type" ,
9841036 }
9851037
9861038 def __init__ (
@@ -1009,6 +1061,7 @@ def __init__(
10091061 tolerate_vulnerable_model : Optional [bool ] = None ,
10101062 tolerate_deprecated_model : Optional [bool ] = None ,
10111063 model_package_arn : Optional [str ] = None ,
1064+ training_instance_type : Optional [str ] = None ,
10121065 ) -> None :
10131066 """Instantiates JumpStartModelInitKwargs object."""
10141067
@@ -1036,6 +1089,7 @@ def __init__(
10361089 self .tolerate_deprecated_model = tolerate_deprecated_model
10371090 self .tolerate_vulnerable_model = tolerate_vulnerable_model
10381091 self .model_package_arn = model_package_arn
1092+ self .training_instance_type = training_instance_type
10391093
10401094
10411095class JumpStartModelDeployKwargs (JumpStartKwargs ):
@@ -1065,6 +1119,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
10651119 "tolerate_vulnerable_model" ,
10661120 "tolerate_deprecated_model" ,
10671121 "sagemaker_session" ,
1122+ "training_instance_type" ,
10681123 ]
10691124
10701125 SERIALIZATION_EXCLUSION_SET = {
@@ -1074,6 +1129,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
10741129 "tolerate_deprecated_model" ,
10751130 "tolerate_vulnerable_model" ,
10761131 "sagemaker_session" ,
1132+ "training_instance_type" ,
10771133 }
10781134
10791135 def __init__ (
@@ -1101,6 +1157,7 @@ def __init__(
11011157 tolerate_deprecated_model : Optional [bool ] = None ,
11021158 tolerate_vulnerable_model : Optional [bool ] = None ,
11031159 sagemaker_session : Optional [Session ] = None ,
1160+ training_instance_type : Optional [str ] = None ,
11041161 ) -> None :
11051162 """Instantiates JumpStartModelDeployKwargs object."""
11061163
@@ -1127,6 +1184,7 @@ def __init__(
11271184 self .tolerate_vulnerable_model = tolerate_vulnerable_model
11281185 self .tolerate_deprecated_model = tolerate_deprecated_model
11291186 self .sagemaker_session = sagemaker_session
1187+ self .training_instance_type = training_instance_type
11301188
11311189
11321190class JumpStartEstimatorInitKwargs (JumpStartKwargs ):
0 commit comments