1- from pydantic import BaseModel , ConfigDict , Field
2- from typing import Optional , List , Dict , Union
1+ from pydantic import BaseModel , ConfigDict , Field , field_validator , model_validator
2+ from typing import Optional , List , Dict , Union , Literal
33from sagemaker .hyperpod .training .config .hyperpod_pytorch_job_unified_config import (
44 Containers ,
55 ReplicaSpec ,
88 Spec ,
99 Template ,
1010 Metadata ,
11+ Volumes ,
12+ HostPath ,
13+ PersistentVolumeClaim
1114)
1215
1316
17+ class VolumeConfig (BaseModel ):
18+ name : str = Field (..., description = "Volume name" )
19+ type : Literal ['hostPath' , 'pvc' ] = Field (..., description = "Volume type" )
20+ mount_path : str = Field (..., description = "Mount path in container" )
21+ path : Optional [str ] = Field (None , description = "Host path (required for hostPath volumes)" )
22+ claim_name : Optional [str ] = Field (None , description = "PVC claim name (required for pvc volumes)" )
23+ read_only : Optional [Literal ['true' , 'false' ]] = Field (None , description = "Read-only flag for pvc volumes" )
24+
25+ @field_validator ('mount_path' , 'path' )
26+ @classmethod
27+ def paths_must_be_absolute (cls , v ):
28+ """Validate that paths are absolute (start with /)."""
29+ if v and not v .startswith ('/' ):
30+ raise ValueError ('Path must be absolute (start with /)' )
31+ return v
32+
33+ @model_validator (mode = 'after' )
34+ def validate_type_specific_fields (self ):
35+ """Validate that required fields are present based on volume type."""
36+
37+ if self .type == 'hostPath' :
38+ if not self .path :
39+ raise ValueError ('hostPath volumes require path field' )
40+ elif self .type == 'pvc' :
41+ if not self .claim_name :
42+ raise ValueError ('PVC volumes require claim_name field' )
43+
44+ return self
45+
46+
1447class PyTorchJobConfig (BaseModel ):
1548 model_config = ConfigDict (extra = "forbid" )
1649
@@ -60,22 +93,41 @@ class PyTorchJobConfig(BaseModel):
6093 max_retry : Optional [int ] = Field (
6194 default = None , alias = "max_retry" , description = "Maximum number of job retries"
6295 )
63- volumes : Optional [List [str ]] = Field (
64- default = None , description = "List of volumes to mount"
65- )
66- persistent_volume_claims : Optional [ List [ str ]] = Field (
67- default = None ,
68- alias = "persistent_volume_claims" ,
69- description = "List of persistent volume claims" ,
96+ volume : Optional [List [VolumeConfig ]] = Field (
97+ default = None , description = "List of volume configurations. \
98+ Command structure: --volume name=<volume_name>,type=<volume_type>,mount_path=<mount_path>,<type-specific options> \
99+ For hostPath: --volume name=model-data,type=hostPath,mount_path=/data,path=/data \
100+ For persistentVolumeClaim: --volume name=training-output,type=pvc,mount_path=/mnt/output,claim_name=training-output-pvc,read_only=false \
101+ If multiple --volume flag if multiple volumes are needed \
102+ "
70103 )
71104 service_account_name : Optional [str ] = Field (
72105 default = None , alias = "service_account_name" , description = "Service account name"
73106 )
74107
108+ @field_validator ('volume' )
109+ def validate_no_duplicates (cls , v ):
110+ """Validate no duplicate volume names or mount paths."""
111+ if not v :
112+ return v
113+
114+ # Check for duplicate volume names
115+ names = [vol .name for vol in v ]
116+ if len (names ) != len (set (names )):
117+ raise ValueError ("Duplicate volume names found" )
118+
119+ # Check for duplicate mount paths
120+ mount_paths = [vol .mount_path for vol in v ]
121+ if len (mount_paths ) != len (set (mount_paths )):
122+ raise ValueError ("Duplicate mount paths found" )
123+
124+ return v
125+
75126 def to_domain (self ) -> Dict :
76127 """
77128 Convert flat config to domain model (HyperPodPytorchJobSpec)
78129 """
130+
79131 # Create container with required fields
80132 container_kwargs = {
81133 "name" : "container-name" ,
@@ -97,17 +149,42 @@ def to_domain(self) -> Dict:
97149 container_kwargs ["env" ] = [
98150 {"name" : k , "value" : v } for k , v in self .environment .items ()
99151 ]
100- if self .volumes is not None :
101- container_kwargs ["volume_mounts" ] = [
102- {"name" : v , "mount_path" : f"/mnt/{ v } " } for v in self .volumes
103- ]
152+
153+ if self .volume is not None :
154+ volume_mounts = []
155+ for i , vol in enumerate (self .volume ):
156+ volume_mount = {"name" : vol .name , "mount_path" : vol .mount_path }
157+ volume_mounts .append (volume_mount )
158+
159+ container_kwargs ["volume_mounts" ] = volume_mounts
160+
104161
105162 # Create container object
106- container = Containers (** container_kwargs )
163+ try :
164+ container = Containers (** container_kwargs )
165+ except Exception as e :
166+ raise
107167
108168 # Create pod spec kwargs
109169 spec_kwargs = {"containers" : list ([container ])}
110170
171+ # Add volumes to pod spec if present
172+ if self .volume is not None :
173+ volumes = []
174+ for i , vol in enumerate (self .volume ):
175+ if vol .type == "hostPath" :
176+ host_path = HostPath (path = vol .path )
177+ volume_obj = Volumes (name = vol .name , host_path = host_path )
178+ elif vol .type == "pvc" :
179+ pvc_config = PersistentVolumeClaim (
180+ claim_name = vol .claim_name ,
181+ read_only = vol .read_only == "true" if vol .read_only else False
182+ )
183+ volume_obj = Volumes (name = vol .name , persistent_volume_claim = pvc_config )
184+ volumes .append (volume_obj )
185+
186+ spec_kwargs ["volumes" ] = volumes
187+
111188 # Add node selector if any selector fields are present
112189 node_selector = {}
113190 if self .instance_type is not None :
@@ -175,5 +252,4 @@ def to_domain(self) -> Dict:
175252 "namespace" : self .namespace ,
176253 "spec" : job_kwargs ,
177254 }
178-
179255 return result
0 commit comments