@@ -19,6 +19,7 @@ package v1
1919import (
2020 "context"
2121 "encoding/json"
22+ "fmt"
2223 "net/http"
2324
2425 tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
@@ -27,6 +28,7 @@ import (
2728 . "github.com/onsi/ginkgo/v2"
2829 . "github.com/onsi/gomega"
2930 admissionv1 "k8s.io/api/admission/v1"
31+ appsv1 "k8s.io/api/apps/v1"
3032 corev1 "k8s.io/api/core/v1"
3133 "k8s.io/apimachinery/pkg/api/resource"
3234 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -283,6 +285,141 @@ var _ = Describe("TensorFusionPodMutator", func() {
283285 })
284286 })
285287
288+ Context ("Handle with EnabledReplicas" , func () {
289+ It ("should only patch enabledReplicas pods" , func () {
290+ // Create a ReplicaSet as the owner for the pod
291+ replicaSet := & appsv1.ReplicaSet {
292+ ObjectMeta : metav1.ObjectMeta {
293+ Name : "test-rs" ,
294+ Namespace : "default" ,
295+ },
296+ Spec : appsv1.ReplicaSetSpec {
297+ Selector : & metav1.LabelSelector {
298+ MatchLabels : map [string ]string {
299+ "app" : "test-app" ,
300+ },
301+ },
302+ Template : corev1.PodTemplateSpec {
303+ ObjectMeta : metav1.ObjectMeta {
304+ Labels : map [string ]string {
305+ "app" : "test-app" ,
306+ },
307+ },
308+ Spec : corev1.PodSpec {
309+ Containers : []corev1.Container {
310+ {
311+ Name : "test-container" ,
312+ Image : "test-image" ,
313+ },
314+ },
315+ },
316+ },
317+ },
318+ }
319+
320+ Expect (k8sclient .Create (ctx , replicaSet )).To (Succeed ())
321+
322+ // Get the ReplicaSet to obtain its UID
323+ createdReplicaSet := & appsv1.ReplicaSet {}
324+ Expect (k8sclient .Get (ctx , client.ObjectKey {Namespace : "default" , Name : "test-rs" }, createdReplicaSet )).To (Succeed ())
325+ replicaSetUID := createdReplicaSet .GetUID ()
326+
327+ // Create a workload profile
328+ workloadProfile := & tfv1.WorkloadProfile {
329+ ObjectMeta : metav1.ObjectMeta {
330+ Name : "test-profile-enabled-replicas" ,
331+ Namespace : "default" ,
332+ },
333+ Spec : tfv1.WorkloadProfileSpec {
334+ PoolName : "mock" ,
335+ Resources : tfv1.Resources {
336+ Requests : tfv1.Resource {
337+ Tflops : resource .MustParse ("10" ),
338+ Vram : resource .MustParse ("1Gi" ),
339+ },
340+ Limits : tfv1.Resource {
341+ Tflops : resource .MustParse ("100" ),
342+ Vram : resource .MustParse ("16Gi" ),
343+ },
344+ },
345+ },
346+ }
347+ Expect (k8sclient .Create (ctx , workloadProfile )).To (Succeed ())
348+
349+ // Create a pod with TF resources and owner reference
350+ trueVal := true
351+ enabledReplicas := int32 (1 )
352+
353+ pod := & corev1.Pod {
354+ ObjectMeta : metav1.ObjectMeta {
355+ Namespace : "default" ,
356+ GenerateName : "test-pod-enabled-replicas-" ,
357+ Labels : map [string ]string {
358+ constants .TensorFusionEnabledLabelKey : "true" ,
359+ "pod-template-hash" : "test-hash" ,
360+ },
361+ Annotations : map [string ]string {
362+ constants .GpuPoolKey : "mock" ,
363+ constants .WorkloadProfileAnnotation : "test-profile-enabled-replicas" ,
364+ constants .InjectContainerAnnotation : "main" ,
365+ constants .WorkloadKey : "test-workload" ,
366+ constants .TensorFusionEnabledReplicasAnnotation : fmt .Sprintf ("%d" , enabledReplicas ), // Using the correct constant
367+ },
368+ OwnerReferences : []metav1.OwnerReference {
369+ {
370+ APIVersion : "apps/v1" ,
371+ Kind : "ReplicaSet" ,
372+ Name : "test-rs" ,
373+ UID : replicaSetUID ,
374+ Controller : & trueVal ,
375+ },
376+ },
377+ },
378+ Spec : corev1.PodSpec {
379+ Containers : []corev1.Container {
380+ {
381+ Name : "main" ,
382+ Image : "test-image" ,
383+ },
384+ },
385+ },
386+ }
387+
388+ podBytes , err := json .Marshal (pod )
389+ Expect (err ).NotTo (HaveOccurred ())
390+
391+ req := admission.Request {
392+ AdmissionRequest : admissionv1.AdmissionRequest {
393+ Object : runtime.RawExtension {
394+ Raw : podBytes ,
395+ },
396+ Operation : admissionv1 .Create ,
397+ },
398+ }
399+
400+ resp := mutator .Handle (ctx , req )
401+ // First call: Pod mutation should occur since enabledReplicas is 1,
402+ // so the response should be allowed and contain patches
403+ Expect (resp .Allowed ).To (BeTrue ())
404+ Expect (resp .Patches ).NotTo (BeEmpty ())
405+
406+ counter := & TensorFusionPodCounter {Client : k8sclient }
407+ count , _ , err := counter .Get (ctx , pod )
408+ Expect (err ).NotTo (HaveOccurred ())
409+ Expect (count ).To (Equal (int32 (1 )))
410+
411+ resp = mutator .Handle (ctx , req )
412+ // Second call: Pod should be ignored since it's been processed already,
413+ // so the response should be allowed but patches should be empty
414+ Expect (resp .Allowed ).To (BeTrue ())
415+ Expect (resp .Patches ).To (BeEmpty ())
416+
417+ // Clean up
418+ Expect (k8sclient .Delete (ctx , replicaSet )).To (Succeed ())
419+ Expect (k8sclient .Delete (ctx , workloadProfile )).To (Succeed ())
420+ })
421+ })
422+
286423 Context ("ParseTensorFusionInfo" , func () {
287424 It ("should correctly parse TF requirements from pod annotations" , func () {
288425 // Set up a workload profile for testing
@@ -316,8 +453,9 @@ var _ = Describe("TensorFusionPodMutator", func() {
316453 constants .WorkloadProfileAnnotation : "test-profile-parse-tf-resources" ,
317454 constants .WorkloadKey : "test-workload" ,
318455 // override tflops request
319- constants .TFLOPSRequestAnnotation : "20" ,
320- constants .InjectContainerAnnotation : "test-container" ,
456+ constants .TFLOPSRequestAnnotation : "20" ,
457+ constants .InjectContainerAnnotation : "test-container" ,
458+ constants .TensorFusionEnabledReplicasAnnotation : "3" ,
321459 },
322460 },
323461 Spec : corev1.PodSpec {
@@ -337,6 +475,7 @@ var _ = Describe("TensorFusionPodMutator", func() {
337475 Expect (tfInfo .Profile .Resources .Requests .Vram .String ()).To (Equal ("1Gi" ))
338476 Expect (tfInfo .Profile .Resources .Limits .Tflops .String ()).To (Equal ("100" ))
339477 Expect (tfInfo .Profile .Resources .Limits .Vram .String ()).To (Equal ("16Gi" ))
478+ Expect (* tfInfo .EnabledReplicas ).To (Equal (int32 (3 )))
340479 })
341480 })
342481
0 commit comments