11from pydantic import ConfigDict , Field
2+
3+ from sagemaker .hyperpod .cli .constants .command_constants import INSTANCE_TYPE_LABEL , NVIDIA_GPU_RESOURCE_LIMIT_KEY , \
4+ NEURON_RESOURCE_LIMIT_KEY
25from sagemaker .hyperpod .training .config .hyperpod_pytorch_job_unified_config import (
36 _HyperPodPytorchJob , HyperPodPytorchJobStatus
47)
1821import yaml
1922import logging
2023
24+ from hyperpod_pytorch_job_template .quota_allocation_util import _is_valid , _get_resources_from_compute_quotas , _get_resources_from_instance , _get_limits
25+
26+
2127
2228TRAINING_GROUP = "sagemaker.amazonaws.com"
2329API_VERSION = "v1"
@@ -52,6 +58,88 @@ def verify_kube_config(cls):
5258
5359 # Verify Kubernetes version compatibility
5460 verify_kubernetes_version_compatibility (cls .get_logger ())
61+ @classmethod
62+ def sanitize_memory (cls , resource ):
63+ if 'memory' in resource :
64+ memory = resource ['memory' ]
65+ # Case when quotas have been already initialized in CLI layer
66+ # ToDo : Cleanup quota initialization in CLI layer and directly use SDK layer for init.
67+ memory .replace ('GiGi' , 'Gi' )
68+ resource ['memory' ] = memory
69+
70+ @classmethod
71+ def _process_replica_resources (cls , data ):
72+ """Process and validate replica resource configuration."""
73+ try :
74+ node_count = data ['replicas' ]
75+
76+ # Extract nested configuration with validation
77+ template = data .get ('template' , {})
78+ spec = template .get ('spec' , {})
79+ node_selector = spec .get ('nodeSelector' , {})
80+ containers = spec .get ('containers' , [])
81+
82+ if not containers :
83+ raise ValueError ("No containers found in template spec" )
84+
85+ instance_type = node_selector .get (INSTANCE_TYPE_LABEL , None )
86+ if not instance_type :
87+ raise ValueError ("Instance type not found in node selector" )
88+
89+ container = containers [0 ]
90+ resources = container .get ('resources' , {})
91+ requests = resources .get ('requests' , {})
92+ limits = resources .get ('limits' , {})
93+
94+ # Extract resource values
95+ vcpu = requests .get ('vcpu' , None )
96+ memory = requests .get ('memory' , None )
97+ accelerators = requests .get (NVIDIA_GPU_RESOURCE_LIMIT_KEY ) or requests .get (NEURON_RESOURCE_LIMIT_KEY ) or None
98+ memory_limit = limits .get ('memory' , None )
99+ vcpu_limit = limits .get ('vcpu' , None )
100+ accelerators_limit = limits .get (NVIDIA_GPU_RESOURCE_LIMIT_KEY ) or requests .get (NEURON_RESOURCE_LIMIT_KEY ) or None
101+
102+ # Validate configuration
103+ valid , error = _is_valid (vcpu , memory , accelerators , node_count , instance_type )
104+ if not valid :
105+ raise ValueError (error )
106+
107+ # Calculate resource values
108+ requests_value = (_get_resources_from_compute_quotas (instance_type , vcpu , memory , accelerators )
109+ or _get_resources_from_instance (instance_type , node_count ))
110+ limits_value = _get_limits (instance_type , vcpu_limit , memory_limit , accelerators_limit )
111+ requests_value = cls .sanitize_memory (requests_value )
112+ limits_value = cls .sanitze_memory (limits_value )
113+
114+ # Update data with calculated values
115+ data ['template' ]['spec' ]['containers' ][0 ]['resources' ]['requests' ] = requests_value
116+ data ['template' ]['spec' ]['containers' ][0 ]['resources' ]['limits' ] = limits_value
117+ return data
118+ except KeyError as e :
119+ raise ValueError (f"Missing required configuration key: { str (e )} " )
120+
121+ @classmethod
122+ def _get_container_resources (cls , replica_spec ):
123+ """Extract container resources from replica spec."""
124+ container_resources = replica_spec ['template' ]['spec' ]['containers' ][0 ]['resources' ]
125+ return container_resources ['requests' ], container_resources ['limits' ]
126+
127+ @classmethod
128+ def allocate_quotas_if_applicable (cls , spec ):
129+ try :
130+ spec_dict = spec .model_dump ()
131+ replica_spec = spec_dict ['replicaSpecs' ][0 ]
132+ cls ._process_replica_resources (replica_spec )
133+
134+ # Update the original spec object directly
135+ requests , limits = cls ._get_container_resources (replica_spec )
136+ spec .replicaSpecs [0 ].template .spec .containers [0 ].resources .requests = requests
137+ spec .replicaSpecs [0 ].template .spec .containers [0 ].resources .limits = limits
138+
139+ return spec
140+ except Exception as e :
141+ print (f"Warning: in quota allocation: { e } . using defaults." )
142+ return spec
55143
56144 @_hyperpod_telemetry_emitter (Feature .HYPERPOD , "create_pytorchjob" )
57145 def create (self , debug = False ):
@@ -65,6 +153,10 @@ def create(self, debug=False):
65153 if not self .metadata .namespace :
66154 self .metadata .namespace = get_default_namespace ()
67155
156+ spec = self .allocate_quotas_if_applicable (spec )
157+ if spec .replicaSpecs [0 ].replicas == 0 :
158+ spec .replicaSpecs [0 ].replicas = 1 # default value
159+
68160 config = {
69161 "apiVersion" : f"{ TRAINING_GROUP } /{ API_VERSION } " ,
70162 "kind" : KIND ,
@@ -91,6 +183,8 @@ def create(self, debug=False):
91183 logger .error (f"Failed to create HyperPodPytorchJob { self .metadata .name } !" )
92184 handle_exception (e , self .metadata .name , self .metadata .namespace )
93185
186+
187+
94188 @classmethod
95189 @_hyperpod_telemetry_emitter (Feature .HYPERPOD , "list_pytorchjobs" )
96190 def list (cls , namespace = None ) -> List ["HyperPodPytorchJob" ]:
0 commit comments