Skip to content

Commit 508a442

Browse files
committed
update training sdk config to combine config and status file into one unified_config file
1 parent dfe9192 commit 508a442

File tree

7 files changed

+176
-2994
lines changed

7 files changed

+176
-2994
lines changed

examples/training/SDK/training_sdk_example.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
"source": [
4141
"from sagemaker.hyperpod.training import (\n",
4242
" HyperPodPytorchJob,\n",
43-
" Container,\n",
43+
" Containers,\n",
4444
" ReplicaSpec,\n",
4545
" Resources,\n",
4646
" RunPolicy,\n",
@@ -57,7 +57,7 @@
5757
" template=Template(\n",
5858
" spec=Spec(\n",
5959
" containers=[\n",
60-
" Container(\n",
60+
" Containers(\n",
6161
" name=\"container-name\",\n",
6262
" image=\"448049793756.dkr.ecr.us-west-2.amazonaws.com/ptjob:mnist\",\n",
6363
" image_pull_policy=\"Always\",\n",

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pydantic import BaseModel, ConfigDict, Field
22
from typing import Optional, List, Dict, Union
3-
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_config import (
4-
Container,
3+
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import (
4+
Containers,
55
ReplicaSpec,
66
Resources,
77
RunPolicy,
@@ -103,7 +103,7 @@ def to_domain(self) -> Dict:
103103
]
104104

105105
# Create container object
106-
container = Container(**container_kwargs)
106+
container = Containers(**container_kwargs)
107107

108108
# Create pod spec kwargs
109109
spec_kwargs = {"containers": list([container])}

src/sagemaker/hyperpod/training/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_config import *
2-
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_status import (
3-
HyperPodPytorchJobStatus,
4-
)
1+
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import *
52
from sagemaker.hyperpod.training.hyperpod_pytorch_job import (
63
HyperPodPytorchJob,
74
_load_hp_job,

src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_config.py

Lines changed: 0 additions & 2977 deletions
This file was deleted.

src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_status.py renamed to src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_unified_config.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,6 +1610,23 @@ class LabelSelector(BaseModel):
16101610
)
16111611

16121612

1613+
class NamespaceSelector(BaseModel):
1614+
"""A label query over the set of namespaces that the term applies to. The term is applied to the union of the namespaces selected by this field and the ones listed in the namespaces field. null selector and null or empty namespaces list means "this pod's namespace". An empty selector ({}) matches all namespaces."""
1615+
1616+
model_config = ConfigDict(extra="forbid")
1617+
1618+
matchExpressions: Optional[List[MatchExpressions]] = Field(
1619+
default=None,
1620+
alias="match_expressions",
1621+
description="matchExpressions is a list of label selector requirements. The requirements are ANDed.",
1622+
)
1623+
matchLabels: Optional[Dict[str, str]] = Field(
1624+
default=None,
1625+
alias="match_labels",
1626+
description='matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels map is equivalent to an element of matchExpressions, whose key field is "key", the operator is "In", and the values array contains only "value". The requirements are ANDed.',
1627+
)
1628+
1629+
16131630
class TopologySpreadConstraints(BaseModel):
16141631
"""TopologySpreadConstraint specifies how to spread matching pods among the given topology."""
16151632

@@ -2955,6 +2972,134 @@ class Template(BaseModel):
29552972
)
29562973

29572974

2975+
class ReplicaSpec(BaseModel):
2976+
"""ReplicaSpec is a description of the replica"""
2977+
2978+
model_config = ConfigDict(extra="forbid")
2979+
2980+
name: str = Field(description="The name for the replica set")
2981+
replicas: Optional[int] = Field(
2982+
default=1,
2983+
description="Replicas is the desired number of replicas of the given template.",
2984+
)
2985+
spares: Optional[int] = Field(
2986+
default=0,
2987+
description="Spares requests spare resources from Kueue. E.g. If a job is configured with 4 replicas and 2 spares, job requests resources required to run 6 pods such as cpu, gpu",
2988+
)
2989+
template: Optional[Template] = Field(
2990+
default=None,
2991+
description="Template is the object that describes the pod that will be created for this replica.",
2992+
)
2993+
2994+
2995+
class LogMonitoringConfiguration(BaseModel):
2996+
"""LogMonitoringRule defines the criteria used to detect a SLOW or HANGING job"""
2997+
2998+
model_config = ConfigDict(extra="forbid")
2999+
3000+
expectedRecurringFrequencyInSeconds: Optional[int] = Field(
3001+
default=None,
3002+
alias="expected_recurring_frequency_in_seconds",
3003+
description="Time interval between two subsequent matches for LogPattern beyond which, the rule evaluates to HANGING. When not specified, there is no constraint on duration between two subsequent matches for LogPattern.",
3004+
)
3005+
expectedStartCutOffInSeconds: Optional[int] = Field(
3006+
default=None,
3007+
alias="expected_start_cut_off_in_seconds",
3008+
description="Time to first match for LogPattern beyond which, the rule evaluates to HANGING. When not specified, there is no constraint on time to first match for LogPattern.",
3009+
)
3010+
logPattern: str = Field(
3011+
alias="log_pattern",
3012+
description="Regex to identify log lines to apply the rule to when the rule is active. This regex can optionally include one capturing group to extract a numeric metric value.",
3013+
)
3014+
metricEvaluationDataPoints: Optional[int] = Field(
3015+
default=None,
3016+
alias="metric_evaluation_data_points",
3017+
description="The number of consecutive times that a rule must evaluate to SLOW in order to mark a job as SLOW. When not specified, the default value is 1.",
3018+
)
3019+
metricThreshold: Optional[int] = Field(
3020+
default=None,
3021+
alias="metric_threshold",
3022+
description="Threshold for value extracted by LogPattern if it has a capturing group. When not specified, Metric evaluation will not be performed.",
3023+
)
3024+
name: str = Field(description="Name of the rule")
3025+
operator: Optional[str] = Field(
3026+
default=None,
3027+
description="Operator to compare the value extracted by LogPattern to MetricThreshold. Rule evaluates to SLOW if value extracted by LogPattern compared to MetricThreshold using Operator evaluates to true. When not specified, Metric evaluation will not be performed. Following operator values are supported: gt (greater than) lt (lesser than) eq (equal to) gteq (greater than or equal to) lteq (less than or equal to)",
3028+
)
3029+
stopPattern: Optional[str] = Field(
3030+
default=None,
3031+
alias="stop_pattern",
3032+
description="Regex to identify the log line at which to deactivate the rule. When not specified, the rule will always be active.",
3033+
)
3034+
3035+
3036+
class RestartPolicy(BaseModel):
3037+
"""Additional restart limiting configurations"""
3038+
3039+
model_config = ConfigDict(extra="forbid")
3040+
3041+
evalPeriodSeconds: int = Field(
3042+
alias="eval_period_seconds",
3043+
description="The period of evaluating the restart limit in seconds",
3044+
)
3045+
maxFullJobRestarts: Optional[int] = Field(
3046+
default=None,
3047+
alias="max_full_job_restarts",
3048+
description="The max allowed number of full job restarts before failing the job",
3049+
)
3050+
numRestartBeforeFullJobRestart: Optional[int] = Field(
3051+
default=None,
3052+
alias="num_restart_before_full_job_restart",
3053+
description="The number of standard restarts before a full job restart",
3054+
)
3055+
3056+
3057+
class RunPolicy(BaseModel):
3058+
"""RunPolicy"""
3059+
3060+
model_config = ConfigDict(extra="forbid")
3061+
3062+
activeDeadlineSeconds: Optional[int] = Field(
3063+
default=None,
3064+
alias="active_deadline_seconds",
3065+
description="Specifies the duration in seconds relative to the startTime that the job may be active before the system tries to terminate it; value must be positive integer.",
3066+
)
3067+
cleanPodPolicy: Optional[str] = Field(
3068+
default="All",
3069+
alias="clean_pod_policy",
3070+
description="CleanPodPolicy defines the policy to kill pods after the job completes.",
3071+
)
3072+
faultDeadlineSeconds: Optional[int] = Field(
3073+
default=None,
3074+
alias="fault_deadline_seconds",
3075+
description="The limit on the fault time for the job (Status of Fault) before failing",
3076+
)
3077+
jobMaxRetryCount: Optional[int] = Field(default=None, alias="job_max_retry_count")
3078+
logMonitoringConfiguration: Optional[List[LogMonitoringConfiguration]] = Field(
3079+
default=None,
3080+
alias="log_monitoring_configuration",
3081+
description="LogMonitoringConfiguration defines the log monitoring rules for SLOW and HANGING job detection",
3082+
)
3083+
restartPolicy: Optional[RestartPolicy] = Field(
3084+
default=None,
3085+
alias="restart_policy",
3086+
description="Additional restart limiting configurations",
3087+
)
3088+
startupDeadlineSeconds: Optional[int] = Field(
3089+
default=None,
3090+
alias="startup_deadline_seconds",
3091+
description="The limit on the startup time for the job (Status of Staring) before failing",
3092+
)
3093+
suspend: Optional[bool] = Field(
3094+
default=None, description="Suspend suspends HyperPodPytorchJob when set to true"
3095+
)
3096+
ttlSecondsAfterFinished: Optional[int] = Field(
3097+
default=0,
3098+
alias="ttl_seconds_after_finished",
3099+
description="TTLSecondsAfterFinished is the TTL to clean up jobs. Set to -1 for infinite",
3100+
)
3101+
3102+
29583103
class PodSets(BaseModel):
29593104
model_config = ConfigDict(extra="forbid")
29603105

@@ -3081,3 +3226,23 @@ class HyperPodPytorchJobStatus(BaseModel):
30813226
alias="start_time",
30823227
description="The time when job is first acknowledged by the controller. When using kueue, the job is also admitted It is represented in RFC3339 form and is in UTC.",
30833228
)
3229+
3230+
3231+
class _HyperPodPytorchJob(BaseModel):
3232+
"""Config defines the desired state of HyperPodPytorchJob"""
3233+
3234+
model_config = ConfigDict(extra="ignore")
3235+
3236+
nprocPerNode: str = Field(
3237+
default="auto",
3238+
alias="nproc_per_node",
3239+
description="Number of workers per node; supported values: [auto, cpu, gpu, int].",
3240+
)
3241+
replicaSpecs: Optional[List[ReplicaSpec]] = Field(
3242+
default=None,
3243+
alias="replica_specs",
3244+
description="The replicas to include as part of the job",
3245+
)
3246+
runPolicy: Optional[RunPolicy] = Field(
3247+
default=None, alias="run_policy", description="RunPolicy"
3248+
)

src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from pydantic import ConfigDict, Field
2-
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_config import (
3-
_HyperPodPytorchJob,
4-
)
5-
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_status import (
6-
HyperPodPytorchJobStatus,
2+
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import (
3+
_HyperPodPytorchJob, HyperPodPytorchJobStatus
74
)
85
from sagemaker.hyperpod.common.config.metadata import Metadata
96
from kubernetes import client, config

test/unit_tests/training/test_hyperpod_pytorch_job.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sagemaker.hyperpod.training import (
66
HyperPodPytorchJob,
77
HyperPodPytorchJobStatus,
8-
Container,
8+
Containers,
99
ReplicaSpec,
1010
Resources,
1111
RunPolicy,
@@ -27,7 +27,7 @@ def setUp(self):
2727
template=Template(
2828
spec=Spec(
2929
containers=[
30-
Container(
30+
Containers(
3131
name="test-container",
3232
image="test-image",
3333
resources=Resources(
@@ -137,7 +137,7 @@ def test_get_success(self, mock_load_job, mock_custom_api, mock_verify_config):
137137
template=Template(
138138
spec=Spec(
139139
containers=[
140-
Container(
140+
Containers(
141141
name="test-container",
142142
image="test-image",
143143
resources=Resources(

0 commit comments

Comments
 (0)