1919 SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY ,
2020)
2121from sagemaker .jumpstart .enums import (
22+ JumpStartModelType ,
2223 JumpStartScriptScope ,
2324)
2425from sagemaker .jumpstart .utils import (
@@ -41,6 +42,7 @@ def _retrieve_default_environment_variables(
4142 instance_type : Optional [str ] = None ,
4243 script : JumpStartScriptScope = JumpStartScriptScope .INFERENCE ,
4344 config_name : Optional [str ] = None ,
45+ model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
4446) -> Dict [str , str ]:
4547 """Retrieves the inference environment variables for the model matching the given arguments.
4648
@@ -73,6 +75,8 @@ def _retrieve_default_environment_variables(
7375 script (JumpStartScriptScope): The JumpStart script for which to retrieve
7476 environment variables.
7577 config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
78+ model_type (JumpStartModelType): The type of the model, can be open weights model
79+ or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
7680 Returns:
7781 dict: the inference environment variables to use for the model.
7882 """
@@ -91,6 +95,7 @@ def _retrieve_default_environment_variables(
9195 tolerate_deprecated_model = tolerate_deprecated_model ,
9296 sagemaker_session = sagemaker_session ,
9397 config_name = config_name ,
98+ model_type = model_type ,
9499 )
95100
96101 default_environment_variables : Dict [str , str ] = {}
@@ -130,6 +135,7 @@ def _retrieve_default_environment_variables(
130135 sagemaker_session = sagemaker_session ,
131136 instance_type = instance_type ,
132137 config_name = config_name ,
138+ model_type = model_type ,
133139 )
134140 )
135141
@@ -178,6 +184,7 @@ def _retrieve_gated_model_uri_env_var_value(
178184 sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
179185 instance_type : Optional [str ] = None ,
180186 config_name : Optional [str ] = None ,
187+ model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
181188) -> Optional [str ]:
182189 """Retrieves the gated model env var URI matching the given arguments.
183190
@@ -204,7 +211,8 @@ def _retrieve_gated_model_uri_env_var_value(
204211 instance_type (str): An instance type to optionally supply in order to get
205212 environment variables specific for the instance type.
206213 config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
207-
214+ model_type (JumpStartModelType): The type of the model, can be open weights model
215+ or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
208216 Returns:
209217 Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
210218 have gated training artifacts.
@@ -227,6 +235,7 @@ def _retrieve_gated_model_uri_env_var_value(
227235 tolerate_deprecated_model = tolerate_deprecated_model ,
228236 sagemaker_session = sagemaker_session ,
229237 config_name = config_name ,
238+ model_type = model_type ,
230239 )
231240
232241 s3_key : Optional [str ] = (
0 commit comments