Skip to content

Commit f65a28e

Browse files
gkatkovGreg Katkov
andauthored
feature: Add support for InstancePlacementConfig in Estimator for training jobs running on ultraserver capacity (#5259)
--------- Co-authored-by: Greg Katkov <[email protected]>
1 parent 754c3a5 commit f65a28e

File tree

7 files changed

+108
-0
lines changed

7 files changed

+108
-0
lines changed

src/sagemaker/estimator.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def __init__(
186186
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
187187
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
188188
training_plan: Optional[Union[str, PipelineVariable]] = None,
189+
instance_placement_config: Optional[Dict] = None,
189190
**kwargs,
190191
):
191192
"""Initialize an ``EstimatorBase`` instance.
@@ -560,6 +561,21 @@ def __init__(
560561
Specifies whether SessionTagChaining is enabled for the training job.
561562
training_plan (str or PipelineVariable): Optional.
562563
Specifies which training plan arn to use for the training job
564+
instance_placement_config (dict): Optional.
565+
Specifies UltraServer placement configuration for the training job
566+
567+
.. code:: python
568+
569+
instance_placement_config={
570+
"EnableMultipleJobs": True,
571+
"PlacementSpecifications":[
572+
{
573+
"UltraServerId": "ultraserver-1",
574+
"InstanceCount": "2"
575+
}
576+
]
577+
}
578+
563579
"""
564580
instance_count = renamed_kwargs(
565581
"train_instance_count", "instance_count", instance_count, kwargs
@@ -813,6 +829,8 @@ def __init__(
813829

814830
self.training_plan = training_plan
815831

832+
self.instance_placement_config = instance_placement_config
833+
816834
# Internal flag
817835
self._is_output_path_set_from_default_bucket_and_prefix = False
818836

@@ -1997,6 +2015,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
19972015
if "TrainingPlanArn" in job_details["ResourceConfig"]:
19982016
init_params["training_plan"] = job_details["ResourceConfig"]["TrainingPlanArn"]
19992017

2018+
if "InstancePlacementConfig" in job_details["ResourceConfig"]:
2019+
init_params["instance_placement_config"] = job_details["ResourceConfig"][
2020+
"InstancePlacementConfig"
2021+
]
2022+
20002023
has_hps = "HyperParameters" in job_details
20012024
init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {}
20022025

@@ -2882,6 +2905,7 @@ def __init__(
28822905
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
28832906
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
28842907
training_plan: Optional[Union[str, PipelineVariable]] = None,
2908+
instance_placement_config: Optional[Dict] = None,
28852909
**kwargs,
28862910
):
28872911
"""Initialize an ``Estimator`` instance.
@@ -3249,6 +3273,20 @@ def __init__(
32493273
Specifies whether SessionTagChaining is enabled for the training job
32503274
training_plan (str or PipelineVariable): Optional.
32513275
Specifies which training plan arn to use for the training job
3276+
instance_placement_config (dict): Optional.
3277+
Specifies UltraServer placement configuration for the training job
3278+
3279+
.. code:: python
3280+
3281+
instance_placement_config={
3282+
"EnableMultipleJobs": True,
3283+
"PlacementSpecifications":[
3284+
{
3285+
"UltraServerId": "ultraserver-1",
3286+
"InstanceCount": "2"
3287+
}
3288+
]
3289+
}
32523290
"""
32533291
self.image_uri = image_uri
32543292
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
@@ -3303,6 +3341,7 @@ def __init__(
33033341
enable_remote_debug=enable_remote_debug,
33043342
enable_session_tag_chaining=enable_session_tag_chaining,
33053343
training_plan=training_plan,
3344+
instance_placement_config=instance_placement_config,
33063345
**kwargs,
33073346
)
33083347

src/sagemaker/job.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
8585
estimator.volume_kms_key,
8686
estimator.keep_alive_period_in_seconds,
8787
estimator.training_plan,
88+
estimator.instance_placement_config,
8889
)
8990
stop_condition = _Job._prepare_stop_condition(estimator.max_run, estimator.max_wait)
9091
vpc_config = estimator.get_vpc_config()
@@ -333,6 +334,7 @@ def _prepare_resource_config(
333334
volume_kms_key,
334335
keep_alive_period_in_seconds,
335336
training_plan,
337+
instance_placement_config=None,
336338
):
337339
"""Placeholder docstring"""
338340
resource_config = {
@@ -360,6 +362,8 @@ def _prepare_resource_config(
360362
resource_config["InstanceType"] = instance_type
361363
if training_plan is not None:
362364
resource_config["TrainingPlanArn"] = training_plan
365+
if instance_placement_config is not None:
366+
resource_config["InstancePlacementConfig"] = instance_placement_config
363367

364368
return resource_config
365369

src/sagemaker/jumpstart/estimator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
config_name: Optional[str] = None,
120120
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
121121
training_plan: Optional[Union[str, PipelineVariable]] = None,
122+
instance_placement_config: Optional[Dict] = None,
122123
):
123124
"""Initializes a ``JumpStartEstimator``.
124125
@@ -517,6 +518,20 @@ def __init__(
517518
Specifies whether SessionTagChaining is enabled for the training job
518519
training_plan (str or PipelineVariable): Optional.
519520
Specifies which training plan arn to use for the training job
521+
instance_placement_config (dict): Optional.
522+
Specifies UltraServer placement configuration for the training job
523+
524+
.. code:: python
525+
526+
instance_placement_config={
527+
"EnableMultipleJobs": True,
528+
"PlacementSpecifications":[
529+
{
530+
"UltraServerId": "ultraserver-1",
531+
"InstanceCount": "2"
532+
}
533+
]
534+
}
520535
521536
Raises:
522537
ValueError: If the model ID is not recognized by JumpStart.
@@ -606,6 +621,7 @@ def _validate_model_id_and_get_type_hook():
606621
config_name=config_name,
607622
enable_session_tag_chaining=enable_session_tag_chaining,
608623
training_plan=training_plan,
624+
instance_placement_config=instance_placement_config,
609625
)
610626

611627
self.hub_arn = estimator_init_kwargs.hub_arn

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def get_init_kwargs(
145145
config_name: Optional[str] = None,
146146
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
147147
training_plan: Optional[Union[str, PipelineVariable]] = None,
148+
instance_placement_config: Optional[Dict] = None,
148149
) -> JumpStartEstimatorInitKwargs:
149150
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
150151

@@ -207,6 +208,7 @@ def get_init_kwargs(
207208
config_name=config_name,
208209
enable_session_tag_chaining=enable_session_tag_chaining,
209210
training_plan=training_plan,
211+
instance_placement_config=instance_placement_config,
210212
)
211213

212214
estimator_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(

src/sagemaker/jumpstart/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,6 +2445,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
24452445
"model_reference_arn",
24462446
"specs",
24472447
"training_plan",
2448+
"instance_placement_config",
24482449
]
24492450

24502451
SERIALIZATION_EXCLUSION_SET = {
@@ -2519,6 +2520,7 @@ def __init__(
25192520
config_name: Optional[str] = None,
25202521
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
25212522
training_plan: Optional[Union[str, PipelineVariable]] = None,
2523+
instance_placement_config: Optional[Dict] = None,
25222524
) -> None:
25232525
"""Instantiates JumpStartEstimatorInitKwargs object."""
25242526

@@ -2582,6 +2584,7 @@ def __init__(
25822584
self.config_name = config_name
25832585
self.enable_session_tag_chaining = enable_session_tag_chaining
25842586
self.training_plan = training_plan
2587+
self.instance_placement_config = instance_placement_config
25852588

25862589

25872590
class JumpStartEstimatorFitKwargs(JumpStartKwargs):

tests/unit/test_estimator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676
)
7777
from sagemaker.model_life_cycle import ModelLifeCycle
7878

79+
from tests.unit.test_job import INSTANCE_PLACEMENT_CONFIG
80+
7981
MODEL_DATA = "s3://bucket/model.tar.gz"
8082
MODEL_IMAGE = "mi"
8183
ENTRY_POINT = "blah.py"
@@ -879,6 +881,22 @@ def test_framework_with_training_plan(sagemaker_session):
879881
assert args["resource_config"]["TrainingPlanArn"] == TRAINING_PLAN
880882

881883

884+
def test_framework_with_instance_placement(sagemaker_session):
885+
f = DummyFramework(
886+
entry_point=SCRIPT_PATH,
887+
role=ROLE,
888+
sagemaker_session=sagemaker_session,
889+
instance_type="ml.c4.xlarge",
890+
instance_count=2,
891+
training_plan=TRAINING_PLAN,
892+
instance_placement_config=INSTANCE_PLACEMENT_CONFIG,
893+
)
894+
f.fit("s3://mydata")
895+
sagemaker_session.train.assert_called_once()
896+
_, args = sagemaker_session.train.call_args
897+
assert args["resource_config"]["InstancePlacementConfig"] == INSTANCE_PLACEMENT_CONFIG
898+
899+
882900
def test_framework_with_both_training_repository_config(sagemaker_session):
883901
f = DummyFramework(
884902
entry_point=SCRIPT_PATH,

tests/unit/test_job.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
INSTANCE_TYPE = "c4.4xlarge"
3333
KEEP_ALIVE_PERIOD = 1800
3434
TRAINING_PLAN = "arn:aws:sagemaker:us-west-2:336:training-plan/test_training_plan"
35+
INSTANCE_PLACEMENT_CONFIG = {
36+
"EnableMultipleJobs": True,
37+
"PlacementSpecifications": [{"UltraServerId": "us-1", "InstanceCount": "2"}],
38+
}
3539
INSTANCE_GROUP = InstanceGroup("group", "ml.c4.xlarge", 1)
3640
VOLUME_SIZE = 1
3741
MAX_RUNTIME = 1
@@ -756,6 +760,28 @@ def test_prepare_resource_config_with_training_plan():
756760
}
757761

758762

763+
def test_prepare_resource_config_with_placement_config():
764+
resource_config = _Job._prepare_resource_config(
765+
INSTANCE_COUNT,
766+
INSTANCE_TYPE,
767+
None,
768+
VOLUME_SIZE,
769+
VOLUME_KMS_KEY,
770+
None,
771+
TRAINING_PLAN,
772+
INSTANCE_PLACEMENT_CONFIG,
773+
)
774+
775+
assert resource_config == {
776+
"InstanceCount": INSTANCE_COUNT,
777+
"InstanceType": INSTANCE_TYPE,
778+
"VolumeSizeInGB": VOLUME_SIZE,
779+
"VolumeKmsKeyId": VOLUME_KMS_KEY,
780+
"TrainingPlanArn": TRAINING_PLAN,
781+
"InstancePlacementConfig": INSTANCE_PLACEMENT_CONFIG,
782+
}
783+
784+
759785
def test_prepare_resource_config_with_keep_alive_period():
760786
resource_config = _Job._prepare_resource_config(
761787
INSTANCE_COUNT,

0 commit comments

Comments
 (0)