11from pydantic import ConfigDict , Field
2+
3+ from sagemaker .hyperpod .cli .constants .command_constants import INSTANCE_TYPE_LABEL , NVIDIA_GPU_RESOURCE_LIMIT_KEY , \
4+ NEURON_RESOURCE_LIMIT_KEY
25from sagemaker .hyperpod .training .config .hyperpod_pytorch_job_unified_config import (
36 _HyperPodPytorchJob , HyperPodPytorchJobStatus
47)
1821import yaml
1922import logging
2023
24+ from hyperpod_pytorch_job_template .quota_allocation_util import _is_valid , _get_resources_from_compute_quotas , _get_resources_from_instance , _get_limits
25+
26+
2127
2228TRAINING_GROUP = "sagemaker.amazonaws.com"
2329API_VERSION = "v1"
@@ -52,6 +58,109 @@ def verify_kube_config(cls):
5258
5359 # Verify Kubernetes version compatibility
5460 verify_kubernetes_version_compatibility (cls .get_logger ())
61+ @classmethod
62+ def _extract_numeric_value (cls , value ):
63+ """Extract numeric value from strings like '1.5Gi' -> 1.5"""
64+ if not value :
65+ return None
66+ import re
67+ match = re .match (r'^([0-9]*\.?[0-9]+)' , str (value ))
68+ return float (match .group (1 )) if match else None
69+
70+ @classmethod
71+ def sanitize_memory (cls , resource ):
72+ try :
73+ if 'memory' in resource :
74+ memory = resource ['memory' ]
75+ # Case when quotas have been already initialized in CLI layer
76+ # ToDo : Cleanup quota initialization in CLI layer and directly use SDK layer for init.
77+ memory .replace ('GiGi' , 'Gi' )
78+ resource ['memory' ] = memory
79+ return resource
80+ except Exception as e :
81+ return resource
82+
83+
84+ @classmethod
85+ def _process_replica_resources (cls , data ):
86+ """Process and validate replica resource configuration."""
87+ try :
88+ node_count = data .get ('replicas' , None )
89+
90+ # Extract nested configuration with validation
91+ template = data .get ('template' , {})
92+ spec = template .get ('spec' , {})
93+ node_selector = spec .get ('nodeSelector' , {})
94+ instance_type = node_selector .get (INSTANCE_TYPE_LABEL ) if node_selector else None
95+
96+ if not instance_type :
97+ return None
98+
99+ containers = spec .get ('containers' , [])
100+
101+ if not containers :
102+ raise ValueError ("No containers found in template spec" )
103+
104+ container = containers [0 ]
105+ resources = container .get ('resources' , {})
106+ requests = resources .get ('requests' , {})
107+ limits = resources .get ('limits' , {})
108+
109+ # Extract resource values
110+ vcpu = float (requests .get ('cpu' )) if requests .get ('cpu' ) else None
111+ memory = cls ._extract_numeric_value (requests .get ('memory' ))
112+ accelerators = int (requests .get (NVIDIA_GPU_RESOURCE_LIMIT_KEY )) or int (requests .get (NEURON_RESOURCE_LIMIT_KEY )) or None
113+ memory_limit = cls ._extract_numeric_value (limits .get ('memory' ))
114+ vcpu_limit = float (limits .get ('cpu' )) if limits .get ('cpu' ) else None
115+ accelerators_limit = int (limits .get (NVIDIA_GPU_RESOURCE_LIMIT_KEY )) or int (limits .get (NEURON_RESOURCE_LIMIT_KEY )) or None
116+
117+ # Validate configuration
118+ valid , error = _is_valid (vcpu , memory , accelerators , node_count , instance_type )
119+ if not valid :
120+ raise ValueError (error )
121+
122+ # Calculate resource values
123+ requests_value = (_get_resources_from_compute_quotas (instance_type , vcpu , memory , accelerators )
124+ or _get_resources_from_instance (instance_type , node_count ))
125+ limits_value = _get_limits (instance_type , vcpu_limit , memory_limit , accelerators_limit )
126+
127+ requests_value = cls .sanitize_memory (requests_value )
128+ limits_value = cls .sanitize_memory (limits_value )
129+
130+ # Update data with calculated values
131+ data ['template' ]['spec' ]['containers' ][0 ]['resources' ]['requests' ] = requests_value
132+ data ['template' ]['spec' ]['containers' ][0 ]['resources' ]['limits' ] = limits_value
133+ return data
134+ except KeyError as e :
135+ raise ValueError (f"Missing required configuration key: { str (e )} " )
136+
137+ @classmethod
138+ def _get_container_resources (cls , replica_spec ):
139+ """Extract container resources from replica spec."""
140+ container_resources = replica_spec ['template' ]['spec' ]['containers' ][0 ]['resources' ]
141+ return container_resources ['requests' ], container_resources ['limits' ]
142+
143+ @classmethod
144+ def allocate_quotas_if_applicable (cls , spec ):
145+ logger = cls .get_logger ()
146+ logger = setup_logging (logger )
147+ try :
148+ spec_dict = spec .model_dump ()
149+ replica_spec = spec_dict ['replicaSpecs' ][0 ]
150+ cls ._process_replica_resources (replica_spec )
151+
152+ # Update the original spec object directly
153+ requests , limits = cls ._get_container_resources (replica_spec )
154+ spec .replicaSpecs [0 ].template .spec .containers [0 ].resources .requests = requests
155+ spec .replicaSpecs [0 ].template .spec .containers [0 ].resources .limits = limits
156+
157+ return spec
158+ except ValueError as e :
159+ logger .error (f"Error: in quota allocation:{ e } " )
160+ raise ValueError (e )
161+ except Exception as e :
162+ logger .info (f"Warning: in quota allocation: { e } . using defaults." )
163+ return spec
55164
56165 @_hyperpod_telemetry_emitter (Feature .HYPERPOD , "create_pytorchjob" )
57166 def create (self , debug = False ):
@@ -65,6 +174,10 @@ def create(self, debug=False):
65174 if not self .metadata .namespace :
66175 self .metadata .namespace = get_default_namespace ()
67176
177+ spec = self .allocate_quotas_if_applicable (spec )
178+ if spec .replicaSpecs [0 ].replicas == 0 :
179+ spec .replicaSpecs [0 ].replicas = 1 # default value
180+
68181 config = {
69182 "apiVersion" : f"{ TRAINING_GROUP } /{ API_VERSION } " ,
70183 "kind" : KIND ,
@@ -91,6 +204,8 @@ def create(self, debug=False):
91204 logger .error (f"Failed to create HyperPodPytorchJob { self .metadata .name } !" )
92205 handle_exception (e , self .metadata .name , self .metadata .namespace )
93206
207+
208+
94209 @classmethod
95210 @_hyperpod_telemetry_emitter (Feature .HYPERPOD , "list_pytorchjobs" )
96211 def list (cls , namespace = None ) -> List ["HyperPodPytorchJob" ]:
0 commit comments