1919
2020from sagemaker .jumpstart import utils as jumpstart_utils
2121from sagemaker .jumpstart import artifacts
22+ from sagemaker .jumpstart .constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
2223from sagemaker .jumpstart .enums import HyperparameterValidationMode
2324from sagemaker .jumpstart .validators import validate_hyperparameters
25+ from sagemaker .session import Session
2426
2527logger = logging .getLogger (__name__ )
2628
@@ -32,6 +34,7 @@ def retrieve_default(
3234 include_container_hyperparameters : bool = False ,
3335 tolerate_vulnerable_model : bool = False ,
3436 tolerate_deprecated_model : bool = False ,
37+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
3538) -> Dict [str , str ]:
3639 """Retrieves the default training hyperparameters for the model matching the given arguments.
3740
@@ -56,6 +59,10 @@ def retrieve_default(
5659 tolerate_deprecated_model (bool): True if deprecated models should be tolerated
5760 (exception not raised). False if these models should raise an exception.
5861 (Default: False).
62+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
63+ object, used for SageMaker interactions. If not
64+ specified, one is created using the default AWS configuration
65+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
5966 Returns:
6067 dict: The hyperparameters to use for the model.
6168
@@ -74,6 +81,7 @@ def retrieve_default(
7481 include_container_hyperparameters ,
7582 tolerate_vulnerable_model ,
7683 tolerate_deprecated_model ,
84+ sagemaker_session = sagemaker_session ,
7785 )
7886
7987
@@ -83,6 +91,9 @@ def validate(
8391 model_version : Optional [str ] = None ,
8492 hyperparameters : Optional [dict ] = None ,
8593 validation_mode : HyperparameterValidationMode = HyperparameterValidationMode .VALIDATE_PROVIDED ,
94+ tolerate_vulnerable_model : bool = False ,
95+ tolerate_deprecated_model : bool = False ,
96+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
8697) -> None :
8798 """Validates hyperparameters for models.
8899
@@ -100,6 +111,17 @@ def validate(
100111 If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated.
101112 If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated.
102113 (Default: None).
114+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
115+ specifications should be tolerated (exception not raised). If False, raises an
116+ exception if the script used by this version of the model has dependencies with known
117+ security vulnerabilities. (Default: False).
118+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated
119+ (exception not raised). False if these models should raise an exception.
120+ (Default: False).
121+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
122+ object, used for SageMaker interactions. If not
123+ specified, one is created using the default AWS configuration
124+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
103125
104126 Raises:
105127 JumpStartHyperparametersError: If the hyperparameter is not formatted correctly,
@@ -125,4 +147,7 @@ def validate(
125147 hyperparameters = hyperparameters ,
126148 validation_mode = validation_mode ,
127149 region = region ,
150+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
151+ tolerate_deprecated_model = tolerate_deprecated_model ,
152+ sagemaker_session = sagemaker_session ,
128153 )
0 commit comments