Skip to content

Commit 8ae6835

Browse files
committed
fix: jumpstart estimator for gated uncompressed training
1 parent 23f4907 commit 8ae6835

File tree

4 files changed

+67
-7
lines changed

4 files changed

+67
-7
lines changed

src/sagemaker/estimator.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,7 @@ def fit(
13101310
logs: str = "All",
13111311
job_name: Optional[str] = None,
13121312
experiment_config: Optional[Dict[str, str]] = None,
1313+
accept_eula: Optional[bool] = None,
13131314
):
13141315
"""Train a model using the input training dataset.
13151316
@@ -1363,14 +1364,21 @@ def fit(
13631364
* Both `ExperimentName` and `TrialName` will be ignored if the Estimator instance
13641365
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
13651366
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
1367+
accept_eula (bool): For models that require a Model Access Config, specify True or
1368+
False to indicate whether model terms of use have been accepted.
1369+
The `accept_eula` value must be explicitly defined as `True` in order to
1370+
accept the end-user license agreement (EULA) that some
1371+
models require. (Default: None).
13661372
Returns:
13671373
None or pipeline step arguments in case the Estimator instance is built with
13681374
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
13691375
"""
13701376
self._prepare_for_training(job_name=job_name)
13711377

13721378
experiment_config = check_and_get_run_experiment_config(experiment_config)
1373-
self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config)
1379+
self.latest_training_job = _TrainingJob.start_new(
1380+
self, inputs, experiment_config, accept_eula
1381+
)
13741382
self.jobs.append(self.latest_training_job)
13751383
forward_to_mlflow_tracking_server = False
13761384
if os.environ.get("MLFLOW_TRACKING_URI") and self.enable_network_isolation():
@@ -2484,7 +2492,7 @@ class _TrainingJob(_Job):
24842492
"""Placeholder docstring"""
24852493

24862494
@classmethod
2487-
def start_new(cls, estimator, inputs, experiment_config):
2495+
def start_new(cls, estimator, inputs, experiment_config, accept_eula):
24882496
"""Create a new Amazon SageMaker training job from the estimator.
24892497
24902498
Args:
@@ -2504,19 +2512,24 @@ def start_new(cls, estimator, inputs, experiment_config):
25042512
will be unassociated.
25052513
* `TrialComponentDisplayName` is used for display in Studio.
25062514
* `RunName` is used to record an experiment run.
2515+
accept_eula (bool): For models that require a Model Access Config, specify True or
2516+
False to indicate whether model terms of use have been accepted.
2517+
The `accept_eula` value must be explicitly defined as `True` in order to
2518+
accept the end-user license agreement (EULA) that some
2519+
models require. (Default: None).
25072520
Returns:
25082521
sagemaker.estimator._TrainingJob: Constructed object that captures
25092522
all information about the started training job.
25102523
"""
2511-
train_args = cls._get_train_args(estimator, inputs, experiment_config)
2524+
train_args = cls._get_train_args(estimator, inputs, experiment_config, accept_eula)
25122525

25132526
logger.debug("Train args after processing defaults: %s", train_args)
25142527
estimator.sagemaker_session.train(**train_args)
25152528

25162529
return cls(estimator.sagemaker_session, estimator._current_job_name)
25172530

25182531
@classmethod
2519-
def _get_train_args(cls, estimator, inputs, experiment_config):
2532+
def _get_train_args(cls, estimator, inputs, experiment_config, accept_eula):
25202533
"""Constructs a dict of arguments for an Amazon SageMaker training job from the estimator.
25212534
25222535
Args:
@@ -2536,6 +2549,11 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
25362549
will be unassociated.
25372550
* `TrialComponentDisplayName` is used for display in Studio.
25382551
* `RunName` is used to record an experiment run.
2552+
accept_eula (bool): For models that require a Model Access Config, specify True or
2553+
False to indicate whether model terms of use have been accepted.
2554+
The `accept_eula` value must be explicitly defined as `True` in order to
2555+
accept the end-user license agreement (EULA) that some
2556+
models require. (Default: None).
25392557
25402558
Returns:
25412559
Dict: dict for `sagemaker.session.Session.train` method
@@ -2652,6 +2670,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
26522670
if estimator.get_session_chaining_config() is not None:
26532671
train_args["session_chaining_config"] = estimator.get_session_chaining_config()
26542672

2673+
if accept_eula is not None:
2674+
cls._set_accept_eula_for_input_data_config(train_args, accept_eula)
2675+
26552676
return train_args
26562677

26572678
@classmethod
@@ -2674,6 +2695,42 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
26742695
raise ValueError("Setting checkpoint_local_path is not supported in local mode.")
26752696
train_args["checkpoint_local_path"] = estimator.checkpoint_local_path
26762697

2698+
@classmethod
2699+
def _set_accept_eula_for_input_data_config(cls, train_args, accept_eula):
2700+
"""Set the AcceptEula flag for all input data configurations.
2701+
2702+
This method sets the AcceptEula flag in the ModelAccessConfig for all S3DataSources
2703+
in the InputDataConfig array. It handles cases where keys might not exist in the
2704+
nested dictionary structure.
2705+
2706+
Args:
2707+
train_args (dict): The training job arguments dictionary
2708+
accept_eula (bool): The value to set for AcceptEula flag
2709+
"""
2710+
if "InputDataConfig" not in train_args:
2711+
return
2712+
2713+
eula_count = 0
2714+
s3_uris = []
2715+
2716+
for idx in range(len(train_args["InputDataConfig"])):
2717+
if "DataSource" in train_args["InputDataConfig"][idx]:
2718+
data_source = train_args["InputDataConfig"][idx]["DataSource"]
2719+
if "S3DataSource" in data_source:
2720+
s3_data_source = data_source["S3DataSource"]
2721+
if "ModelAccessConfig" not in s3_data_source:
2722+
s3_data_source["ModelAccessConfig"] = {}
2723+
s3_data_source["ModelAccessConfig"]["AcceptEula"] = accept_eula
2724+
eula_count += 1
2725+
2726+
# Collect S3 URI if available
2727+
if "S3Uri" in s3_data_source:
2728+
s3_uris.append(s3_data_source["S3Uri"])
2729+
2730+
# Log info if more than one EULA needs to be accepted
2731+
if eula_count > 1:
2732+
logger.info("Accepting EULA for %d S3 data sources: %s", eula_count, ", ".join(s3_uris))
2733+
26772734
@classmethod
26782735
def _is_local_channel(cls, input_uri):
26792736
"""Placeholder docstring"""

src/sagemaker/jumpstart/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ def fit(
713713
sagemaker_session=self.sagemaker_session,
714714
config_name=self.config_name,
715715
hub_access_config=self.hub_access_config,
716+
accept_eula=accept_eula,
716717
)
717718
remove_env_var_from_estimator_kwargs_if_model_access_config_present(
718719
self.init_kwargs, self.model_access_config

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def get_fit_kwargs(
266266
sagemaker_session: Optional[Session] = None,
267267
config_name: Optional[str] = None,
268268
hub_access_config: Optional[Dict] = None,
269+
accept_eula: Optional[bool] = None,
269270
) -> JumpStartEstimatorFitKwargs:
270271
"""Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object."""
271272

@@ -283,6 +284,7 @@ def get_fit_kwargs(
283284
tolerate_vulnerable_model=tolerate_vulnerable_model,
284285
sagemaker_session=sagemaker_session,
285286
config_name=config_name,
287+
accept_eula=accept_eula,
286288
)
287289

288290
estimator_fit_kwargs, _ = _set_temp_sagemaker_session_if_not_set(kwargs=estimator_fit_kwargs)

src/sagemaker/jumpstart/types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1940,9 +1940,6 @@ def use_inference_script_uri(self) -> bool:
19401940

19411941
def use_training_model_artifact(self) -> bool:
19421942
"""Returns True if the model should use a model uri when kicking off training job."""
1943-
# gated model never use training model artifact
1944-
if self.gated_bucket:
1945-
return False
19461943

19471944
# otherwise, return true is a training model package is not set
19481945
return len(self.training_model_package_artifact_uris or {}) == 0
@@ -2595,6 +2592,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
25952592
"sagemaker_session",
25962593
"config_name",
25972594
"specs",
2595+
"accept_eula",
25982596
]
25992597

26002598
SERIALIZATION_EXCLUSION_SET = {
@@ -2625,6 +2623,7 @@ def __init__(
26252623
tolerate_vulnerable_model: Optional[bool] = None,
26262624
sagemaker_session: Optional[Session] = None,
26272625
config_name: Optional[str] = None,
2626+
accept_eula: Optional[bool] = None,
26282627
) -> None:
26292628
"""Instantiates JumpStartEstimatorInitKwargs object."""
26302629

@@ -2642,6 +2641,7 @@ def __init__(
26422641
self.tolerate_vulnerable_model = tolerate_vulnerable_model
26432642
self.sagemaker_session = sagemaker_session
26442643
self.config_name = config_name
2644+
self.accept_eula = accept_eula
26452645

26462646

26472647
class JumpStartEstimatorDeployKwargs(JumpStartKwargs):

0 commit comments

Comments
 (0)