@@ -1632,16 +1632,16 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
16321632 return neo_bucket
16331633
16341634
1635- def remove_env_var_from_estimator_kwargs_if_accept_eula_present (
1636- init_kwargs : dict , accept_eula : Optional [ bool ]
1635+ def remove_env_var_from_estimator_kwargs_if_model_access_config_present (
1636+ init_kwargs : dict , model_access_config : dict | None
16371637):
1638- """Remove env vars if access configs are used
1638+ """Remove env vars if ModelAccessConfig is used
16391639
16401640 Args:
16411641 init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated.
16421642 accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
16431643 """
1644- if accept_eula is not None and init_kwargs .get ("environment" ) is not None :
1644+ if model_access_config is not None and init_kwargs .get ("environment" ) is not None :
16451645 if (
16461646 constants .SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY
16471647 in init_kwargs ["environment" ]
@@ -1672,7 +1672,7 @@ def get_model_access_config(accept_eula: Optional[bool], environment: Optional[d
16721672 accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
16731673 """
16741674 env_var_eula = environment .get ("accept_eula" ) if environment else None
1675- if env_var_eula and accept_eula is not None :
1675+ if env_var_eula is not None and accept_eula is not None :
16761676 raise ValueError (
16771677 "Cannot pass in both accept_eula and environment variables. "
16781678 "Please remove the environment variable and pass in the accept_eula parameter."
0 commit comments