@@ -69,13 +69,13 @@ func (m *TensorFusionPodMutator) Handle(ctx context.Context, req admission.Reque
6969 log := log .FromContext (ctx )
7070 log .Info ("Mutating pod" , "generateName" , pod .GenerateName , "namespace" , pod .Namespace )
7171
72- reqs := ParseTFReq (pod )
73- if len (reqs ) == 0 {
72+ resources := ParseTFResources (pod )
73+ if len (resources ) == 0 {
7474 return admission .Allowed ("no tensor fusion requirements found" )
7575 }
7676
7777 // 1. Inject initContainer and env variables
78- patches , err := m .patchTFClient (pod , reqs )
78+ patches , err := m .patchTFClient (pod , resources )
7979 if err != nil {
8080 return admission .Errored (http .StatusInternalServerError , err )
8181 }
@@ -89,36 +89,43 @@ func (m *TensorFusionPodMutator) InjectDecoder(d admission.Decoder) error {
8989 return nil
9090}
9191
92- type TFReq struct {
92+ type TFResource struct {
9393 ContainerName string
9494 ConnectionName string
9595 ConnectionNamespace string
96- Tflops resource.Quantity
97- Vram resource.Quantity
96+ TflopsRequest resource.Quantity
97+ VramRequest resource.Quantity
98+ TflopsLimit resource.Quantity
99+ VramLimit resource.Quantity
98100}
99101
100- func ParseTFReq (pod * corev1.Pod ) []TFReq {
102+ func ParseTFResources (pod * corev1.Pod ) []TFResource {
101103 if pod .Annotations == nil {
102104 return nil
103105 }
104106
105- reqs := make ([]TFReq , 0 , len (pod .Spec .Containers ))
107+ reqs := make ([]TFResource , 0 , len (pod .Spec .Containers ))
106108
107109 for _ , container := range pod .Spec .Containers {
108110 containerName := container .Name
109111
110112 // Check if TF requirements exist for this container
111- tflopsKey := fmt .Sprintf (constants .TFLOPSContainerAnnotationFormat , containerName )
112- vramKey := fmt .Sprintf (constants .VRAMContainerAnnotationFormat , containerName )
113+ tflopsReqKey := fmt .Sprintf (constants .TFLOPSRequestAnnotationFormat , containerName )
114+ vramReqKey := fmt .Sprintf (constants .VRAMRequestAnnotationFormat , containerName )
115+ tflopsLimitKey := fmt .Sprintf (constants .TFLOPSLimitAnnotationFormat , containerName )
116+ vramLimitKey := fmt .Sprintf (constants .VRAMLimitAnnotationFormat , containerName )
113117
114- tflopsStr , hasTflops := pod .Annotations [tflopsKey ]
115- vramStr , hasVram := pod .Annotations [vramKey ]
118+ tflopsReqStr , hasTflopsReq := pod .Annotations [tflopsReqKey ]
119+ vramReqStr , hasVramReq := pod .Annotations [vramReqKey ]
116120
117- if ! hasTflops && ! hasVram {
121+ tflopsLimitStr , hasTflopsLimit := pod .Annotations [tflopsLimitKey ]
122+ vramLimitStr , hasVramLimit := pod .Annotations [vramLimitKey ]
123+
124+ if ! hasTflopsReq && ! hasVramReq && ! hasTflopsLimit && ! hasVramLimit {
118125 continue
119126 }
120127
121- req := TFReq {
128+ req := TFResource {
122129 ContainerName : containerName ,
123130 }
124131 connectionNameEnv , ok := lo .Find (container .Env , func (e corev1.EnvVar ) bool {
@@ -133,19 +140,35 @@ func ParseTFReq(pod *corev1.Pod) []TFReq {
133140 if ok {
134141 req .ConnectionNamespace = connectionNamespaceEnv .Value
135142 }
136- // Parse TFLOPS requirement
137- if hasTflops {
138- tflops , err := resource .ParseQuantity (tflopsStr )
143+ // Parse TFLOPS request
144+ if hasTflopsReq {
145+ tflops , err := resource .ParseQuantity (tflopsReqStr )
146+ if err == nil {
147+ req .TflopsRequest = tflops
148+ }
149+ }
150+
151+ // Parse VRAM request
152+ if hasVramReq {
153+ vram , err := resource .ParseQuantity (vramReqStr )
154+ if err == nil {
155+ req .VramRequest = vram
156+ }
157+ }
158+
159+ // Parse TFLOPS limit
160+ if hasTflopsReq {
161+ tflops , err := resource .ParseQuantity (tflopsLimitStr )
139162 if err == nil {
140- req .Tflops = tflops
163+ req .TflopsLimit = tflops
141164 }
142165 }
143166
144- // Parse VRAM requirement
145- if hasVram {
146- vram , err := resource .ParseQuantity (vramStr )
167+ // Parse VRAM limit
168+ if hasVramReq {
169+ vram , err := resource .ParseQuantity (vramLimitStr )
147170 if err == nil {
148- req .Vram = vram
171+ req .VramLimit = vram
149172 }
150173 }
151174
@@ -155,7 +178,7 @@ func ParseTFReq(pod *corev1.Pod) []TFReq {
155178 return reqs
156179}
157180
158- func (m * TensorFusionPodMutator ) patchTFClient (pod * corev1.Pod , tfReq []TFReq ) ([]jsonpatch.JsonPatchOperation , error ) {
181+ func (m * TensorFusionPodMutator ) patchTFClient (pod * corev1.Pod , tfReq []TFResource ) ([]jsonpatch.JsonPatchOperation , error ) {
159182 // Convert the current pod to JSON
160183 currentBytes , err := json .Marshal (pod )
161184 if err != nil {
0 commit comments