Skip to content

Commit 02f6fcb

Browse files
authored
feat: support canary (gray) TensorFusion-enabled Pods (#133)
* feat: support canary (gray) TensorFusion-enabled Pods * fix comments * fix: delete annotation when count reaches zero and add counter key
1 parent d47a725 commit 02f6fcb

File tree

8 files changed

+360
-8
lines changed

8 files changed

+360
-8
lines changed

internal/constants/constants.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ const (
3737
WorkloadProfileAnnotation = Domain + "/client-profile"
3838
InjectContainerAnnotation = Domain + "/inject-container"
3939
ReplicasAnnotation = Domain + "/replicas"
40-
GenWorkload = Domain + "/generate-workload"
40+
GenWorkloadAnnotation = Domain + "/generate-workload"
41+
42+
TensorFusionPodCounterKeyAnnotation = Domain + "/pod-counter-key"
43+
TensorFusionPodCountAnnotation = Domain + "/tf-pod-count"
44+
TensorFusionEnabledReplicasAnnotation = Domain + "/enabled-replicas"
4145

4246
PendingRequeueDuration = time.Second * 3
4347
StatusCheckInterval = time.Second * 6

internal/controller/pod_controller.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import (
2222

2323
tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
2424
"github.com/NexusGPU/tensor-fusion/internal/constants"
25+
"github.com/NexusGPU/tensor-fusion/internal/utils"
26+
v1 "github.com/NexusGPU/tensor-fusion/internal/webhook/v1"
2527
"github.com/samber/lo"
2628
corev1 "k8s.io/api/core/v1"
2729
"k8s.io/apimachinery/pkg/api/errors"
@@ -51,6 +53,22 @@ func (r *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R
5153
log := log.FromContext(ctx)
5254
pod := &corev1.Pod{}
5355

56+
if _, ok := pod.Annotations[constants.TensorFusionEnabledReplicasAnnotation]; ok {
57+
deleted, err := utils.HandleFinalizer(ctx, pod, r.Client, func(context context.Context, pod *corev1.Pod) (bool, error) {
58+
counter := &v1.TensorFusionPodCounter{Client: r.Client}
59+
if err := counter.Decrease(ctx, pod); err != nil {
60+
return false, err
61+
}
62+
return true, nil
63+
})
64+
if err != nil {
65+
return ctrl.Result{}, err
66+
}
67+
if deleted {
68+
return ctrl.Result{}, nil
69+
}
70+
}
71+
5472
if err := r.Get(ctx, req.NamespacedName, pod); err != nil {
5573
if errors.IsNotFound(err) {
5674
return ctrl.Result{}, nil

internal/webhook/v1/pod_counter.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package v1
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"strconv"
7+
8+
"github.com/NexusGPU/tensor-fusion/internal/constants"
9+
"github.com/NexusGPU/tensor-fusion/internal/utils"
10+
corev1 "k8s.io/api/core/v1"
11+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
12+
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
13+
"sigs.k8s.io/controller-runtime/pkg/client"
14+
)
15+
16+
type TensorFusionPodCounter struct {
17+
Client client.Client
18+
}
19+
20+
// getOrGenerateKey returns the pod's counter key from annotation if present, otherwise generates one from pod template labels (e.g. pod-template-hash or fallback to object hash)
21+
func getOrGenerateKey(pod *corev1.Pod) string {
22+
if pod.Annotations != nil {
23+
if key, ok := pod.Annotations[constants.TensorFusionPodCounterKeyAnnotation]; ok && key != "" {
24+
return key
25+
}
26+
}
27+
// Try to use pod-template-hash if present
28+
if hash, ok := pod.Labels["pod-template-hash"]; ok && hash != "" {
29+
return hash
30+
}
31+
32+
// Fallback to object hash
33+
return utils.GetObjectHash(pod)
34+
}
35+
36+
// Get gets the counter value from the owner annotation by key
37+
func (c *TensorFusionPodCounter) Get(ctx context.Context, pod *corev1.Pod) (int32, string, error) {
38+
ownerRef := getControllerOwnerRef(pod)
39+
if ownerRef == nil {
40+
return 0, "", fmt.Errorf("no controller owner reference found for pod %s/%s", pod.Namespace, pod.Name)
41+
}
42+
key := getOrGenerateKey(pod)
43+
ownerObj := &unstructured.Unstructured{}
44+
ownerObj.SetAPIVersion(ownerRef.APIVersion)
45+
ownerObj.SetKind(ownerRef.Kind)
46+
objKey := client.ObjectKey{Name: ownerRef.Name, Namespace: pod.Namespace}
47+
if err := c.Client.Get(ctx, objKey, ownerObj); err != nil {
48+
return 0, "", fmt.Errorf("failed to get owner object: %w", err)
49+
}
50+
annotations := ownerObj.GetAnnotations()
51+
if annotations == nil {
52+
return 0, "", nil
53+
}
54+
val, ok := annotations[key]
55+
if !ok || val == "" {
56+
return 0, "", nil
57+
}
58+
count, err := strconv.ParseInt(val, 10, 32)
59+
if err != nil {
60+
return 0, "", fmt.Errorf("invalid count annotation: %s, err: %w", val, err)
61+
}
62+
return int32(count), key, nil
63+
}
64+
65+
// Increase increases the counter in owner annotation by key
66+
func (c *TensorFusionPodCounter) Increase(ctx context.Context, pod *corev1.Pod) error {
67+
ownerRef := getControllerOwnerRef(pod)
68+
if ownerRef == nil {
69+
return fmt.Errorf("no controller owner reference found for pod %s/%s", pod.Namespace, pod.Name)
70+
}
71+
key := getOrGenerateKey(pod)
72+
ownerObj := &unstructured.Unstructured{}
73+
ownerObj.SetAPIVersion(ownerRef.APIVersion)
74+
ownerObj.SetKind(ownerRef.Kind)
75+
objKey := client.ObjectKey{Name: ownerRef.Name, Namespace: pod.Namespace}
76+
if err := c.Client.Get(ctx, objKey, ownerObj); err != nil {
77+
return fmt.Errorf("failed to get owner object: %w", err)
78+
}
79+
annotations := ownerObj.GetAnnotations()
80+
if annotations == nil {
81+
annotations = map[string]string{}
82+
}
83+
val := annotations[key]
84+
if val == "" {
85+
val = "0"
86+
}
87+
count, err := strconv.ParseInt(val, 10, 32)
88+
if err != nil {
89+
return fmt.Errorf("invalid count annotation: %s, err: %w", val, err)
90+
}
91+
count++
92+
annotations[key] = fmt.Sprintf("%d", count)
93+
ownerObj.SetAnnotations(annotations)
94+
if err := c.Client.Update(ctx, ownerObj); err != nil {
95+
return fmt.Errorf("failed to update owner annotation: %w", err)
96+
}
97+
return nil
98+
}
99+
100+
// Decrease decreases the counter in owner annotation by key
101+
func (c *TensorFusionPodCounter) Decrease(ctx context.Context, pod *corev1.Pod) error {
102+
ownerRef := getControllerOwnerRef(pod)
103+
if ownerRef == nil {
104+
return fmt.Errorf("no controller owner reference found for pod %s/%s", pod.Namespace, pod.Name)
105+
}
106+
key := getOrGenerateKey(pod)
107+
ownerObj := &unstructured.Unstructured{}
108+
ownerObj.SetAPIVersion(ownerRef.APIVersion)
109+
ownerObj.SetKind(ownerRef.Kind)
110+
objKey := client.ObjectKey{Name: ownerRef.Name, Namespace: pod.Namespace}
111+
if err := c.Client.Get(ctx, objKey, ownerObj); err != nil {
112+
return fmt.Errorf("failed to get owner object: %w", err)
113+
}
114+
annotations := ownerObj.GetAnnotations()
115+
if annotations == nil {
116+
annotations = map[string]string{}
117+
}
118+
val := annotations[key]
119+
if val == "" {
120+
val = "0"
121+
}
122+
count, err := strconv.ParseInt(val, 10, 32)
123+
if err != nil {
124+
return fmt.Errorf("invalid count annotation: %s, err: %w", val, err)
125+
}
126+
count--
127+
if count <= 0 {
128+
delete(annotations, key)
129+
} else {
130+
annotations[key] = fmt.Sprintf("%d", count)
131+
}
132+
ownerObj.SetAnnotations(annotations)
133+
if err := c.Client.Update(ctx, ownerObj); err != nil {
134+
return fmt.Errorf("failed to update owner annotation: %w", err)
135+
}
136+
return nil
137+
}
138+
139+
// getControllerOwnerRef returns the controller owner reference of a pod
140+
func getControllerOwnerRef(pod *corev1.Pod) *metav1.OwnerReference {
141+
for i, ref := range pod.OwnerReferences {
142+
if ref.Controller != nil && *ref.Controller {
143+
return &pod.OwnerReferences[i]
144+
}
145+
}
146+
return nil
147+
}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package v1
2+
3+
import (
4+
"context"
5+
6+
"github.com/NexusGPU/tensor-fusion/internal/constants"
7+
. "github.com/onsi/ginkgo/v2"
8+
. "github.com/onsi/gomega"
9+
appsv1 "k8s.io/api/apps/v1"
10+
corev1 "k8s.io/api/core/v1"
11+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
12+
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
13+
"k8s.io/utils/ptr"
14+
"sigs.k8s.io/controller-runtime/pkg/client"
15+
)
16+
17+
var _ = Describe("TensorFusionPodCounter", func() {
18+
var (
19+
counter *TensorFusionPodCounter
20+
ctx context.Context
21+
pod *corev1.Pod
22+
owner *appsv1.Deployment
23+
)
24+
25+
BeforeEach(func() {
26+
ctx = context.Background()
27+
counter = &TensorFusionPodCounter{Client: k8sClient}
28+
pod = &corev1.Pod{
29+
ObjectMeta: metav1.ObjectMeta{
30+
Name: "test-pod",
31+
Namespace: "default",
32+
Annotations: map[string]string{
33+
constants.TensorFusionPodCounterKeyAnnotation: "my-key",
34+
},
35+
Labels: map[string]string{
36+
"pod-template-hash": "hash123",
37+
},
38+
OwnerReferences: []metav1.OwnerReference{{
39+
APIVersion: "apps/v1",
40+
Kind: "Deployment",
41+
Name: "owner",
42+
Controller: ptr.To(true),
43+
}},
44+
},
45+
}
46+
owner = &appsv1.Deployment{
47+
ObjectMeta: metav1.ObjectMeta{
48+
Name: "owner",
49+
Namespace: "default",
50+
Annotations: map[string]string{},
51+
},
52+
Spec: appsv1.DeploymentSpec{
53+
Selector: &metav1.LabelSelector{
54+
MatchLabels: map[string]string{"app": "dummy"},
55+
},
56+
Template: corev1.PodTemplateSpec{
57+
ObjectMeta: metav1.ObjectMeta{
58+
Labels: map[string]string{"app": "dummy"},
59+
},
60+
Spec: corev1.PodSpec{
61+
Containers: []corev1.Container{{
62+
Name: "dummy",
63+
Image: "busybox",
64+
Command: []string{"sleep", "3600"},
65+
}},
66+
},
67+
},
68+
},
69+
}
70+
Expect(k8sClient.Create(ctx, owner)).To(Succeed())
71+
})
72+
73+
AfterEach(func() {
74+
Expect(k8sClient.Delete(ctx, owner)).To(Succeed())
75+
})
76+
77+
It("should get 0 if annotation not set", func() {
78+
val, _, err := counter.Get(ctx, pod)
79+
Expect(err).NotTo(HaveOccurred())
80+
Expect(val).To(Equal(int32(0)))
81+
})
82+
83+
It("should increase and get the counter", func() {
84+
Expect(counter.Increase(ctx, pod)).To(Succeed())
85+
val, _, err := counter.Get(ctx, pod)
86+
Expect(err).NotTo(HaveOccurred())
87+
Expect(val).To(Equal(int32(1)))
88+
})
89+
90+
It("should increase twice and get the correct value", func() {
91+
Expect(counter.Increase(ctx, pod)).To(Succeed())
92+
Expect(counter.Increase(ctx, pod)).To(Succeed())
93+
val, _, err := counter.Get(ctx, pod)
94+
Expect(err).NotTo(HaveOccurred())
95+
Expect(val).To(Equal(int32(2)))
96+
})
97+
98+
It("should decrease the counter", func() {
99+
Expect(counter.Increase(ctx, pod)).To(Succeed())
100+
Expect(counter.Decrease(ctx, pod)).To(Succeed())
101+
val, _, err := counter.Get(ctx, pod)
102+
Expect(err).NotTo(HaveOccurred())
103+
Expect(val).To(Equal(int32(0)))
104+
})
105+
106+
It("should not go below zero", func() {
107+
Expect(counter.Decrease(ctx, pod)).To(Succeed())
108+
val, _, err := counter.Get(ctx, pod)
109+
Expect(err).NotTo(HaveOccurred())
110+
Expect(val).To(Equal(int32(0)))
111+
})
112+
113+
It("should return error if owner not found", func() {
114+
pod.OwnerReferences[0].Name = "notfound"
115+
_, _, err := counter.Get(ctx, pod)
116+
Expect(err).To(HaveOccurred())
117+
})
118+
119+
It("should delete annotation key when count reaches zero", func() {
120+
// Increase
121+
Expect(counter.Increase(ctx, pod)).To(Succeed())
122+
// Decrease to 0
123+
Expect(counter.Decrease(ctx, pod)).To(Succeed())
124+
125+
// Get owner object
126+
ownerRef := getControllerOwnerRef(pod)
127+
ownerObj := &unstructured.Unstructured{}
128+
ownerObj.SetAPIVersion(ownerRef.APIVersion)
129+
ownerObj.SetKind(ownerRef.Kind)
130+
objKey := client.ObjectKey{Name: ownerRef.Name, Namespace: pod.Namespace}
131+
Expect(counter.Client.Get(ctx, objKey, ownerObj)).To(Succeed())
132+
annotations := ownerObj.GetAnnotations()
133+
key := getOrGenerateKey(pod)
134+
_, exists := annotations[key]
135+
Expect(exists).To(BeFalse())
136+
})
137+
})

internal/webhook/v1/pod_webhook.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,22 @@ func (m *TensorFusionPodMutator) Handle(ctx context.Context, req admission.Reque
7373
if err != nil {
7474
return admission.Errored(http.StatusInternalServerError, fmt.Errorf("parse tf resources: %w", err))
7575
}
76+
counter := &TensorFusionPodCounter{Client: m.Client}
77+
enabledReplicas := tfInfo.EnabledReplicas
78+
79+
var podCounterAnnotationKey string
80+
if enabledReplicas != nil {
81+
// Get `tf-pod-count` by querying the owner's annotation
82+
// and then decide whether to patch the current pod
83+
podCount, podCounterKey, err := counter.Get(ctx, pod)
84+
if err != nil {
85+
return admission.Errored(http.StatusInternalServerError, fmt.Errorf("get tf pod count: %w", err))
86+
}
87+
if podCount >= *enabledReplicas {
88+
return admission.Allowed("tf pod count exceeds enabled replicas")
89+
}
90+
podCounterAnnotationKey = podCounterKey
91+
}
7692

7793
workload := &tfv1.TensorFusionWorkload{}
7894
if tfInfo.GenWorkload {
@@ -108,6 +124,19 @@ func (m *TensorFusionPodMutator) Handle(ctx context.Context, req admission.Reque
108124
return admission.Errored(http.StatusInternalServerError, err)
109125
}
110126

127+
if podCounterAnnotationKey != "" {
128+
if err := counter.Increase(ctx, pod); err != nil {
129+
return admission.Errored(http.StatusInternalServerError, fmt.Errorf("increase tf pod count: %w", err))
130+
}
131+
// Patch annotation for pod counter
132+
patch := jsonpatch.JsonPatchOperation{
133+
Operation: "add",
134+
Path: "/metadata/annotations/" + constants.TensorFusionPodCounterKeyAnnotation,
135+
Value: podCounterAnnotationKey,
136+
}
137+
patches = append(patches, patch)
138+
}
139+
111140
return admission.Patched("tensor fusion component patched", patches...)
112141
}
113142

internal/webhook/v1/pod_webhook_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ var _ = Describe("TensorFusionPodMutator", func() {
9595
constants.WorkloadProfileAnnotation: "test-profile-handle",
9696
constants.InjectContainerAnnotation: "main",
9797
constants.WorkloadKey: "test-workload",
98-
constants.GenWorkload: "true",
98+
constants.GenWorkloadAnnotation: "true",
9999
},
100100
},
101101
Spec: corev1.PodSpec{

0 commit comments

Comments
 (0)