|
11 | 11 | Metadata, |
12 | 12 | Volumes, |
13 | 13 | HostPath, |
14 | | - PersistentVolumeClaim |
| 14 | + PersistentVolumeClaim, |
| 15 | + ElasticPolicy |
15 | 16 | ) |
16 | 17 | from sagemaker.hyperpod.training.hyperpod_pytorch_job import HyperPodPytorchJob |
17 | 18 | import yaml |
@@ -239,6 +240,38 @@ class PyTorchJobConfig(BaseModel): |
239 | 240 | alias="required_topology", |
240 | 241 | description="Required topology annotation for scheduling", |
241 | 242 | ) |
| 243 | + elastic_replica_increment_step: Optional[int] = Field( |
| 244 | + default=None, |
| 245 | + alias="elastic_replica_increment_step", |
| 246 | + description="Scaling step size for elastic training", |
| 247 | + ge=1, |
| 248 | + ) |
| 249 | + max_node_count: Optional[int] = Field( |
| 250 | + default=None, |
| 251 | + alias="max_node_count", |
| 252 | + description="Maximum number of nodes for elastic training", |
| 253 | + ge=1, |
| 254 | + ) |
| 255 | + elastic_graceful_shutdown_timeout_in_seconds: Optional[int] = Field( |
| 256 | + default=None, |
| 257 | + alias="elastic_graceful_shutdown_timeout_in_seconds", |
| 258 | + description="Graceful shutdown timeout in seconds for elastic scaling operations" |
| 259 | + ) |
| 260 | + elastic_scaling_timeout_in_seconds: Optional[int] = Field( |
| 261 | + default=None, |
| 262 | + alias="elastic_scaling_timeout_in_seconds", |
| 263 | + description="Scaling timeout for elastic training" |
| 264 | + ) |
| 265 | + elastic_scale_up_snooze_time_in_seconds: Optional[int] = Field( |
| 266 | + default=None, |
| 267 | + alias="elastic_scale_up_snooze_time_in_seconds", |
| 268 | + description="Timeout period after job restart during which no scale up/workload admission is allowed" |
| 269 | + ) |
| 270 | + elastic_replica_discrete_values: Optional[List[int]] = Field( |
| 271 | + default=None, |
| 272 | + alias="elastic_replica_discrete_values", |
| 273 | + description="Alternative to replica increment step. Provides exact values for total replicas count" |
| 274 | + ) |
242 | 275 |
|
243 | 276 | @field_validator('tasks_per_node', mode='before') |
244 | 277 | @classmethod |
@@ -363,6 +396,45 @@ def validate_accelerator_partition_options(self): |
363 | 396 | ) |
364 | 397 | if not valid: |
365 | 398 | raise ValueError(error) |
| 399 | + |
| 400 | + return self |
| 401 | + |
| 402 | + @model_validator(mode='after') |
| 403 | + def validate_elastic_replica_config(self): |
| 404 | + """Validate elastic replica configuration.""" |
| 405 | + has_increment_step = self.elastic_replica_increment_step is not None |
| 406 | + has_discrete_values = self.elastic_replica_discrete_values is not None |
| 407 | + |
| 408 | + # Check mutual exclusivity |
| 409 | + if has_increment_step and has_discrete_values: |
| 410 | + raise ValueError( |
| 411 | + "Only one of 'elastic_replica_increment_step' or 'elastic_replica_discrete_values' " |
| 412 | + "can be specified, not both. Please use either:\n" |
| 413 | + " - elastic_replica_increment_step for uniform scaling steps, or\n" |
| 414 | + " - elastic_replica_discrete_values for specific replica counts" |
| 415 | + ) |
| 416 | + |
| 417 | + # Validate discrete values are within valid range |
| 418 | + if has_discrete_values: |
| 419 | + discrete_values = self.elastic_replica_discrete_values |
| 420 | + |
| 421 | + # Check that all values are positive |
| 422 | + if any(val <= 0 for val in discrete_values): |
| 423 | + raise ValueError( |
| 424 | + f"All values in 'elastic_replica_discrete_values' must be positive integers. " |
| 425 | + f"Got: {discrete_values}" |
| 426 | + ) |
| 427 | + |
| 428 | + # Check against max_node_count if specified |
| 429 | + if self.max_node_count is not None: |
| 430 | + invalid_values = [val for val in discrete_values if val > self.max_node_count] |
| 431 | + if invalid_values: |
| 432 | + raise ValueError( |
| 433 | + f"All values in 'elastic_replica_discrete_values' must be ≤ max_node_count ({self.max_node_count}). " |
| 434 | + f"Invalid values: {invalid_values}. " |
| 435 | + f"Please either increase max_node_count or remove values exceeding it." |
| 436 | + ) |
| 437 | + |
366 | 438 | return self |
367 | 439 |
|
368 | 440 | def to_domain(self) -> Dict: |
@@ -467,15 +539,61 @@ def build_dict(**kwargs): |
467 | 539 | replica_kwargs = build_dict( |
468 | 540 | name="pod", |
469 | 541 | template=Template(metadata=Metadata(**metadata_kwargs), spec=Spec(**spec_kwargs)), |
470 | | - replicas=self.node_count |
| 542 | + replicas=self.node_count, |
| 543 | + max_replicas=self.max_node_count |
471 | 544 | ) |
472 | 545 |
|
| 546 | + # Build elastic policy |
| 547 | + elastic_policy = None |
| 548 | + if any([ |
| 549 | + self.elastic_replica_increment_step is not None, |
| 550 | + self.max_node_count is not None, |
| 551 | + self.elastic_graceful_shutdown_timeout_in_seconds is not None, |
| 552 | + self.elastic_scaling_timeout_in_seconds is not None, |
| 553 | + self.elastic_replica_discrete_values is not None |
| 554 | + ]): |
| 555 | + # Build base elastic policy kwargs |
| 556 | + elastic_policy_kwargs = build_dict( |
| 557 | + min_replicas=self.node_count, |
| 558 | + max_replicas=self.max_node_count, |
| 559 | + graceful_shutdown_timeout_in_seconds=self.elastic_graceful_shutdown_timeout_in_seconds, |
| 560 | + scaling_timeout_in_seconds=self.elastic_scaling_timeout_in_seconds |
| 561 | + ) |
| 562 | + |
| 563 | + if self.elastic_replica_discrete_values is not None: |
| 564 | + elastic_policy_kwargs['replica_discrete_values'] = self.elastic_replica_discrete_values |
| 565 | + elif self.elastic_replica_increment_step is not None: |
| 566 | + elastic_policy_kwargs['replica_increment_step'] = self.elastic_replica_increment_step |
| 567 | + |
| 568 | + elastic_policy = ElasticPolicy(**elastic_policy_kwargs) |
| 569 | + |
| 570 | + # Build run policy |
| 571 | + run_policy = None |
| 572 | + if self.max_retry is not None or self.elastic_scale_up_snooze_time_in_seconds is not None: |
| 573 | + from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import RestartPolicy |
| 574 | + |
| 575 | + run_policy_kwargs = build_dict( |
| 576 | + clean_pod_policy="None", |
| 577 | + job_max_retry_count=self.max_retry |
| 578 | + ) |
| 579 | + |
| 580 | + # Add restart policy if scale_up_snooze_interval is provided |
| 581 | + if self.elastic_scale_up_snooze_time_in_seconds is not None: |
| 582 | + restart_policy = RestartPolicy( |
| 583 | + eval_period_seconds=3600, |
| 584 | + scale_up_snooze_time_in_seconds=self.elastic_scale_up_snooze_time_in_seconds |
| 585 | + ) |
| 586 | + run_policy_kwargs['restart_policy'] = restart_policy |
| 587 | + |
| 588 | + run_policy = RunPolicy(**run_policy_kwargs) |
| 589 | + |
473 | 590 | # Build job |
474 | 591 | job_kwargs = build_dict( |
475 | 592 | metadata=metadata_kwargs, |
476 | 593 | replica_specs=[ReplicaSpec(**replica_kwargs)], |
477 | 594 | nproc_per_node=str(self.tasks_per_node) if self.tasks_per_node else None, |
478 | | - run_policy=RunPolicy(clean_pod_policy="None", job_max_retry_count=self.max_retry) if self.max_retry else None |
| 595 | + run_policy=run_policy, |
| 596 | + elastic_policy=elastic_policy |
479 | 597 | ) |
480 | 598 |
|
481 | 599 | result = HyperPodPytorchJob(**job_kwargs) |
|
0 commit comments