diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 2ed2deb803..9ebc2880bc 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -25,6 +25,7 @@ from sagemaker.jumpstart.hub.utils import ( construct_hub_model_arn_from_inputs, construct_hub_model_reference_arn_from_inputs, + generate_hub_arn_for_init_kwargs, ) from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.session import Session @@ -291,6 +292,10 @@ def get_model_specs( # Users only input model id, not contentType, so first try to describe with ModelReference, then with Model if hub_arn: try: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_arn, region=region, session=sagemaker_session + ) + hub_model_arn = construct_hub_model_reference_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index af2fb5bc54..4daf9b1810 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -41,7 +41,7 @@ validate_model_id_and_get_type, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, - remove_env_var_from_estimator_kwargs_if_accept_eula_present, + remove_env_var_from_estimator_kwargs_if_model_access_config_present, get_model_access_config, get_hub_access_config, ) @@ -616,6 +616,7 @@ def _validate_model_id_and_get_type_hook(): self.tolerate_vulnerable_model = estimator_init_kwargs.tolerate_vulnerable_model self.instance_count = estimator_init_kwargs.instance_count self.region = estimator_init_kwargs.region + self.environment = estimator_init_kwargs.environment self.orig_predictor_cls = None self.role = estimator_init_kwargs.role self.sagemaker_session = estimator_init_kwargs.sagemaker_session @@ -693,7 +694,7 @@ def fit( accept the end-user license agreement (EULA) that some models require. (Default: None). """ - self.model_access_config = get_model_access_config(accept_eula) + self.model_access_config = get_model_access_config(accept_eula, self.environment) self.hub_access_config = get_hub_access_config( hub_content_arn=self.init_kwargs.get("model_reference_arn", None) ) @@ -713,7 +714,9 @@ def fit( config_name=self.config_name, hub_access_config=self.hub_access_config, ) - remove_env_var_from_estimator_kwargs_if_accept_eula_present(self.init_kwargs, accept_eula) + remove_env_var_from_estimator_kwargs_if_model_access_config_present( + self.init_kwargs, self.model_access_config + ) return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index bd81226727..15f9e9b52e 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1632,17 +1632,29 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str: return neo_bucket -def remove_env_var_from_estimator_kwargs_if_accept_eula_present( - init_kwargs: dict, accept_eula: Optional[bool] +def remove_env_var_from_estimator_kwargs_if_model_access_config_present( + init_kwargs: dict, model_access_config: Optional[dict] ): - """Remove env vars if access configs are used + """Remove env vars if ModelAccessConfig is used Args: init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated. accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit(). """ - if accept_eula is not None and init_kwargs["environment"]: - del init_kwargs["environment"][constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY] + if ( + model_access_config is not None + and init_kwargs.get("environment") is not None + and init_kwargs.get("model_uri") is not None + ): + if ( + constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY + in init_kwargs["environment"] + ): + del init_kwargs["environment"][ + constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY + ] + if "accept_eula" in init_kwargs["environment"]: + del init_kwargs["environment"]["accept_eula"] def get_hub_access_config(hub_content_arn: Optional[str]): @@ -1659,16 +1671,24 @@ def get_hub_access_config(hub_content_arn: Optional[str]): return hub_access_config -def get_model_access_config(accept_eula: Optional[bool]): +def get_model_access_config(accept_eula: Optional[bool], environment: Optional[dict]): """Get access configs Args: accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit(). """ + env_var_eula = environment.get("accept_eula") if environment else None + if env_var_eula is not None and accept_eula is not None: + raise ValueError( + "Cannot pass in both accept_eula and environment variables. " + "Please remove the environment variable and pass in the accept_eula parameter." + ) + + model_access_config = None + if env_var_eula is not None: + model_access_config = {"AcceptEula": env_var_eula == "true"} if accept_eula is not None: model_access_config = {"AcceptEula": accept_eula} - else: - model_access_config = None return model_access_config