Skip to content

Commit 7a5b79d

Browse files
HollowTubeTritin Truong
authored andcommitted
feature: Support SageMakerTrainingPlan for training jobs (#1544)
Co-authored-by: Tritin Truong <[email protected]>
1 parent ae3e1da commit 7a5b79d

File tree

8 files changed

+159
-7
lines changed

8 files changed

+159
-7
lines changed

src/sagemaker/estimator.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __init__(
185185
disable_output_compression: bool = False,
186186
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
187187
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
188+
training_plan: Optional[Union[str, PipelineVariable]] = None,
188189
**kwargs,
189190
):
190191
"""Initialize an ``EstimatorBase`` instance.
@@ -554,6 +555,8 @@ def __init__(
554555
Specifies whether RemoteDebug is enabled for the training job.
555556
enable_session_tag_chaining (bool or PipelineVariable): Optional.
556557
Specifies whether SessionTagChaining is enabled for the training job.
558+
training_plan (str or PipelineVariable): Optional.
559+
Specifies which training plan arn to use for the training job
557560
"""
558561
instance_count = renamed_kwargs(
559562
"train_instance_count", "instance_count", instance_count, kwargs
@@ -762,8 +765,7 @@ def __init__(
762765

763766
self.tensorboard_output_config = tensorboard_output_config
764767

765-
self.debugger_rule_configs = None
766-
self.collection_configs = None
768+
self.debugger_rule_configs, self.collection_configs = None, None
767769

768770
self.enable_sagemaker_metrics = enable_sagemaker_metrics
769771

@@ -774,6 +776,7 @@ def __init__(
774776
sagemaker_session=self.sagemaker_session,
775777
)
776778

779+
self.profiler_rule_configs, self.profiler_rules = None, None
777780
self.profiler_config = profiler_config
778781
self.disable_profiler = resolve_value_from_config(
779782
direct_input=disable_profiler,
@@ -796,8 +799,6 @@ def __init__(
796799
) or _instance_type_supports_profiler(self.instance_type):
797800
self.disable_profiler = True
798801

799-
self.profiler_rule_configs = None
800-
self.profiler_rules = None
801802
self.debugger_rules = None
802803
self.disable_output_compression = disable_output_compression
803804
validate_source_code_input_against_pipeline_variables(
@@ -807,6 +808,8 @@ def __init__(
807808
enable_network_isolation=self._enable_network_isolation,
808809
)
809810

811+
self.training_plan = training_plan
812+
810813
# Internal flag
811814
self._is_output_path_set_from_default_bucket_and_prefix = False
812815

@@ -1960,6 +1963,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
19601963
"KeepAlivePeriodInSeconds"
19611964
]
19621965

1966+
if "TrainingPlanArn" in job_details["ResourceConfig"]:
1967+
init_params["training_plan"] = job_details["ResourceConfig"]["TrainingPlanArn"]
1968+
19631969
has_hps = "HyperParameters" in job_details
19641970
init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {}
19651971

@@ -2840,6 +2846,7 @@ def __init__(
28402846
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
28412847
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
28422848
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
2849+
training_plan: Optional[Union[str, PipelineVariable]] = None,
28432850
**kwargs,
28442851
):
28452852
"""Initialize an ``Estimator`` instance.
@@ -3205,6 +3212,8 @@ def __init__(
32053212
Specifies whether RemoteDebug is enabled for the training job
32063213
enable_session_tag_chaining (bool or PipelineVariable): Optional.
32073214
Specifies whether SessionTagChaining is enabled for the training job
3215+
training_plan (str or PipelineVariable): Optional.
3216+
Specifies which training plan arn to use for the training job
32083217
"""
32093218
self.image_uri = image_uri
32103219
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
@@ -3258,6 +3267,7 @@ def __init__(
32583267
disable_output_compression=disable_output_compression,
32593268
enable_remote_debug=enable_remote_debug,
32603269
enable_session_tag_chaining=enable_session_tag_chaining,
3270+
training_plan=training_plan,
32613271
**kwargs,
32623272
)
32633273

src/sagemaker/job.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
8383
estimator.volume_size,
8484
estimator.volume_kms_key,
8585
estimator.keep_alive_period_in_seconds,
86+
estimator.training_plan,
8687
)
8788
stop_condition = _Job._prepare_stop_condition(estimator.max_run, estimator.max_wait)
8889
vpc_config = estimator.get_vpc_config()
@@ -294,6 +295,7 @@ def _prepare_resource_config(
294295
volume_size,
295296
volume_kms_key,
296297
keep_alive_period_in_seconds,
298+
training_plan,
297299
):
298300
"""Placeholder docstring"""
299301
resource_config = {
@@ -319,6 +321,8 @@ def _prepare_resource_config(
319321
)
320322
resource_config["InstanceCount"] = instance_count
321323
resource_config["InstanceType"] = instance_type
324+
if training_plan is not None:
325+
resource_config["TrainingPlanArn"] = training_plan
322326

323327
return resource_config
324328

src/sagemaker/jumpstart/estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
116116
config_name: Optional[str] = None,
117117
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
118+
training_plan: Optional[Union[str, PipelineVariable]] = None,
118119
):
119120
"""Initializes a ``JumpStartEstimator``.
120121
@@ -511,6 +512,8 @@ def __init__(
511512
Name of the training configuration to apply to the Estimator. (Default: None).
512513
enable_session_tag_chaining (bool or PipelineVariable): Optional.
513514
Specifies whether SessionTagChaining is enabled for the training job
515+
training_plan (str or PipelineVariable): Optional.
516+
Specifies which training plan arn to use for the training job
514517
515518
Raises:
516519
ValueError: If the model ID is not recognized by JumpStart.
@@ -599,6 +602,7 @@ def _validate_model_id_and_get_type_hook():
599602
enable_remote_debug=enable_remote_debug,
600603
config_name=config_name,
601604
enable_session_tag_chaining=enable_session_tag_chaining,
605+
training_plan=training_plan,
602606
)
603607

604608
self.hub_arn = estimator_init_kwargs.hub_arn

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def get_init_kwargs(
144144
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
145145
config_name: Optional[str] = None,
146146
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
147+
training_plan: Optional[Union[str, PipelineVariable]] = None,
147148
) -> JumpStartEstimatorInitKwargs:
148149
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
149150

@@ -205,6 +206,7 @@ def get_init_kwargs(
205206
enable_remote_debug=enable_remote_debug,
206207
config_name=config_name,
207208
enable_session_tag_chaining=enable_session_tag_chaining,
209+
training_plan=training_plan,
208210
)
209211

210212
estimator_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(

src/sagemaker/jumpstart/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2406,6 +2406,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
24062406
"hub_content_type",
24072407
"model_reference_arn",
24082408
"specs",
2409+
"training_plan",
24092410
]
24102411

24112412
SERIALIZATION_EXCLUSION_SET = {
@@ -2479,6 +2480,7 @@ def __init__(
24792480
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
24802481
config_name: Optional[str] = None,
24812482
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
2483+
training_plan: Optional[Union[str, PipelineVariable]] = None,
24822484
) -> None:
24832485
"""Instantiates JumpStartEstimatorInitKwargs object."""
24842486

@@ -2541,6 +2543,7 @@ def __init__(
25412543
self.enable_remote_debug = enable_remote_debug
25422544
self.config_name = config_name
25432545
self.enable_session_tag_chaining = enable_session_tag_chaining
2546+
self.training_plan = training_plan
25442547

25452548

25462549
class JumpStartEstimatorFitKwargs(JumpStartKwargs):

src/sagemaker/session.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2470,6 +2470,75 @@ def describe_training_job(self, job_name):
24702470
"""
24712471
return self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
24722472

2473+
def describe_training_plan(self, training_plan_name):
2474+
"""Calls the DescribeTrainingPlan API for the given training plan and returns the response.
2475+
2476+
Args:
2477+
training_plan_name (str): The name of the training plan to describe.
2478+
2479+
Returns:
2480+
dict: A dictionary response with the training plan description.
2481+
"""
2482+
return self.sagemaker_client.describe_training_plan(TrainingPlanName=training_plan_name)
2483+
2484+
def list_training_plans(
2485+
self,
2486+
filters=None,
2487+
requested_start_time_after=None,
2488+
requested_start_time_before=None,
2489+
start_time_after=None,
2490+
start_time_before=None,
2491+
sort_order=None,
2492+
sort_by=None,
2493+
max_results=None,
2494+
next_token=None,
2495+
):
2496+
"""Calls the ListrTrainingPlan API for the given filters and returns the response.
2497+
2498+
Args:
2499+
filters (dict): A dictionary of key-value pairs used to filter the training plans.
2500+
Default to None.
2501+
requested_start_time_after (datetime): A timestamp that filters the results
2502+
to only include training plans with a requested start time after this timestamp.
2503+
requested_start_time_before (datetime): A timestamp that filters the results
2504+
to only include training plans with a requested start time before this timestamp.
2505+
start_time_after (datetime): A timestamp that filters the results
2506+
to only include training plans with an actual start time after this timestamp.
2507+
start_time_before (datetime): A timestamp that filters the results
2508+
to only include training plans with an actual start time before this timestamp.
2509+
sort_order (str): The order that the training plans will be listed in result.
2510+
Default to None.
2511+
sort_by (str): The value that the training plans will be sorted by.
2512+
Default to None.
2513+
max_results (int): The number of candidates will be listed in results,
2514+
between 1 and 100. Default to None. If None, will return all the training_plans.
2515+
next_token (str): The pagination token. Default to None.
2516+
2517+
Returns:
2518+
dict: A dictionary containing the following keys:
2519+
- "TrainingPlanSummaries": A list of dictionaries, where each dictionary represents
2520+
a training plan.
2521+
- "NextToken": A token to retrieve the next set of results, if there are more
2522+
than the maximum number of results returned.
2523+
"""
2524+
list_training_plan_args = {}
2525+
2526+
def check_object(key, value):
2527+
if value is not None:
2528+
list_training_plan_args[key] = value
2529+
2530+
check_object("Filters", filters)
2531+
check_object("SortBy", sort_by)
2532+
check_object("SortOrder", sort_order)
2533+
check_object("RequestedStartTimeAfter", requested_start_time_after)
2534+
check_object("RequestedStartTimeBefore", requested_start_time_before)
2535+
check_object("StartTimeAfter", start_time_after)
2536+
check_object("StartTimeBefore", start_time_before)
2537+
check_object("NextToken", next_token)
2538+
check_object("MaxResults", max_results)
2539+
2540+
return self.sagemaker_client.list_training_plans(**list_training_plan_args)
2541+
24732542
def auto_ml(
24742543
self,
24752544
input_config,

tests/unit/test_estimator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
INSTANCE_COUNT = 1
9090
INSTANCE_TYPE = "c4.4xlarge"
9191
KEEP_ALIVE_PERIOD_IN_SECONDS = 1800
92+
TRAINING_PLAN = "arn:aws:sagemaker:us-west-2:336:training-plan/test_training_plan"
9293
ACCELERATOR_TYPE = "ml.eia.medium"
9394
ROLE = "DummyRole"
9495
IMAGE_URI = "fakeimage"
@@ -861,6 +862,23 @@ def test_framework_with_keep_alive_period(sagemaker_session):
861862
assert args["resource_config"]["KeepAlivePeriodInSeconds"] == KEEP_ALIVE_PERIOD_IN_SECONDS
862863

863864

865+
def test_framework_with_training_plan(sagemaker_session):
866+
f = DummyFramework(
867+
entry_point=SCRIPT_PATH,
868+
role=ROLE,
869+
sagemaker_session=sagemaker_session,
870+
instance_groups=[
871+
InstanceGroup("group1", "ml.c4.xlarge", 1),
872+
InstanceGroup("group2", "ml.m4.xlarge", 2),
873+
],
874+
training_plan=TRAINING_PLAN,
875+
)
876+
f.fit("s3://mydata")
877+
sagemaker_session.train.assert_called_once()
878+
_, args = sagemaker_session.train.call_args
879+
assert args["resource_config"]["TrainingPlanArn"] == TRAINING_PLAN
880+
881+
864882
def test_framework_with_both_training_repository_config(sagemaker_session):
865883
f = DummyFramework(
866884
entry_point=SCRIPT_PATH,

tests/unit/test_job.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
INSTANCE_COUNT = 1
3232
INSTANCE_TYPE = "c4.4xlarge"
3333
KEEP_ALIVE_PERIOD = 1800
34+
TRAINING_PLAN = "arn:aws:sagemaker:us-west-2:336:training-plan/test_training_plan"
3435
INSTANCE_GROUP = InstanceGroup("group", "ml.c4.xlarge", 1)
3536
VOLUME_SIZE = 1
3637
MAX_RUNTIME = 1
@@ -633,7 +634,13 @@ def test_prepare_output_config_kms_key_none():
633634

634635
def test_prepare_resource_config():
635636
resource_config = _Job._prepare_resource_config(
636-
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, None, None
637+
INSTANCE_COUNT,
638+
INSTANCE_TYPE,
639+
None,
640+
VOLUME_SIZE,
641+
None,
642+
None,
643+
None,
637644
)
638645

639646
assert resource_config == {
@@ -643,9 +650,35 @@ def test_prepare_resource_config():
643650
}
644651

645652

653+
def test_prepare_resource_config_with_training_plan():
654+
resource_config = _Job._prepare_resource_config(
655+
INSTANCE_COUNT,
656+
INSTANCE_TYPE,
657+
None,
658+
VOLUME_SIZE,
659+
VOLUME_KMS_KEY,
660+
None,
661+
TRAINING_PLAN,
662+
)
663+
664+
assert resource_config == {
665+
"InstanceCount": INSTANCE_COUNT,
666+
"InstanceType": INSTANCE_TYPE,
667+
"VolumeSizeInGB": VOLUME_SIZE,
668+
"VolumeKmsKeyId": VOLUME_KMS_KEY,
669+
"TrainingPlanArn": TRAINING_PLAN,
670+
}
671+
672+
646673
def test_prepare_resource_config_with_keep_alive_period():
647674
resource_config = _Job._prepare_resource_config(
648-
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, KEEP_ALIVE_PERIOD
675+
INSTANCE_COUNT,
676+
INSTANCE_TYPE,
677+
None,
678+
VOLUME_SIZE,
679+
VOLUME_KMS_KEY,
680+
KEEP_ALIVE_PERIOD,
681+
None,
649682
)
650683

651684
assert resource_config == {
@@ -659,7 +692,13 @@ def test_prepare_resource_config_with_keep_alive_period():
659692

660693
def test_prepare_resource_config_with_volume_kms():
661694
resource_config = _Job._prepare_resource_config(
662-
INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, None
695+
INSTANCE_COUNT,
696+
INSTANCE_TYPE,
697+
None,
698+
VOLUME_SIZE,
699+
VOLUME_KMS_KEY,
700+
None,
701+
None,
663702
)
664703

665704
assert resource_config == {
@@ -678,6 +717,7 @@ def test_prepare_resource_config_with_heterogeneous_cluster():
678717
VOLUME_SIZE,
679718
None,
680719
None,
720+
None,
681721
)
682722

683723
assert resource_config == {
@@ -698,6 +738,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou
698738
VOLUME_SIZE,
699739
None,
700740
None,
741+
None,
701742
)
702743
assert "instance_count and instance_type cannot be set when instance_groups is set" in str(
703744
error
@@ -713,6 +754,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou
713754
VOLUME_SIZE,
714755
None,
715756
None,
757+
None,
716758
)
717759
assert "instance_count and instance_type must be set if instance_groups is not set" in str(
718760
error

0 commit comments

Comments
 (0)