Skip to content

Commit 4b57ac2

Browse files
authored
fix(webhook): enabledReplicas hash key should be returned even if annotations is nil and add test (#139)
1 parent 1339592 commit 4b57ac2

File tree

3 files changed

+145
-7
lines changed

3 files changed

+145
-7
lines changed

internal/webhook/v1/pod_counter.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ func getOrGenerateKey(pod *corev1.Pod) string {
2626
}
2727
// Try to use pod-template-hash if present
2828
if hash, ok := pod.Labels["pod-template-hash"]; ok && hash != "" {
29-
return hash
29+
return fmt.Sprintf("%s/tf-counter-%s", constants.Domain, hash)
3030
}
3131

3232
// Fallback to object hash
33-
return utils.GetObjectHash(pod)
33+
return fmt.Sprintf("%s/tf-counter-%s", constants.Domain, utils.GetObjectHash(pod))
3434
}
3535

3636
// Get gets the counter value from the owner annotation by key
@@ -49,11 +49,11 @@ func (c *TensorFusionPodCounter) Get(ctx context.Context, pod *corev1.Pod) (int3
4949
}
5050
annotations := ownerObj.GetAnnotations()
5151
if annotations == nil {
52-
return 0, "", nil
52+
return 0, key, nil
5353
}
5454
val, ok := annotations[key]
5555
if !ok || val == "" {
56-
return 0, "", nil
56+
return 0, key, nil
5757
}
5858
count, err := strconv.ParseInt(val, 10, 32)
5959
if err != nil {

internal/webhook/v1/pod_webhook_test.go

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package v1
1919
import (
2020
"context"
2121
"encoding/json"
22+
"fmt"
2223
"net/http"
2324

2425
tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
@@ -27,6 +28,7 @@ import (
2728
. "github.com/onsi/ginkgo/v2"
2829
. "github.com/onsi/gomega"
2930
admissionv1 "k8s.io/api/admission/v1"
31+
appsv1 "k8s.io/api/apps/v1"
3032
corev1 "k8s.io/api/core/v1"
3133
"k8s.io/apimachinery/pkg/api/resource"
3234
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -283,6 +285,141 @@ var _ = Describe("TensorFusionPodMutator", func() {
283285
})
284286
})
285287

288+
Context("Handle with EnabledReplicas", func() {
289+
It("should only patch enabledReplicas pods", func() {
290+
// Create a ReplicaSet as the owner for the pod
291+
replicaSet := &appsv1.ReplicaSet{
292+
ObjectMeta: metav1.ObjectMeta{
293+
Name: "test-rs",
294+
Namespace: "default",
295+
},
296+
Spec: appsv1.ReplicaSetSpec{
297+
Selector: &metav1.LabelSelector{
298+
MatchLabels: map[string]string{
299+
"app": "test-app",
300+
},
301+
},
302+
Template: corev1.PodTemplateSpec{
303+
ObjectMeta: metav1.ObjectMeta{
304+
Labels: map[string]string{
305+
"app": "test-app",
306+
},
307+
},
308+
Spec: corev1.PodSpec{
309+
Containers: []corev1.Container{
310+
{
311+
Name: "test-container",
312+
Image: "test-image",
313+
},
314+
},
315+
},
316+
},
317+
},
318+
}
319+
320+
Expect(k8sclient.Create(ctx, replicaSet)).To(Succeed())
321+
322+
// Get the ReplicaSet to obtain its UID
323+
createdReplicaSet := &appsv1.ReplicaSet{}
324+
Expect(k8sclient.Get(ctx, client.ObjectKey{Namespace: "default", Name: "test-rs"}, createdReplicaSet)).To(Succeed())
325+
replicaSetUID := createdReplicaSet.GetUID()
326+
327+
// Create a workload profile
328+
workloadProfile := &tfv1.WorkloadProfile{
329+
ObjectMeta: metav1.ObjectMeta{
330+
Name: "test-profile-enabled-replicas",
331+
Namespace: "default",
332+
},
333+
Spec: tfv1.WorkloadProfileSpec{
334+
PoolName: "mock",
335+
Resources: tfv1.Resources{
336+
Requests: tfv1.Resource{
337+
Tflops: resource.MustParse("10"),
338+
Vram: resource.MustParse("1Gi"),
339+
},
340+
Limits: tfv1.Resource{
341+
Tflops: resource.MustParse("100"),
342+
Vram: resource.MustParse("16Gi"),
343+
},
344+
},
345+
},
346+
}
347+
Expect(k8sclient.Create(ctx, workloadProfile)).To(Succeed())
348+
349+
// Create a pod with TF resources and owner reference
350+
trueVal := true
351+
enabledReplicas := int32(1)
352+
353+
pod := &corev1.Pod{
354+
ObjectMeta: metav1.ObjectMeta{
355+
Namespace: "default",
356+
GenerateName: "test-pod-enabled-replicas-",
357+
Labels: map[string]string{
358+
constants.TensorFusionEnabledLabelKey: "true",
359+
"pod-template-hash": "test-hash",
360+
},
361+
Annotations: map[string]string{
362+
constants.GpuPoolKey: "mock",
363+
constants.WorkloadProfileAnnotation: "test-profile-enabled-replicas",
364+
constants.InjectContainerAnnotation: "main",
365+
constants.WorkloadKey: "test-workload",
366+
constants.TensorFusionEnabledReplicasAnnotation: fmt.Sprintf("%d", enabledReplicas), // Using the correct constant
367+
},
368+
OwnerReferences: []metav1.OwnerReference{
369+
{
370+
APIVersion: "apps/v1",
371+
Kind: "ReplicaSet",
372+
Name: "test-rs",
373+
UID: replicaSetUID,
374+
Controller: &trueVal,
375+
},
376+
},
377+
},
378+
Spec: corev1.PodSpec{
379+
Containers: []corev1.Container{
380+
{
381+
Name: "main",
382+
Image: "test-image",
383+
},
384+
},
385+
},
386+
}
387+
388+
podBytes, err := json.Marshal(pod)
389+
Expect(err).NotTo(HaveOccurred())
390+
391+
req := admission.Request{
392+
AdmissionRequest: admissionv1.AdmissionRequest{
393+
Object: runtime.RawExtension{
394+
Raw: podBytes,
395+
},
396+
Operation: admissionv1.Create,
397+
},
398+
}
399+
400+
resp := mutator.Handle(ctx, req)
401+
// First call: Pod mutation should occur since enabledReplicas is 1,
402+
// so the response should be allowed and contain patches
403+
Expect(resp.Allowed).To(BeTrue())
404+
Expect(resp.Patches).NotTo(BeEmpty())
405+
406+
counter := &TensorFusionPodCounter{Client: k8sclient}
407+
count, _, err := counter.Get(ctx, pod)
408+
Expect(err).NotTo(HaveOccurred())
409+
Expect(count).To(Equal(int32(1)))
410+
411+
resp = mutator.Handle(ctx, req)
412+
// Second call: Pod should be ignored since it's been processed already,
413+
// so the response should be allowed but patches should be empty
414+
Expect(resp.Allowed).To(BeTrue())
415+
Expect(resp.Patches).To(BeEmpty())
416+
417+
// Clean up
418+
Expect(k8sclient.Delete(ctx, replicaSet)).To(Succeed())
419+
Expect(k8sclient.Delete(ctx, workloadProfile)).To(Succeed())
420+
})
421+
})
422+
286423
Context("ParseTensorFusionInfo", func() {
287424
It("should correctly parse TF requirements from pod annotations", func() {
288425
// Set up a workload profile for testing
@@ -316,8 +453,9 @@ var _ = Describe("TensorFusionPodMutator", func() {
316453
constants.WorkloadProfileAnnotation: "test-profile-parse-tf-resources",
317454
constants.WorkloadKey: "test-workload",
318455
// override tflops request
319-
constants.TFLOPSRequestAnnotation: "20",
320-
constants.InjectContainerAnnotation: "test-container",
456+
constants.TFLOPSRequestAnnotation: "20",
457+
constants.InjectContainerAnnotation: "test-container",
458+
constants.TensorFusionEnabledReplicasAnnotation: "3",
321459
},
322460
},
323461
Spec: corev1.PodSpec{
@@ -337,6 +475,7 @@ var _ = Describe("TensorFusionPodMutator", func() {
337475
Expect(tfInfo.Profile.Resources.Requests.Vram.String()).To(Equal("1Gi"))
338476
Expect(tfInfo.Profile.Resources.Limits.Tflops.String()).To(Equal("100"))
339477
Expect(tfInfo.Profile.Resources.Limits.Vram.String()).To(Equal("16Gi"))
478+
Expect(*tfInfo.EnabledReplicas).To(Equal(int32(3)))
340479
})
341480
})
342481

internal/webhook/v1/webhook_suite_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ var _ = BeforeSuite(func() {
153153
if err != nil {
154154
return err
155155
}
156-
157156
return conn.Close()
158157
}).Should(Succeed())
159158
})

0 commit comments

Comments
 (0)