@@ -1310,6 +1310,7 @@ def fit(
13101310 logs : str = "All" ,
13111311 job_name : Optional [str ] = None ,
13121312 experiment_config : Optional [Dict [str , str ]] = None ,
1313+ accept_eula : Optional [bool ] = None ,
13131314 ):
13141315 """Train a model using the input training dataset.
13151316
@@ -1363,14 +1364,21 @@ def fit(
13631364 * Both `ExperimentName` and `TrialName` will be ignored if the Estimator instance
13641365 is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
13651366 However, the value of `TrialComponentDisplayName` is honored for display in Studio.
1367+ accept_eula (bool): For models that require a Model Access Config, specify True or
1368+ False to indicate whether model terms of use have been accepted.
1369+ The `accept_eula` value must be explicitly defined as `True` in order to
1370+ accept the end-user license agreement (EULA) that some
1371+ models require. (Default: None).
13661372 Returns:
13671373 None or pipeline step arguments in case the Estimator instance is built with
13681374 :class:`~sagemaker.workflow.pipeline_context.PipelineSession`
13691375 """
13701376 self ._prepare_for_training (job_name = job_name )
13711377
13721378 experiment_config = check_and_get_run_experiment_config (experiment_config )
1373- self .latest_training_job = _TrainingJob .start_new (self , inputs , experiment_config )
1379+ self .latest_training_job = _TrainingJob .start_new (
1380+ self , inputs , experiment_config , accept_eula
1381+ )
13741382 self .jobs .append (self .latest_training_job )
13751383 forward_to_mlflow_tracking_server = False
13761384 if os .environ .get ("MLFLOW_TRACKING_URI" ) and self .enable_network_isolation ():
@@ -2484,7 +2492,7 @@ class _TrainingJob(_Job):
24842492 """Placeholder docstring"""
24852493
24862494 @classmethod
2487- def start_new (cls , estimator , inputs , experiment_config ):
2495+ def start_new (cls , estimator , inputs , experiment_config , accept_eula ):
24882496 """Create a new Amazon SageMaker training job from the estimator.
24892497
24902498 Args:
@@ -2504,19 +2512,24 @@ def start_new(cls, estimator, inputs, experiment_config):
25042512 will be unassociated.
25052513 * `TrialComponentDisplayName` is used for display in Studio.
25062514 * `RunName` is used to record an experiment run.
2515+ accept_eula (bool): For models that require a Model Access Config, specify True or
2516+ False to indicate whether model terms of use have been accepted.
2517+ The `accept_eula` value must be explicitly defined as `True` in order to
2518+ accept the end-user license agreement (EULA) that some
2519+ models require. (Default: None).
25072520 Returns:
25082521 sagemaker.estimator._TrainingJob: Constructed object that captures
25092522 all information about the started training job.
25102523 """
2511- train_args = cls ._get_train_args (estimator , inputs , experiment_config )
2524+ train_args = cls ._get_train_args (estimator , inputs , experiment_config , accept_eula )
25122525
25132526 logger .debug ("Train args after processing defaults: %s" , train_args )
25142527 estimator .sagemaker_session .train (** train_args )
25152528
25162529 return cls (estimator .sagemaker_session , estimator ._current_job_name )
25172530
25182531 @classmethod
2519- def _get_train_args (cls , estimator , inputs , experiment_config ):
2532+ def _get_train_args (cls , estimator , inputs , experiment_config , accept_eula ):
25202533 """Constructs a dict of arguments for an Amazon SageMaker training job from the estimator.
25212534
25222535 Args:
@@ -2536,6 +2549,11 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
25362549 will be unassociated.
25372550 * `TrialComponentDisplayName` is used for display in Studio.
25382551 * `RunName` is used to record an experiment run.
2552+ accept_eula (bool): For models that require a Model Access Config, specify True or
2553+ False to indicate whether model terms of use have been accepted.
2554+ The `accept_eula` value must be explicitly defined as `True` in order to
2555+ accept the end-user license agreement (EULA) that some
2556+ models require. (Default: None).
25392557
25402558 Returns:
25412559 Dict: dict for `sagemaker.session.Session.train` method
@@ -2652,6 +2670,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
26522670 if estimator .get_session_chaining_config () is not None :
26532671 train_args ["session_chaining_config" ] = estimator .get_session_chaining_config ()
26542672
2673+ if accept_eula is not None :
2674+ cls ._set_accept_eula_for_input_data_config (train_args , accept_eula )
2675+
26552676 return train_args
26562677
26572678 @classmethod
@@ -2674,6 +2695,42 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
26742695 raise ValueError ("Setting checkpoint_local_path is not supported in local mode." )
26752696 train_args ["checkpoint_local_path" ] = estimator .checkpoint_local_path
26762697
2698+ @classmethod
2699+ def _set_accept_eula_for_input_data_config (cls , train_args , accept_eula ):
2700+ """Set the AcceptEula flag for all input data configurations.
2701+
2702+ This method sets the AcceptEula flag in the ModelAccessConfig for all S3DataSources
2703+ in the InputDataConfig array. It handles cases where keys might not exist in the
2704+ nested dictionary structure.
2705+
2706+ Args:
2707+ train_args (dict): The training job arguments dictionary
2708+ accept_eula (bool): The value to set for AcceptEula flag
2709+ """
2710+ if "InputDataConfig" not in train_args :
2711+ return
2712+
2713+ eula_count = 0
2714+ s3_uris = []
2715+
2716+ for idx in range (len (train_args ["InputDataConfig" ])):
2717+ if "DataSource" in train_args ["InputDataConfig" ][idx ]:
2718+ data_source = train_args ["InputDataConfig" ][idx ]["DataSource" ]
2719+ if "S3DataSource" in data_source :
2720+ s3_data_source = data_source ["S3DataSource" ]
2721+ if "ModelAccessConfig" not in s3_data_source :
2722+ s3_data_source ["ModelAccessConfig" ] = {}
2723+ s3_data_source ["ModelAccessConfig" ]["AcceptEula" ] = accept_eula
2724+ eula_count += 1
2725+
2726+ # Collect S3 URI if available
2727+ if "S3Uri" in s3_data_source :
2728+ s3_uris .append (s3_data_source ["S3Uri" ])
2729+
2730+ # Log info if more than one EULA needs to be accepted
2731+ if eula_count > 1 :
2732+ logger .info ("Accepting EULA for %d S3 data sources: %s" , eula_count , ", " .join (s3_uris ))
2733+
26772734 @classmethod
26782735 def _is_local_channel (cls , input_uri ):
26792736 """Placeholder docstring"""
0 commit comments