@@ -1310,6 +1310,7 @@ def fit(
1310
1310
logs : str = "All" ,
1311
1311
job_name : Optional [str ] = None ,
1312
1312
experiment_config : Optional [Dict [str , str ]] = None ,
1313
+ accept_eula : Optional [bool ] = None ,
1313
1314
):
1314
1315
"""Train a model using the input training dataset.
1315
1316
@@ -1363,14 +1364,21 @@ def fit(
1363
1364
* Both `ExperimentName` and `TrialName` will be ignored if the Estimator instance
1364
1365
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
1365
1366
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
1367
+ accept_eula (bool): For models that require a Model Access Config, specify True or
1368
+ False to indicate whether model terms of use have been accepted.
1369
+ The `accept_eula` value must be explicitly defined as `True` in order to
1370
+ accept the end-user license agreement (EULA) that some
1371
+ models require. (Default: None).
1366
1372
Returns:
1367
1373
None or pipeline step arguments in case the Estimator instance is built with
1368
1374
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
1369
1375
"""
1370
1376
self ._prepare_for_training (job_name = job_name )
1371
1377
1372
1378
experiment_config = check_and_get_run_experiment_config (experiment_config )
1373
- self .latest_training_job = _TrainingJob .start_new (self , inputs , experiment_config )
1379
+ self .latest_training_job = _TrainingJob .start_new (
1380
+ self , inputs , experiment_config , accept_eula
1381
+ )
1374
1382
self .jobs .append (self .latest_training_job )
1375
1383
forward_to_mlflow_tracking_server = False
1376
1384
if os .environ .get ("MLFLOW_TRACKING_URI" ) and self .enable_network_isolation ():
@@ -2484,7 +2492,7 @@ class _TrainingJob(_Job):
2484
2492
"""Placeholder docstring"""
2485
2493
2486
2494
@classmethod
2487
- def start_new (cls , estimator , inputs , experiment_config ):
2495
+ def start_new (cls , estimator , inputs , experiment_config , accept_eula ):
2488
2496
"""Create a new Amazon SageMaker training job from the estimator.
2489
2497
2490
2498
Args:
@@ -2504,19 +2512,24 @@ def start_new(cls, estimator, inputs, experiment_config):
2504
2512
will be unassociated.
2505
2513
* `TrialComponentDisplayName` is used for display in Studio.
2506
2514
* `RunName` is used to record an experiment run.
2515
+ accept_eula (bool): For models that require a Model Access Config, specify True or
2516
+ False to indicate whether model terms of use have been accepted.
2517
+ The `accept_eula` value must be explicitly defined as `True` in order to
2518
+ accept the end-user license agreement (EULA) that some
2519
+ models require. (Default: None).
2507
2520
Returns:
2508
2521
sagemaker.estimator._TrainingJob: Constructed object that captures
2509
2522
all information about the started training job.
2510
2523
"""
2511
- train_args = cls ._get_train_args (estimator , inputs , experiment_config )
2524
+ train_args = cls ._get_train_args (estimator , inputs , experiment_config , accept_eula )
2512
2525
2513
2526
logger .debug ("Train args after processing defaults: %s" , train_args )
2514
2527
estimator .sagemaker_session .train (** train_args )
2515
2528
2516
2529
return cls (estimator .sagemaker_session , estimator ._current_job_name )
2517
2530
2518
2531
@classmethod
2519
- def _get_train_args (cls , estimator , inputs , experiment_config ):
2532
+ def _get_train_args (cls , estimator , inputs , experiment_config , accept_eula ):
2520
2533
"""Constructs a dict of arguments for an Amazon SageMaker training job from the estimator.
2521
2534
2522
2535
Args:
@@ -2536,6 +2549,11 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
2536
2549
will be unassociated.
2537
2550
* `TrialComponentDisplayName` is used for display in Studio.
2538
2551
* `RunName` is used to record an experiment run.
2552
+ accept_eula (bool): For models that require a Model Access Config, specify True or
2553
+ False to indicate whether model terms of use have been accepted.
2554
+ The `accept_eula` value must be explicitly defined as `True` in order to
2555
+ accept the end-user license agreement (EULA) that some
2556
+ models require. (Default: None).
2539
2557
2540
2558
Returns:
2541
2559
Dict: dict for `sagemaker.session.Session.train` method
@@ -2652,6 +2670,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
2652
2670
if estimator .get_session_chaining_config () is not None :
2653
2671
train_args ["session_chaining_config" ] = estimator .get_session_chaining_config ()
2654
2672
2673
+ if accept_eula is not None :
2674
+ cls ._set_accept_eula_for_input_data_config (train_args , accept_eula )
2675
+
2655
2676
return train_args
2656
2677
2657
2678
@classmethod
@@ -2674,6 +2695,42 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
2674
2695
raise ValueError ("Setting checkpoint_local_path is not supported in local mode." )
2675
2696
train_args ["checkpoint_local_path" ] = estimator .checkpoint_local_path
2676
2697
2698
+ @classmethod
2699
+ def _set_accept_eula_for_input_data_config (cls , train_args , accept_eula ):
2700
+ """Set the AcceptEula flag for all input data configurations.
2701
+
2702
+ This method sets the AcceptEula flag in the ModelAccessConfig for all S3DataSources
2703
+ in the InputDataConfig array. It handles cases where keys might not exist in the
2704
+ nested dictionary structure.
2705
+
2706
+ Args:
2707
+ train_args (dict): The training job arguments dictionary
2708
+ accept_eula (bool): The value to set for AcceptEula flag
2709
+ """
2710
+ if "InputDataConfig" not in train_args :
2711
+ return
2712
+
2713
+ eula_count = 0
2714
+ s3_uris = []
2715
+
2716
+ for idx in range (len (train_args ["InputDataConfig" ])):
2717
+ if "DataSource" in train_args ["InputDataConfig" ][idx ]:
2718
+ data_source = train_args ["InputDataConfig" ][idx ]["DataSource" ]
2719
+ if "S3DataSource" in data_source :
2720
+ s3_data_source = data_source ["S3DataSource" ]
2721
+ if "ModelAccessConfig" not in s3_data_source :
2722
+ s3_data_source ["ModelAccessConfig" ] = {}
2723
+ s3_data_source ["ModelAccessConfig" ]["AcceptEula" ] = accept_eula
2724
+ eula_count += 1
2725
+
2726
+ # Collect S3 URI if available
2727
+ if "S3Uri" in s3_data_source :
2728
+ s3_uris .append (s3_data_source ["S3Uri" ])
2729
+
2730
+ # Log info if more than one EULA needs to be accepted
2731
+ if eula_count > 1 :
2732
+ logger .info ("Accepting EULA for %d S3 data sources: %s" , eula_count , ", " .join (s3_uris ))
2733
+
2677
2734
@classmethod
2678
2735
def _is_local_channel (cls , input_uri ):
2679
2736
"""Placeholder docstring"""
0 commit comments