Skip to content

Commit 5f6b101

Browse files
authored
fix(webhook): set the workload's owner Reference to the pod's root owner (#158)
- Enhance pod_webhook_test.go to verify workload owner reference propagation in tests
1 parent 2b75da8 commit 5f6b101

File tree

3 files changed

+64
-0
lines changed

3 files changed

+64
-0
lines changed

internal/utils/owner_ref_utils.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package utils
2+
3+
import (
4+
context "context"
5+
6+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
7+
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
8+
"sigs.k8s.io/controller-runtime/pkg/client"
9+
)
10+
11+
// FindRootOwnerReference recursively finds the root owner reference for a given object (e.g. Pod).
12+
func FindRootOwnerReference(ctx context.Context, c client.Client, namespace string, obj metav1.Object) (*metav1.OwnerReference, error) {
13+
current := obj
14+
for {
15+
owners := current.GetOwnerReferences()
16+
if len(owners) == 0 {
17+
return nil, nil // no owner, this is root
18+
}
19+
ownerRef := owners[0]
20+
// Try to get the owner object as unstructured
21+
unObj := &unstructured.Unstructured{}
22+
unObj.SetAPIVersion(ownerRef.APIVersion)
23+
unObj.SetKind(ownerRef.Kind)
24+
key := client.ObjectKey{Name: ownerRef.Name, Namespace: namespace}
25+
err := c.Get(ctx, key, unObj)
26+
if err != nil {
27+
// If not found, treat this ownerRef as root
28+
return &ownerRef, nil
29+
}
30+
// Cast back to metav1.Object if possible
31+
if metaObj, ok := any(unObj).(metav1.Object); ok {
32+
current = metaObj
33+
} else {
34+
return &ownerRef, nil
35+
}
36+
}
37+
}

internal/webhook/v1/pod_webhook.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,12 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload(ctx context.Context, pod
160160
if !errors.IsNotFound(err) {
161161
return fmt.Errorf("failed to get workload: %w", err)
162162
}
163+
// find root owner references of pod
164+
rootOwnerRef, err := utils.FindRootOwnerReference(ctx, m.Client, pod.Namespace, pod)
165+
if err != nil {
166+
return fmt.Errorf("failed to find root owner reference: %w", err)
167+
}
168+
163169
// Create a new workload
164170
replicas := tfInfo.Replicas
165171
workload = &tfv1.TensorFusionWorkload{
@@ -176,6 +182,10 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload(ctx context.Context, pod
176182
},
177183
}
178184

185+
if rootOwnerRef != nil {
186+
workload.OwnerReferences = []metav1.OwnerReference{*rootOwnerRef}
187+
}
188+
179189
if err := m.Client.Create(ctx, workload); err != nil {
180190
return fmt.Errorf("failed to create workload: %w", err)
181191
}

internal/webhook/v1/pod_webhook_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333
"k8s.io/apimachinery/pkg/api/resource"
3434
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3535
"k8s.io/apimachinery/pkg/runtime"
36+
"k8s.io/utils/ptr"
3637
"sigs.k8s.io/controller-runtime/pkg/client"
3738
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
3839
)
@@ -133,6 +134,13 @@ var _ = Describe("TensorFusionPodMutator", func() {
133134
Labels: map[string]string{
134135
constants.TensorFusionEnabledLabelKey: "true",
135136
},
137+
OwnerReferences: []metav1.OwnerReference{{
138+
APIVersion: "apps/v1",
139+
Kind: "Deployment",
140+
Name: "owner",
141+
UID: "owner-uid",
142+
Controller: ptr.To(true),
143+
}},
136144
Annotations: map[string]string{
137145
constants.GpuPoolKey: "mock",
138146
constants.WorkloadProfileAnnotation: "test-profile-handle",
@@ -179,6 +187,15 @@ var _ = Describe("TensorFusionPodMutator", func() {
179187
resp := mutator.Handle(ctx, req)
180188
Expect(resp.Allowed).To(BeTrue())
181189
Expect(resp.Patches).NotTo(BeEmpty())
190+
191+
// Check workload created
192+
workload := &tfv1.TensorFusionWorkload{}
193+
err = k8sClient.Get(ctx, client.ObjectKey{Name: "test-workload", Namespace: "default"}, workload)
194+
Expect(err).NotTo(HaveOccurred())
195+
Expect(*workload.Spec.Replicas).To(Equal(int32(1)))
196+
// check workload owner reference
197+
Expect(workload.OwnerReferences).To(HaveLen(1))
198+
Expect(workload.OwnerReferences[0].Name).To(Equal("owner"))
182199
})
183200

184201
It("should handle pods without TF requirements", func() {

0 commit comments

Comments
 (0)