Skip to content

feat: Add support for InstancePlacementConfig in Estimator #5259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
instance_placement_config: Optional[Dict] = None,
**kwargs,
):
"""Initialize an ``EstimatorBase`` instance.
Expand Down Expand Up @@ -560,6 +561,21 @@ def __init__(
Specifies whether SessionTagChaining is enabled for the training job.
training_plan (str or PipelineVariable): Optional.
Specifies which training plan arn to use for the training job
instance_placement_config (dict): Optional.
Specifies UltraServer placement configuration for the training job
.. code:: python
instance_placement_config={
"EnableMultipleJobs": True,
"PlacementSpecifications":[
{
"UltraServerId": "ultraserver-1",
"InstanceCount": "2"
}
]
}
"""
instance_count = renamed_kwargs(
"train_instance_count", "instance_count", instance_count, kwargs
Expand Down Expand Up @@ -813,6 +829,8 @@ def __init__(

self.training_plan = training_plan

self.instance_placement_config = instance_placement_config

# Internal flag
self._is_output_path_set_from_default_bucket_and_prefix = False

Expand Down Expand Up @@ -1997,6 +2015,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
if "TrainingPlanArn" in job_details["ResourceConfig"]:
init_params["training_plan"] = job_details["ResourceConfig"]["TrainingPlanArn"]

if "InstancePlacementConfig" in job_details["ResourceConfig"]:
init_params["instance_placement_config"] = job_details["ResourceConfig"][
"InstancePlacementConfig"
]

has_hps = "HyperParameters" in job_details
init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {}

Expand Down Expand Up @@ -2882,6 +2905,7 @@ def __init__(
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
instance_placement_config: Optional[Dict] = None,
**kwargs,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -3249,6 +3273,20 @@ def __init__(
Specifies whether SessionTagChaining is enabled for the training job
training_plan (str or PipelineVariable): Optional.
Specifies which training plan arn to use for the training job
instance_placement_config (dict): Optional.
Specifies UltraServer placement configuration for the training job
.. code:: python
instance_placement_config={
"EnableMultipleJobs": True,
"PlacementSpecifications":[
{
"UltraServerId": "ultraserver-1",
"InstanceCount": "2"
}
]
}
"""
self.image_uri = image_uri
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
Expand Down Expand Up @@ -3303,6 +3341,7 @@ def __init__(
enable_remote_debug=enable_remote_debug,
enable_session_tag_chaining=enable_session_tag_chaining,
training_plan=training_plan,
instance_placement_config=instance_placement_config,
**kwargs,
)

Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
estimator.volume_kms_key,
estimator.keep_alive_period_in_seconds,
estimator.training_plan,
estimator.instance_placement_config,
)
stop_condition = _Job._prepare_stop_condition(estimator.max_run, estimator.max_wait)
vpc_config = estimator.get_vpc_config()
Expand Down Expand Up @@ -333,6 +334,7 @@ def _prepare_resource_config(
volume_kms_key,
keep_alive_period_in_seconds,
training_plan,
instance_placement_config=None,
):
"""Placeholder docstring"""
resource_config = {
Expand Down Expand Up @@ -360,6 +362,8 @@ def _prepare_resource_config(
resource_config["InstanceType"] = instance_type
if training_plan is not None:
resource_config["TrainingPlanArn"] = training_plan
if instance_placement_config is not None:
resource_config["InstancePlacementConfig"] = instance_placement_config

return resource_config

Expand Down
16 changes: 16 additions & 0 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
config_name: Optional[str] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
instance_placement_config: Optional[Dict] = None,
):
"""Initializes a ``JumpStartEstimator``.
Expand Down Expand Up @@ -517,6 +518,20 @@ def __init__(
Specifies whether SessionTagChaining is enabled for the training job
training_plan (str or PipelineVariable): Optional.
Specifies which training plan arn to use for the training job
instance_placement_config (dict): Optional.
Specifies UltraServer placement configuration for the training job
.. code:: python
instance_placement_config={
"EnableMultipleJobs": True,
"PlacementSpecifications":[
{
"UltraServerId": "ultraserver-1",
"InstanceCount": "2"
}
]
}
Raises:
ValueError: If the model ID is not recognized by JumpStart.
Expand Down Expand Up @@ -606,6 +621,7 @@ def _validate_model_id_and_get_type_hook():
config_name=config_name,
enable_session_tag_chaining=enable_session_tag_chaining,
training_plan=training_plan,
instance_placement_config=instance_placement_config,
)

self.hub_arn = estimator_init_kwargs.hub_arn
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def get_init_kwargs(
config_name: Optional[str] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
instance_placement_config: Optional[Dict] = None,
) -> JumpStartEstimatorInitKwargs:
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""

Expand Down Expand Up @@ -207,6 +208,7 @@ def get_init_kwargs(
config_name=config_name,
enable_session_tag_chaining=enable_session_tag_chaining,
training_plan=training_plan,
instance_placement_config=instance_placement_config,
)

estimator_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2445,6 +2445,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"model_reference_arn",
"specs",
"training_plan",
"instance_placement_config",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -2519,6 +2520,7 @@ def __init__(
config_name: Optional[str] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
instance_placement_config: Optional[Dict] = None,
) -> None:
"""Instantiates JumpStartEstimatorInitKwargs object."""

Expand Down Expand Up @@ -2582,6 +2584,7 @@ def __init__(
self.config_name = config_name
self.enable_session_tag_chaining = enable_session_tag_chaining
self.training_plan = training_plan
self.instance_placement_config = instance_placement_config


class JumpStartEstimatorFitKwargs(JumpStartKwargs):
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
)
from sagemaker.model_life_cycle import ModelLifeCycle

from tests.unit.test_job import INSTANCE_PLACEMENT_CONFIG

MODEL_DATA = "s3://bucket/model.tar.gz"
MODEL_IMAGE = "mi"
ENTRY_POINT = "blah.py"
Expand Down Expand Up @@ -879,6 +881,22 @@ def test_framework_with_training_plan(sagemaker_session):
assert args["resource_config"]["TrainingPlanArn"] == TRAINING_PLAN


def test_framework_with_instance_placement(sagemaker_session):
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_type="ml.c4.xlarge",
instance_count=2,
training_plan=TRAINING_PLAN,
instance_placement_config=INSTANCE_PLACEMENT_CONFIG,
)
f.fit("s3://mydata")
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
assert args["resource_config"]["InstancePlacementConfig"] == INSTANCE_PLACEMENT_CONFIG


def test_framework_with_both_training_repository_config(sagemaker_session):
f = DummyFramework(
entry_point=SCRIPT_PATH,
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
INSTANCE_TYPE = "c4.4xlarge"
KEEP_ALIVE_PERIOD = 1800
TRAINING_PLAN = "arn:aws:sagemaker:us-west-2:336:training-plan/test_training_plan"
INSTANCE_PLACEMENT_CONFIG = {
"EnableMultipleJobs": True,
"PlacementSpecifications": [{"UltraServerId": "us-1", "InstanceCount": "2"}],
}
INSTANCE_GROUP = InstanceGroup("group", "ml.c4.xlarge", 1)
VOLUME_SIZE = 1
MAX_RUNTIME = 1
Expand Down Expand Up @@ -756,6 +760,28 @@ def test_prepare_resource_config_with_training_plan():
}


def test_prepare_resource_config_with_placement_config():
resource_config = _Job._prepare_resource_config(
INSTANCE_COUNT,
INSTANCE_TYPE,
None,
VOLUME_SIZE,
VOLUME_KMS_KEY,
None,
TRAINING_PLAN,
INSTANCE_PLACEMENT_CONFIG,
)

assert resource_config == {
"InstanceCount": INSTANCE_COUNT,
"InstanceType": INSTANCE_TYPE,
"VolumeSizeInGB": VOLUME_SIZE,
"VolumeKmsKeyId": VOLUME_KMS_KEY,
"TrainingPlanArn": TRAINING_PLAN,
"InstancePlacementConfig": INSTANCE_PLACEMENT_CONFIG,
}


def test_prepare_resource_config_with_keep_alive_period():
resource_config = _Job._prepare_resource_config(
INSTANCE_COUNT,
Expand Down