Skip to content

Commit c64811d

Browse files
mollyheamazonshantanutripmohamedzeidan2021Mohamed Zeidanyungwenh-aws
authored
Reinvent Keynote3 Elastic Training Feature Support (#341)
* Upgrade Inference Operator Version (#327) * pyproj version update (#328) Co-authored-by: Mohamed Zeidan <[email protected]> * version change (#329) Co-authored-by: Mohamed Zeidan <[email protected]> * elastic training to keynote3 (#307) * feat: Implement elastic training cli arguments (#273) * feat: Implement elastic training cli arguments * Add elastic training unified config and unit test * Add graceful shutdown and scaling timeout to cli args * Revert "feat: Implement elastic training cli arguments (#273)" This reverts commit 18428ef2b1c0562bf51a9a4b4aa2914eed441259. * feat: Implement elastic training cli arguments (#295) * feat: implement elastic training cli args * Rename args name to match crd for elastic training * Add unit test for replcia discrete values * Add integ test for elastic training cli --------- Co-authored-by: Sophia <[email protected]> Co-authored-by: Molly He <[email protected]> Co-authored-by: Mohamed Zeidan <[email protected]> * version update for v3.5.0 --------- Co-authored-by: Shantanu Tripathi <[email protected]> Co-authored-by: Mohamed Zeidan <[email protected]> Co-authored-by: Mohamed Zeidan <[email protected]> Co-authored-by: Sophia <[email protected]>
1 parent 1aafd60 commit c64811d

File tree

12 files changed

+741
-8
lines changed

12 files changed

+741
-8
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Changelog
22

3+
## v.3.5.0 (2025-12-03)
4+
5+
### Features
6+
* Elastic training support for HyperPodTrainingOperator that is released in Reinvent 2025 keynote 3. This is a method that dynamically scales distributed machine learning operations.
7+
8+
39
## v.3.4.0 (2025-11-20)
410

511
### Features

hyperpod-pytorch-job-template/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## v1.3.0 (2025-12-03)
2+
3+
### Features
4+
5+
* Support for elastic training
6+
17
## v1.2.0 (2025-11-20)
28

39
### Features

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
Metadata,
1212
Volumes,
1313
HostPath,
14-
PersistentVolumeClaim
14+
PersistentVolumeClaim,
15+
ElasticPolicy
1516
)
1617
from sagemaker.hyperpod.training.hyperpod_pytorch_job import HyperPodPytorchJob
1718
import yaml
@@ -239,6 +240,38 @@ class PyTorchJobConfig(BaseModel):
239240
alias="required_topology",
240241
description="Required topology annotation for scheduling",
241242
)
243+
elastic_replica_increment_step: Optional[int] = Field(
244+
default=None,
245+
alias="elastic_replica_increment_step",
246+
description="Scaling step size for elastic training",
247+
ge=1,
248+
)
249+
max_node_count: Optional[int] = Field(
250+
default=None,
251+
alias="max_node_count",
252+
description="Maximum number of nodes for elastic training",
253+
ge=1,
254+
)
255+
elastic_graceful_shutdown_timeout_in_seconds: Optional[int] = Field(
256+
default=None,
257+
alias="elastic_graceful_shutdown_timeout_in_seconds",
258+
description="Graceful shutdown timeout in seconds for elastic scaling operations"
259+
)
260+
elastic_scaling_timeout_in_seconds: Optional[int] = Field(
261+
default=None,
262+
alias="elastic_scaling_timeout_in_seconds",
263+
description="Scaling timeout for elastic training"
264+
)
265+
elastic_scale_up_snooze_time_in_seconds: Optional[int] = Field(
266+
default=None,
267+
alias="elastic_scale_up_snooze_time_in_seconds",
268+
description="Timeout period after job restart during which no scale up/workload admission is allowed"
269+
)
270+
elastic_replica_discrete_values: Optional[List[int]] = Field(
271+
default=None,
272+
alias="elastic_replica_discrete_values",
273+
description="Alternative to replica increment step. Provides exact values for total replicas count"
274+
)
242275

243276
@field_validator('tasks_per_node', mode='before')
244277
@classmethod
@@ -363,6 +396,45 @@ def validate_accelerator_partition_options(self):
363396
)
364397
if not valid:
365398
raise ValueError(error)
399+
400+
return self
401+
402+
@model_validator(mode='after')
403+
def validate_elastic_replica_config(self):
404+
"""Validate elastic replica configuration."""
405+
has_increment_step = self.elastic_replica_increment_step is not None
406+
has_discrete_values = self.elastic_replica_discrete_values is not None
407+
408+
# Check mutual exclusivity
409+
if has_increment_step and has_discrete_values:
410+
raise ValueError(
411+
"Only one of 'elastic_replica_increment_step' or 'elastic_replica_discrete_values' "
412+
"can be specified, not both. Please use either:\n"
413+
" - elastic_replica_increment_step for uniform scaling steps, or\n"
414+
" - elastic_replica_discrete_values for specific replica counts"
415+
)
416+
417+
# Validate discrete values are within valid range
418+
if has_discrete_values:
419+
discrete_values = self.elastic_replica_discrete_values
420+
421+
# Check that all values are positive
422+
if any(val <= 0 for val in discrete_values):
423+
raise ValueError(
424+
f"All values in 'elastic_replica_discrete_values' must be positive integers. "
425+
f"Got: {discrete_values}"
426+
)
427+
428+
# Check against max_node_count if specified
429+
if self.max_node_count is not None:
430+
invalid_values = [val for val in discrete_values if val > self.max_node_count]
431+
if invalid_values:
432+
raise ValueError(
433+
f"All values in 'elastic_replica_discrete_values' must be ≤ max_node_count ({self.max_node_count}). "
434+
f"Invalid values: {invalid_values}. "
435+
f"Please either increase max_node_count or remove values exceeding it."
436+
)
437+
366438
return self
367439

368440
def to_domain(self) -> Dict:
@@ -467,15 +539,61 @@ def build_dict(**kwargs):
467539
replica_kwargs = build_dict(
468540
name="pod",
469541
template=Template(metadata=Metadata(**metadata_kwargs), spec=Spec(**spec_kwargs)),
470-
replicas=self.node_count
542+
replicas=self.node_count,
543+
max_replicas=self.max_node_count
471544
)
472545

546+
# Build elastic policy
547+
elastic_policy = None
548+
if any([
549+
self.elastic_replica_increment_step is not None,
550+
self.max_node_count is not None,
551+
self.elastic_graceful_shutdown_timeout_in_seconds is not None,
552+
self.elastic_scaling_timeout_in_seconds is not None,
553+
self.elastic_replica_discrete_values is not None
554+
]):
555+
# Build base elastic policy kwargs
556+
elastic_policy_kwargs = build_dict(
557+
min_replicas=self.node_count,
558+
max_replicas=self.max_node_count,
559+
graceful_shutdown_timeout_in_seconds=self.elastic_graceful_shutdown_timeout_in_seconds,
560+
scaling_timeout_in_seconds=self.elastic_scaling_timeout_in_seconds
561+
)
562+
563+
if self.elastic_replica_discrete_values is not None:
564+
elastic_policy_kwargs['replica_discrete_values'] = self.elastic_replica_discrete_values
565+
elif self.elastic_replica_increment_step is not None:
566+
elastic_policy_kwargs['replica_increment_step'] = self.elastic_replica_increment_step
567+
568+
elastic_policy = ElasticPolicy(**elastic_policy_kwargs)
569+
570+
# Build run policy
571+
run_policy = None
572+
if self.max_retry is not None or self.elastic_scale_up_snooze_time_in_seconds is not None:
573+
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import RestartPolicy
574+
575+
run_policy_kwargs = build_dict(
576+
clean_pod_policy="None",
577+
job_max_retry_count=self.max_retry
578+
)
579+
580+
# Add restart policy if scale_up_snooze_interval is provided
581+
if self.elastic_scale_up_snooze_time_in_seconds is not None:
582+
restart_policy = RestartPolicy(
583+
eval_period_seconds=3600,
584+
scale_up_snooze_time_in_seconds=self.elastic_scale_up_snooze_time_in_seconds
585+
)
586+
run_policy_kwargs['restart_policy'] = restart_policy
587+
588+
run_policy = RunPolicy(**run_policy_kwargs)
589+
473590
# Build job
474591
job_kwargs = build_dict(
475592
metadata=metadata_kwargs,
476593
replica_specs=[ReplicaSpec(**replica_kwargs)],
477594
nproc_per_node=str(self.tasks_per_node) if self.tasks_per_node else None,
478-
run_policy=RunPolicy(clean_pod_policy="None", job_max_retry_count=self.max_retry) if self.max_retry else None
595+
run_policy=run_policy,
596+
elastic_policy=elastic_policy
479597
)
480598

481599
result = HyperPodPytorchJob(**job_kwargs)

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,94 @@
395395
"type": "string",
396396
"description": "Required topology annotation for scheduling",
397397
"$ref": "#/$defs/topologyLabels"
398+
},
399+
"elastic_replica_increment_step": {
400+
"anyOf": [
401+
{
402+
"minimum": 1,
403+
"type": "integer"
404+
},
405+
{
406+
"type": "null"
407+
}
408+
],
409+
"default": null,
410+
"description": "Scaling step size for elastic training",
411+
"title": "Elastic Training Replica Increment Step"
412+
},
413+
"max_node_count": {
414+
"anyOf": [
415+
{
416+
"minimum": 1,
417+
"type": "integer"
418+
},
419+
{
420+
"type": "null"
398421
}
422+
],
423+
"default": null,
424+
"description": "Maximum number of nodes for elastic training",
425+
"title": "Max Node Count"
426+
},
427+
"elastic_graceful_shutdown_timeout_in_seconds": {
428+
"anyOf": [
429+
{
430+
"minimum": 0,
431+
"type": "integer"
432+
},
433+
{
434+
"type": "null"
435+
}
436+
],
437+
"default": null,
438+
"description": "Graceful shutdown timeout in seconds for elastic scaling operations",
439+
"title": "Elastic Graceful Shutdown Timeout In Seconds"
440+
},
441+
"elastic_scaling_timeout_in_seconds": {
442+
"anyOf": [
443+
{
444+
"minimum": 0,
445+
"type": "integer"
446+
},
447+
{
448+
"type": "null"
449+
}
450+
],
451+
"default": null,
452+
"description": "Scaling timeout for elastic training",
453+
"title": "Elastic Scaling Timeout In Seconds"
454+
},
455+
"elastic_scale_up_snooze_time_in_seconds": {
456+
"anyOf": [
457+
{
458+
"minimum": 0,
459+
"type": "integer"
460+
},
461+
{
462+
"type": "null"
463+
}
464+
],
465+
"default": null,
466+
"description": "Timeout period after job restart during which no scale up/workload admission is allowed",
467+
"title": "Elastic Scale Up Snooze Time In Seconds"
468+
},
469+
"elastic_replica_discrete_values": {
470+
"anyOf": [
471+
{
472+
"items": {
473+
"type": "integer"
474+
},
475+
"type": "array"
476+
},
477+
{
478+
"type": "null"
479+
}
480+
],
481+
"default": null,
482+
"description": "Alternative to replica increment step. Provides exact values for total replicas count",
483+
"title": "Elastic Replica Discrete Values"
484+
}
485+
399486
},
400487
"required": [
401488
"job_name",

hyperpod-pytorch-job-template/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "hyperpod-pytorch-job-template"
7-
version = "1.2.0"
7+
version = "1.3.0"
88
readme = "README.md"
99
authors = [{name = "Amazon Web Services"}]
1010
license = {text = "Apache-2.0"}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
55
[project]
66
dynamic = ["dependencies"]
77
name = "sagemaker-hyperpod"
8-
version = "3.4.0"
8+
version = "3.5.0"
99
description = "Amazon SageMaker HyperPod SDK and CLI"
1010
readme = "README.md"
1111
requires-python = ">=3.8"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
setup(
4848
data_files=sagemaker_hyperpod_recipes,
4949
name="sagemaker-hyperpod",
50-
version="3.4.0",
50+
version="3.5.0",
5151
description="Amazon SageMaker HyperPod SDK and CLI",
5252
long_description=open("README.md").read(),
5353
long_description_content_type="text/markdown",

src/sagemaker/hyperpod/cli/training_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,16 @@ def _parse_list_flag(ctx, param, value):
4242
return None
4343
# Remove brackets and split by comma
4444
value = value.strip("[]")
45-
return [item.strip() for item in value.split(",") if item.strip()]
45+
items = [item.strip() for item in value.split(",") if item.strip()]
46+
47+
# Convert to integers for elastic_replica_discrete_values
48+
if param and hasattr(param, 'name') and param.name == 'elastic_replica_discrete_values':
49+
try:
50+
return [int(item) for item in items]
51+
except ValueError as e:
52+
raise click.BadParameter(f"elastic-replica-discrete-values must contain only integers: {e}")
53+
54+
return items
4655

4756
def _parse_volume_param(ctx, param, value):
4857
"""Parse volume parameters from command line format to dictionary format."""
@@ -134,11 +143,12 @@ def wrapped_func(*args, **kwargs):
134143
list_params = {
135144
"command": "List of command arguments",
136145
"args": "List of script arguments, e.g. '[--batch-size, 32, --learning-rate, 0.001]'",
146+
"elastic_replica_discrete_values": "List of discrete replica values for elastic training, e.g. '[2, 4, 8, 16]'",
137147
}
138148

139149
for param_name, help_text in list_params.items():
140150
wrapped_func = click.option(
141-
f"--{param_name}",
151+
f"--{param_name.replace('_', '-')}",
142152
callback=_parse_list_flag,
143153
type=str,
144154
default=None,
@@ -154,6 +164,7 @@ def wrapped_func(*args, **kwargs):
154164
"command",
155165
"args",
156166
"volume",
167+
"elastic_replica_discrete_values"
157168
]
158169
)
159170

0 commit comments

Comments
 (0)