Skip to content

Commit ec8800d

Browse files
Update volume flag to support hostPath and pvc (#171)
* update help text to avoid truncation * update volume flag to support hostPath and pvc, before e2e testing * clean up and e2e working * Minor updates after PR * update * Added unit tests for volume, all cli unit tests passed
1 parent 9f534b4 commit ec8800d

File tree

4 files changed

+765
-176
lines changed

4 files changed

+765
-176
lines changed

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/model.py

Lines changed: 91 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from pydantic import BaseModel, ConfigDict, Field
2-
from typing import Optional, List, Dict, Union
1+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
2+
from typing import Optional, List, Dict, Union, Literal
33
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import (
44
Containers,
55
ReplicaSpec,
@@ -8,9 +8,42 @@
88
Spec,
99
Template,
1010
Metadata,
11+
Volumes,
12+
HostPath,
13+
PersistentVolumeClaim
1114
)
1215

1316

17+
class VolumeConfig(BaseModel):
18+
name: str = Field(..., description="Volume name")
19+
type: Literal['hostPath', 'pvc'] = Field(..., description="Volume type")
20+
mount_path: str = Field(..., description="Mount path in container")
21+
path: Optional[str] = Field(None, description="Host path (required for hostPath volumes)")
22+
claim_name: Optional[str] = Field(None, description="PVC claim name (required for pvc volumes)")
23+
read_only: Optional[Literal['true', 'false']] = Field(None, description="Read-only flag for pvc volumes")
24+
25+
@field_validator('mount_path', 'path')
26+
@classmethod
27+
def paths_must_be_absolute(cls, v):
28+
"""Validate that paths are absolute (start with /)."""
29+
if v and not v.startswith('/'):
30+
raise ValueError('Path must be absolute (start with /)')
31+
return v
32+
33+
@model_validator(mode='after')
34+
def validate_type_specific_fields(self):
35+
"""Validate that required fields are present based on volume type."""
36+
37+
if self.type == 'hostPath':
38+
if not self.path:
39+
raise ValueError('hostPath volumes require path field')
40+
elif self.type == 'pvc':
41+
if not self.claim_name:
42+
raise ValueError('PVC volumes require claim_name field')
43+
44+
return self
45+
46+
1447
class PyTorchJobConfig(BaseModel):
1548
model_config = ConfigDict(extra="forbid")
1649

@@ -60,22 +93,41 @@ class PyTorchJobConfig(BaseModel):
6093
max_retry: Optional[int] = Field(
6194
default=None, alias="max_retry", description="Maximum number of job retries"
6295
)
63-
volumes: Optional[List[str]] = Field(
64-
default=None, description="List of volumes to mount"
65-
)
66-
persistent_volume_claims: Optional[List[str]] = Field(
67-
default=None,
68-
alias="persistent_volume_claims",
69-
description="List of persistent volume claims",
96+
volume: Optional[List[VolumeConfig]] = Field(
97+
default=None, description="List of volume configurations. \
98+
Command structure: --volume name=<volume_name>,type=<volume_type>,mount_path=<mount_path>,<type-specific options> \
99+
For hostPath: --volume name=model-data,type=hostPath,mount_path=/data,path=/data \
100+
For persistentVolumeClaim: --volume name=training-output,type=pvc,mount_path=/mnt/output,claim_name=training-output-pvc,read_only=false \
101+
If multiple --volume flag if multiple volumes are needed \
102+
"
70103
)
71104
service_account_name: Optional[str] = Field(
72105
default=None, alias="service_account_name", description="Service account name"
73106
)
74107

108+
@field_validator('volume')
109+
def validate_no_duplicates(cls, v):
110+
"""Validate no duplicate volume names or mount paths."""
111+
if not v:
112+
return v
113+
114+
# Check for duplicate volume names
115+
names = [vol.name for vol in v]
116+
if len(names) != len(set(names)):
117+
raise ValueError("Duplicate volume names found")
118+
119+
# Check for duplicate mount paths
120+
mount_paths = [vol.mount_path for vol in v]
121+
if len(mount_paths) != len(set(mount_paths)):
122+
raise ValueError("Duplicate mount paths found")
123+
124+
return v
125+
75126
def to_domain(self) -> Dict:
76127
"""
77128
Convert flat config to domain model (HyperPodPytorchJobSpec)
78129
"""
130+
79131
# Create container with required fields
80132
container_kwargs = {
81133
"name": "container-name",
@@ -97,17 +149,42 @@ def to_domain(self) -> Dict:
97149
container_kwargs["env"] = [
98150
{"name": k, "value": v} for k, v in self.environment.items()
99151
]
100-
if self.volumes is not None:
101-
container_kwargs["volume_mounts"] = [
102-
{"name": v, "mount_path": f"/mnt/{v}"} for v in self.volumes
103-
]
152+
153+
if self.volume is not None:
154+
volume_mounts = []
155+
for i, vol in enumerate(self.volume):
156+
volume_mount = {"name": vol.name, "mount_path": vol.mount_path}
157+
volume_mounts.append(volume_mount)
158+
159+
container_kwargs["volume_mounts"] = volume_mounts
160+
104161

105162
# Create container object
106-
container = Containers(**container_kwargs)
163+
try:
164+
container = Containers(**container_kwargs)
165+
except Exception as e:
166+
raise
107167

108168
# Create pod spec kwargs
109169
spec_kwargs = {"containers": list([container])}
110170

171+
# Add volumes to pod spec if present
172+
if self.volume is not None:
173+
volumes = []
174+
for i, vol in enumerate(self.volume):
175+
if vol.type == "hostPath":
176+
host_path = HostPath(path=vol.path)
177+
volume_obj = Volumes(name=vol.name, host_path=host_path)
178+
elif vol.type == "pvc":
179+
pvc_config = PersistentVolumeClaim(
180+
claim_name=vol.claim_name,
181+
read_only=vol.read_only == "true" if vol.read_only else False
182+
)
183+
volume_obj = Volumes(name=vol.name, persistent_volume_claim=pvc_config)
184+
volumes.append(volume_obj)
185+
186+
spec_kwargs["volumes"] = volumes
187+
111188
# Add node selector if any selector fields are present
112189
node_selector = {}
113190
if self.instance_type is not None:
@@ -175,5 +252,4 @@ def to_domain(self) -> Dict:
175252
"namespace": self.namespace,
176253
"spec": job_kwargs,
177254
}
178-
179255
return result

0 commit comments

Comments
 (0)