@@ -1632,17 +1632,29 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
16321632 return neo_bucket
16331633
16341634
1635- def remove_env_var_from_estimator_kwargs_if_accept_eula_present (
1636- init_kwargs : dict , accept_eula : Optional [bool ]
1635+ def remove_env_var_from_estimator_kwargs_if_model_access_config_present (
1636+ init_kwargs : dict , model_access_config : Optional [dict ]
16371637):
1638- """Remove env vars if access configs are used
1638+ """Remove env vars if ModelAccessConfig is used
16391639
16401640 Args:
16411641 init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated.
16421642 accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
16431643 """
1644- if accept_eula is not None and init_kwargs ["environment" ]:
1645- del init_kwargs ["environment" ][constants .SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY ]
1644+ if (
1645+ model_access_config is not None
1646+ and init_kwargs .get ("environment" ) is not None
1647+ and init_kwargs .get ("model_uri" ) is not None
1648+ ):
1649+ if (
1650+ constants .SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY
1651+ in init_kwargs ["environment" ]
1652+ ):
1653+ del init_kwargs ["environment" ][
1654+ constants .SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY
1655+ ]
1656+ if "accept_eula" in init_kwargs ["environment" ]:
1657+ del init_kwargs ["environment" ]["accept_eula" ]
16461658
16471659
16481660def get_hub_access_config (hub_content_arn : Optional [str ]):
@@ -1659,16 +1671,24 @@ def get_hub_access_config(hub_content_arn: Optional[str]):
16591671 return hub_access_config
16601672
16611673
1662- def get_model_access_config (accept_eula : Optional [bool ]):
1674+ def get_model_access_config (accept_eula : Optional [bool ], environment : Optional [ dict ] ):
16631675 """Get access configs
16641676
16651677 Args:
16661678 accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
16671679 """
1680+ env_var_eula = environment .get ("accept_eula" ) if environment else None
1681+ if env_var_eula is not None and accept_eula is not None :
1682+ raise ValueError (
1683+ "Cannot pass in both accept_eula and environment variables. "
1684+ "Please remove the environment variable and pass in the accept_eula parameter."
1685+ )
1686+
1687+ model_access_config = None
1688+ if env_var_eula is not None :
1689+ model_access_config = {"AcceptEula" : env_var_eula == "true" }
16681690 if accept_eula is not None :
16691691 model_access_config = {"AcceptEula" : accept_eula }
1670- else :
1671- model_access_config = None
16721692
16731693 return model_access_config
16741694
0 commit comments