Skip to content

Commit 9d7acfd

Browse files
committed
Add recipe validation
Tested by running hyp recipe-init and hyp recipe-validate on files in recipe collection
1 parent 66e1f74 commit 9d7acfd

File tree

11 files changed

+909
-125
lines changed

11 files changed

+909
-125
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ __pycache__/
1818
/doc/_apidoc/
1919
doc/_build/
2020
/build
21+
/hyperpod-pytorch-job-template/build
2122

2223
/sagemaker-hyperpod/build
2324
/sagemaker-hyperpod/.coverage

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/model.py

Lines changed: 90 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -9,75 +9,60 @@
99
Template,
1010
Metadata,
1111
Volumes,
12-
HostPath,
13-
PersistentVolumeClaim
12+
HostPath,
13+
PersistentVolumeClaim,
1414
)
1515

1616

1717
class VolumeConfig(BaseModel):
18-
name: str = Field(
19-
...,
20-
description="Volume name",
21-
min_length=1
22-
)
23-
type: Literal['hostPath', 'pvc'] = Field(..., description="Volume type")
24-
mount_path: str = Field(
25-
...,
26-
description="Mount path in container",
27-
min_length=1
28-
)
18+
name: str = Field(..., description="Volume name", min_length=1)
19+
type: Literal["hostPath", "pvc"] = Field(..., description="Volume type")
20+
mount_path: str = Field(..., description="Mount path in container", min_length=1)
2921
path: Optional[str] = Field(
30-
None,
31-
description="Host path (required for hostPath volumes)",
32-
min_length=1
22+
None, description="Host path (required for hostPath volumes)", min_length=1
3323
)
3424
claim_name: Optional[str] = Field(
35-
None,
36-
description="PVC claim name (required for pvc volumes)",
37-
min_length=1
25+
None, description="PVC claim name (required for pvc volumes)", min_length=1
26+
)
27+
read_only: Optional[Literal["true", "false"]] = Field(
28+
None, description="Read-only flag for pvc volumes"
3829
)
39-
read_only: Optional[Literal['true', 'false']] = Field(None, description="Read-only flag for pvc volumes")
40-
41-
@field_validator('mount_path', 'path')
30+
31+
@field_validator("mount_path", "path")
4232
@classmethod
4333
def paths_must_be_absolute(cls, v):
4434
"""Validate that paths are absolute (start with /)."""
45-
if v and not v.startswith('/'):
46-
raise ValueError('Path must be absolute (start with /)')
35+
if v and not v.startswith("/"):
36+
raise ValueError("Path must be absolute (start with /)")
4737
return v
48-
49-
@model_validator(mode='after')
38+
39+
@model_validator(mode="after")
5040
def validate_type_specific_fields(self):
5141
"""Validate that required fields are present based on volume type."""
52-
53-
if self.type == 'hostPath':
42+
43+
if self.type == "hostPath":
5444
if not self.path:
55-
raise ValueError('hostPath volumes require path field')
56-
elif self.type == 'pvc':
45+
raise ValueError("hostPath volumes require path field")
46+
elif self.type == "pvc":
5747
if not self.claim_name:
58-
raise ValueError('PVC volumes require claim_name field')
59-
48+
raise ValueError("PVC volumes require claim_name field")
49+
6050
return self
6151

6252

6353
class PyTorchJobConfig(BaseModel):
6454
model_config = ConfigDict(extra="forbid")
6555

6656
job_name: str = Field(
67-
alias="job_name",
57+
alias="job_name",
6858
description="Job name",
6959
min_length=1,
7060
max_length=63,
71-
pattern=r'^[a-z0-9]([-a-z0-9]*[a-z0-9])?$'
72-
)
73-
image: str = Field(
74-
description="Docker image for training",
75-
min_length=1
61+
pattern=r"^[a-z0-9]([-a-z0-9]*[a-z0-9])?$",
7662
)
63+
image: str = Field(description="Docker image for training", min_length=1)
7764
namespace: Optional[str] = Field(
78-
default=None,
79-
description="Kubernetes namespace",
80-
min_length=1
65+
default=None, description="Kubernetes namespace", min_length=1
8166
)
8267
command: Optional[List[str]] = Field(
8368
default=None, description="Command to run in the container"
@@ -89,28 +74,22 @@ class PyTorchJobConfig(BaseModel):
8974
default=None, description="Environment variables as key_value pairs"
9075
)
9176
pull_policy: Optional[str] = Field(
92-
default=None,
93-
alias="pull_policy",
94-
description="Image pull policy",
95-
min_length=1
77+
default=None, alias="pull_policy", description="Image pull policy", min_length=1
9678
)
9779
instance_type: Optional[str] = Field(
98-
default=None,
99-
alias="instance_type",
80+
default=None,
81+
alias="instance_type",
10082
description="Instance type for training",
101-
min_length=1
83+
min_length=1,
10284
)
10385
node_count: Optional[int] = Field(
104-
default=None,
105-
alias="node_count",
106-
description="Number of nodes",
107-
ge=1
86+
default=None, alias="node_count", description="Number of nodes", ge=1
10887
)
10988
tasks_per_node: Optional[int] = Field(
110-
default=None,
111-
alias="tasks_per_node",
89+
default=None,
90+
alias="tasks_per_node",
11291
description="Number of tasks per node",
113-
ge=1
92+
ge=1,
11493
)
11594
label_selector: Optional[Dict[str, str]] = Field(
11695
default=None,
@@ -123,114 +102,122 @@ class PyTorchJobConfig(BaseModel):
123102
description="Schedule pods only on nodes that passed deep health check",
124103
)
125104
scheduler_type: Optional[str] = Field(
126-
default=None,
127-
alias="scheduler_type",
128-
description="Scheduler type",
129-
min_length=1
105+
default=None, alias="scheduler_type", description="Scheduler type", min_length=1
130106
)
131107
queue_name: Optional[str] = Field(
132-
default=None,
133-
alias="queue_name",
108+
default=None,
109+
alias="queue_name",
134110
description="Queue name for job scheduling",
135111
min_length=1,
136112
max_length=63,
137-
pattern=r'^[a-z0-9]([-a-z0-9]*[a-z0-9])?$'
113+
pattern=r"^[a-z0-9]([-a-z0-9]*[a-z0-9])?$",
138114
)
139115
priority: Optional[str] = Field(
140-
default=None,
141-
description="Priority class for job scheduling",
142-
min_length=1
116+
default=None, description="Priority class for job scheduling", min_length=1
143117
)
144118
max_retry: Optional[int] = Field(
145-
default=None,
146-
alias="max_retry",
119+
default=None,
120+
alias="max_retry",
147121
description="Maximum number of job retries",
148-
ge=0
122+
ge=0,
149123
)
150124
volume: Optional[List[VolumeConfig]] = Field(
151-
default=None, description="List of volume configurations. \
125+
default=None,
126+
description="List of volume configurations. \
152127
Command structure: --volume name=<volume_name>,type=<volume_type>,mount_path=<mount_path>,<type-specific options> \
153128
For hostPath: --volume name=model-data,type=hostPath,mount_path=/data,path=/data \
154129
For persistentVolumeClaim: --volume name=training-output,type=pvc,mount_path=/mnt/output,claim_name=training-output-pvc,read_only=false \
155130
If multiple --volume flag if multiple volumes are needed \
156-
"
131+
",
157132
)
158133
service_account_name: Optional[str] = Field(
159-
default=None,
160-
alias="service_account_name",
134+
default=None,
135+
alias="service_account_name",
161136
description="Service account name",
162-
min_length=1
137+
min_length=1,
163138
)
164139

165-
@field_validator('volume')
140+
@field_validator("volume")
166141
def validate_no_duplicates(cls, v):
167142
"""Validate no duplicate volume names or mount paths."""
168143
if not v:
169144
return v
170-
145+
171146
# Check for duplicate volume names
172147
names = [vol.name for vol in v]
173148
if len(names) != len(set(names)):
174149
raise ValueError("Duplicate volume names found")
175-
150+
176151
# Check for duplicate mount paths
177152
mount_paths = [vol.mount_path for vol in v]
178153
if len(mount_paths) != len(set(mount_paths)):
179154
raise ValueError("Duplicate mount paths found")
180-
155+
181156
return v
182157

183-
@field_validator('command', 'args')
158+
@field_validator("command", "args")
184159
def validate_string_lists(cls, v):
185160
"""Validate that command and args contain non-empty strings."""
186161
if not v:
187162
return v
188-
163+
189164
for i, item in enumerate(v):
190165
if not isinstance(item, str) or not item.strip():
191-
field_name = cls.model_fields.get('command', {}).get('alias', 'command') if 'command' in str(v) else 'args'
166+
field_name = (
167+
cls.model_fields.get("command", {}).get("alias", "command")
168+
if "command" in str(v)
169+
else "args"
170+
)
192171
raise ValueError(f"{field_name}[{i}] must be a non-empty string")
193-
172+
194173
return v
195174

196-
@field_validator('environment')
175+
@field_validator("environment")
197176
def validate_environment_variable_names(cls, v):
198177
"""Validate environment variable names follow C_IDENTIFIER pattern."""
199178
if not v:
200179
return v
201-
180+
202181
import re
203-
c_identifier_pattern = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
204-
182+
183+
c_identifier_pattern = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
184+
205185
for key in v.keys():
206186
if not c_identifier_pattern.match(key):
207-
raise ValueError(f"Environment variable name '{key}' must be a valid C_IDENTIFIER")
208-
187+
raise ValueError(
188+
f"Environment variable name '{key}' must be a valid C_IDENTIFIER"
189+
)
190+
209191
return v
210192

211-
@field_validator('label_selector')
193+
@field_validator("label_selector")
212194
def validate_label_selector_keys(cls, v):
213195
"""Validate label selector keys follow Kubernetes label naming conventions."""
214196
if not v:
215197
return v
216-
198+
217199
import re
200+
218201
# Kubernetes label key pattern - allows namespaced labels like kubernetes.io/arch
219202
# Pattern: [prefix/]name where prefix and name follow DNS subdomain rules
220203
# Also reject double dots
221-
label_key_pattern = re.compile(r'^([a-zA-Z0-9]([a-zA-Z0-9\-_.]*[a-zA-Z0-9])?/)?[a-zA-Z0-9]([a-zA-Z0-9\-_.]*[a-zA-Z0-9])?$')
222-
204+
label_key_pattern = re.compile(
205+
r"^([a-zA-Z0-9]([a-zA-Z0-9\-_.]*[a-zA-Z0-9])?/)?[a-zA-Z0-9]([a-zA-Z0-9\-_.]*[a-zA-Z0-9])?$"
206+
)
207+
223208
for key in v.keys():
224-
if not key or not label_key_pattern.match(key) or '..' in key:
225-
raise ValueError(f"Label selector key '{key}' must follow Kubernetes label naming conventions")
226-
209+
if not key or not label_key_pattern.match(key) or ".." in key:
210+
raise ValueError(
211+
f"Label selector key '{key}' must follow Kubernetes label naming conventions"
212+
)
213+
227214
return v
228215

229216
def to_domain(self) -> Dict:
230217
"""
231218
Convert flat config to domain model (HyperPodPytorchJobSpec)
232219
"""
233-
220+
234221
# Create container with required fields
235222
container_kwargs = {
236223
"name": "container-name",
@@ -258,9 +245,8 @@ def to_domain(self) -> Dict:
258245
for i, vol in enumerate(self.volume):
259246
volume_mount = {"name": vol.name, "mount_path": vol.mount_path}
260247
volume_mounts.append(volume_mount)
261-
262-
container_kwargs["volume_mounts"] = volume_mounts
263248

249+
container_kwargs["volume_mounts"] = volume_mounts
264250

265251
# Create container object
266252
try:
@@ -280,14 +266,16 @@ def to_domain(self) -> Dict:
280266
volume_obj = Volumes(name=vol.name, host_path=host_path)
281267
elif vol.type == "pvc":
282268
pvc_config = PersistentVolumeClaim(
283-
claim_name=vol.claim_name,
284-
read_only=vol.read_only == "true" if vol.read_only else False
269+
claim_name=vol.claim_name,
270+
read_only=vol.read_only == "true" if vol.read_only else False,
271+
)
272+
volume_obj = Volumes(
273+
name=vol.name, persistent_volume_claim=pvc_config
285274
)
286-
volume_obj = Volumes(name=vol.name, persistent_volume_claim=pvc_config)
287275
volumes.append(volume_obj)
288-
276+
289277
spec_kwargs["volumes"] = volumes
290-
278+
291279
# Add node selector if any selector fields are present
292280
node_selector = {}
293281
if self.instance_type is not None:

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/recipes/hf/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)