1212 HostPath ,
1313 PersistentVolumeClaim
1414)
15+ from sagemaker .hyperpod .training .hyperpod_pytorch_job import HyperPodPytorchJob
1516
1617
1718class VolumeConfig (BaseModel ):
@@ -228,133 +229,82 @@ def validate_label_selector_keys(cls, v):
228229 return v
229230
230231 def to_domain (self ) -> Dict :
231- """
232- Convert flat config to domain model (HyperPodPytorchJobSpec)
233- """
232+ """Convert flat config to domain model (HyperPodPytorchJobSpec)"""
234233
235- # Create container with required fields
236- container_kwargs = {
237- "name" : "pytorch-job-container" ,
238- "image" : self .image ,
239- "resources" : Resources (
240- requests = {"nvidia.com/gpu" : "0" },
241- limits = {"nvidia.com/gpu" : "0" },
242- ),
243- }
244-
245- # Add optional container fields
246- if self .command is not None :
247- container_kwargs ["command" ] = self .command
248- if self .args is not None :
249- container_kwargs ["args" ] = self .args
250- if self .pull_policy is not None :
251- container_kwargs ["image_pull_policy" ] = self .pull_policy
252- if self .environment is not None :
253- container_kwargs ["env" ] = [
254- {"name" : k , "value" : v } for k , v in self .environment .items ()
255- ]
256-
257- if self .volume is not None :
258- volume_mounts = []
259- for i , vol in enumerate (self .volume ):
260- volume_mount = {"name" : vol .name , "mount_path" : vol .mount_path }
261- volume_mounts .append (volume_mount )
262-
263- container_kwargs ["volume_mounts" ] = volume_mounts
264-
265-
266- # Create container object
267- try :
268- container = Containers (** container_kwargs )
269- except Exception as e :
270- raise
271-
272- # Create pod spec kwargs
273- spec_kwargs = {"containers" : list ([container ])}
234+ # Helper function to build dict with non-None values
235+ def build_dict (** kwargs ):
236+ return {k : v for k , v in kwargs .items () if v is not None }
237+
238+ # Build container
239+ container_kwargs = build_dict (
240+ name = "pytorch-job-container" ,
241+ image = self .image ,
242+ resources = Resources (requests = {"nvidia.com/gpu" : "0" }, limits = {"nvidia.com/gpu" : "0" }),
243+ command = self .command ,
244+ args = self .args ,
245+ image_pull_policy = self .pull_policy ,
246+ env = [{"name" : k , "value" : v } for k , v in self .environment .items ()] if self .environment else None ,
247+ volume_mounts = [{"name" : vol .name , "mount_path" : vol .mount_path } for vol in self .volume ] if self .volume else None
248+ )
249+
250+ container = Containers (** container_kwargs )
274251
275- # Add volumes to pod spec if present
276- if self .volume is not None :
252+ # Build volumes
253+ volumes = None
254+ if self .volume :
277255 volumes = []
278- for i , vol in enumerate ( self .volume ) :
256+ for vol in self .volume :
279257 if vol .type == "hostPath" :
280- host_path = HostPath (path = vol .path )
281- volume_obj = Volumes (name = vol .name , host_path = host_path )
258+ volume_obj = Volumes (name = vol .name , host_path = HostPath (path = vol .path ))
282259 elif vol .type == "pvc" :
283- pvc_config = PersistentVolumeClaim (
284- claim_name = vol .claim_name ,
285- read_only = vol .read_only == "true" if vol .read_only else False
286- )
287- volume_obj = Volumes (name = vol .name , persistent_volume_claim = pvc_config )
260+ volume_obj = Volumes (name = vol .name , persistent_volume_claim = PersistentVolumeClaim (
261+ claim_name = vol .claim_name ,
262+ read_only = vol .read_only == "true" if vol .read_only else False
263+ ))
288264 volumes .append (volume_obj )
289-
290- spec_kwargs ["volumes" ] = volumes
291-
292- # Add node selector if any selector fields are present
293- node_selector = {}
294- if self .instance_type is not None :
295- map = {"node.kubernetes.io/instance-type" : self .instance_type }
296- node_selector .update (map )
297- if self .label_selector is not None :
298- node_selector .update (self .label_selector )
299- if self .deep_health_check_passed_nodes_only :
300- map = {"deep-health-check-passed" : "true" }
301- node_selector .update (map )
302- if node_selector :
303- spec_kwargs .update ({"node_selector" : node_selector })
304-
305- # Add other optional pod spec fields
306- if self .service_account_name is not None :
307- map = {"service_account_name" : self .service_account_name }
308- spec_kwargs .update (map )
309-
310- if self .scheduler_type is not None :
311- map = {"scheduler_name" : self .scheduler_type }
312- spec_kwargs .update (map )
313-
314- # Build metadata labels only if relevant fields are present
315- metadata_kwargs = {"name" : self .job_name }
316- if self .namespace is not None :
317- metadata_kwargs ["namespace" ] = self .namespace
318-
319- metadata_labels = {}
320- if self .queue_name is not None :
321- metadata_labels ["kueue.x-k8s.io/queue-name" ] = self .queue_name
322- if self .priority is not None :
323- metadata_labels ["kueue.x-k8s.io/priority-class" ] = self .priority
324-
325- if metadata_labels :
326- metadata_kwargs ["labels" ] = metadata_labels
327265
328- # Create replica spec with only non-None values
329- replica_kwargs = {
330- "name" : "pod" ,
331- "template" : Template (
332- metadata = Metadata (** metadata_kwargs ), spec = Spec (** spec_kwargs )
333- ),
334- }
266+ # Build node selector
267+ node_selector = build_dict (
268+ ** {"node.kubernetes.io/instance-type" : self .instance_type } if self .instance_type else {},
269+ ** self .label_selector if self .label_selector else {},
270+ ** {"deep-health-check-passed" : "true" } if self .deep_health_check_passed_nodes_only else {}
271+ )
335272
336- if self .node_count is not None :
337- replica_kwargs ["replicas" ] = self .node_count
273+ # Build spec
274+ spec_kwargs = build_dict (
275+ containers = [container ],
276+ volumes = volumes ,
277+ node_selector = node_selector if node_selector else None ,
278+ service_account_name = self .service_account_name ,
279+ scheduler_name = self .scheduler_type
280+ )
338281
339- replica_spec = ReplicaSpec (** replica_kwargs )
282+ # Build metadata
283+ metadata_labels = build_dict (
284+ ** {"kueue.x-k8s.io/queue-name" : self .queue_name } if self .queue_name else {},
285+ ** {"kueue.x-k8s.io/priority-class" : self .priority } if self .priority else {}
286+ )
340287
341- replica_specs = list ([replica_spec ])
288+ metadata_kwargs = build_dict (
289+ name = self .job_name ,
290+ namespace = self .namespace ,
291+ labels = metadata_labels if metadata_labels else None
292+ )
342293
343- job_kwargs = {"replica_specs" : replica_specs }
344- # Add optional fields only if they exist
345- if self .tasks_per_node is not None :
346- job_kwargs ["nproc_per_node" ] = str (self .tasks_per_node )
294+ # Build replica spec
295+ replica_kwargs = build_dict (
296+ name = "pod" ,
297+ template = Template (metadata = Metadata (** metadata_kwargs ), spec = Spec (** spec_kwargs )),
298+ replicas = self .node_count
299+ )
347300
348- if self .max_retry is not None :
349- job_kwargs ["run_policy" ] = RunPolicy (
350- clean_pod_policy = "None" , job_max_retry_count = self .max_retry
351- )
301+ # Build job
302+ job_kwargs = build_dict (
303+ metadata = metadata_kwargs ,
304+ replica_specs = [ReplicaSpec (** replica_kwargs )],
305+ nproc_per_node = str (self .tasks_per_node ) if self .tasks_per_node else None ,
306+ run_policy = RunPolicy (clean_pod_policy = "None" , job_max_retry_count = self .max_retry ) if self .max_retry else None
307+ )
352308
353- # Create base return dictionary
354- result = {
355- "name" : self .job_name ,
356- "namespace" : self .namespace ,
357- "labels" : metadata_labels ,
358- "spec" : job_kwargs ,
359- }
309+ result = HyperPodPytorchJob (** job_kwargs )
360310 return result
0 commit comments