Skip to content

Commit 36451d4

Browse files
authored
feat: add test for tensorfusion connection controller & fix deployment of controller (#29)
1 parent 6f455a8 commit 36451d4

File tree

5 files changed

+87
-30
lines changed

5 files changed

+87
-30
lines changed

charts/tensor-fusion/templates/controller-deployment.yaml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ spec:
4545
- name: cert
4646
readOnly: true
4747
mountPath: /tmp/k8s-webhook-server/serving-certs
48-
- name: config
49-
mountPath: /etc/tensor-fusion
5048
- name: vector
5149
image: docker.io/timberio/vector:nightly-2025-01-07-debian
5250
env:
@@ -74,10 +72,6 @@ spec:
7472
path: tls.crt
7573
- key: key
7674
path: tls.key
77-
- name: config
78-
configMap:
79-
name: {{ include "tensor-fusion.fullname" . }}-config
80-
defaultMode: 420
8175
- name: vector-config
8276
configMap:
8377
name: {{ include "tensor-fusion.fullname" . }}-vector-config

internal/controller/tensorfusionconnection_controller.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,10 @@ func (r *TensorFusionConnectionReconciler) tryStartWorker(
173173
if errors.IsNotFound(err) {
174174
// Pod doesn't exist, create a new one
175175
port := workerGenerator.AllocPort()
176-
pod = workerGenerator.GenerateWorkerPod(gpu, connection, namespacedName, port)
176+
pod, err = workerGenerator.GenerateWorkerPod(gpu, connection, namespacedName, port)
177+
if err != nil {
178+
return nil, fmt.Errorf("generate worker pod %w", err)
179+
}
177180
if err := ctrl.SetControllerReference(connection, pod, r.Scheme); err != nil {
178181
return nil, fmt.Errorf("set owner reference %w", err)
179182
}

internal/controller/tensorfusionconnection_controller_test.go

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,18 @@ import (
2121

2222
. "github.com/onsi/ginkgo/v2"
2323
. "github.com/onsi/gomega"
24+
corev1 "k8s.io/api/core/v1"
2425
"k8s.io/apimachinery/pkg/api/errors"
26+
"k8s.io/apimachinery/pkg/api/resource"
2527
"k8s.io/apimachinery/pkg/types"
2628
"sigs.k8s.io/controller-runtime/pkg/reconcile"
2729

2830
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2931

3032
tfv1 "github.com/NexusGPU/tensor-fusion-operator/api/v1"
3133
"github.com/NexusGPU/tensor-fusion-operator/internal/config"
34+
"github.com/NexusGPU/tensor-fusion-operator/internal/constants"
35+
"github.com/NexusGPU/tensor-fusion-operator/internal/scheduler"
3236
)
3337

3438
var _ = Describe("TensorFusionConnection Controller", func() {
@@ -39,27 +43,61 @@ var _ = Describe("TensorFusionConnection Controller", func() {
3943

4044
typeNamespacedName := types.NamespacedName{
4145
Name: resourceName,
42-
Namespace: "default", // TODO(user):Modify as needed
46+
Namespace: "default",
47+
}
48+
scheduler := scheduler.NewNaiveScheduler()
49+
gpu := &tfv1.GPU{
50+
ObjectMeta: metav1.ObjectMeta{
51+
Name: "mock-gpu",
52+
},
4353
}
44-
tensorfusionconnection := &tfv1.TensorFusionConnection{}
45-
4654
BeforeEach(func() {
55+
connection := &tfv1.TensorFusionConnection{}
4756
By("creating the custom resource for the Kind TensorFusionConnection")
48-
err := k8sClient.Get(ctx, typeNamespacedName, tensorfusionconnection)
57+
err := k8sClient.Get(ctx, typeNamespacedName, connection)
4958
if err != nil && errors.IsNotFound(err) {
5059
resource := &tfv1.TensorFusionConnection{
5160
ObjectMeta: metav1.ObjectMeta{
5261
Name: resourceName,
5362
Namespace: "default",
5463
},
55-
// TODO(user): Specify other spec details if needed.
64+
Spec: tfv1.TensorFusionConnectionSpec{
65+
PoolName: "mock",
66+
Resources: tfv1.Resources{
67+
Requests: tfv1.Resource{
68+
Tflops: resource.MustParse("1"),
69+
Vram: resource.MustParse("1Gi"),
70+
},
71+
Limits: tfv1.Resource{
72+
Tflops: resource.MustParse("1"),
73+
Vram: resource.MustParse("1Gi"),
74+
},
75+
},
76+
},
5677
}
5778
Expect(k8sClient.Create(ctx, resource)).To(Succeed())
5879
}
80+
81+
scheduler.OnAdd(gpu)
82+
Expect(k8sClient.Create(ctx, gpu)).To(Succeed())
83+
gpu.Status = tfv1.GPUStatus{
84+
UUID: "mock-gpu",
85+
NodeSelector: map[string]string{
86+
"kubernetes.io/hostname": "mock-node",
87+
},
88+
Capacity: tfv1.Resource{
89+
Tflops: resource.MustParse("2"),
90+
Vram: resource.MustParse("2Gi"),
91+
},
92+
Available: tfv1.Resource{
93+
Tflops: resource.MustParse("2"),
94+
Vram: resource.MustParse("2Gi"),
95+
},
96+
}
97+
Expect(k8sClient.Status().Update(ctx, gpu)).To(Succeed())
5998
})
6099

61100
AfterEach(func() {
62-
// TODO(user): Cleanup logic after each test, like removing the resource instance.
63101
resource := &tfv1.TensorFusionConnection{}
64102
err := k8sClient.Get(ctx, typeNamespacedName, resource)
65103
Expect(err).NotTo(HaveOccurred())
@@ -74,13 +112,29 @@ var _ = Describe("TensorFusionConnection Controller", func() {
74112
Client: k8sClient,
75113
Scheme: k8sClient.Scheme(),
76114
GpuPoolState: gpuPoolState,
115+
Scheduler: scheduler,
77116
}
78117
_, err := controllerReconciler.Reconcile(ctx, reconcile.Request{
79118
NamespacedName: typeNamespacedName,
80119
})
81120
Expect(err).NotTo(HaveOccurred())
82-
// TODO(user): Add more specific assertions depending on your controller's reconciliation logic.
83-
// Example: If you expect a certain status condition after reconciliation, verify it here.
121+
connection := &tfv1.TensorFusionConnection{}
122+
Expect(k8sClient.Get(ctx, typeNamespacedName, connection)).NotTo(HaveOccurred())
123+
Expect(connection.Finalizers).Should(ConsistOf(constants.Finalizer))
124+
_, err = controllerReconciler.Reconcile(ctx, reconcile.Request{
125+
NamespacedName: typeNamespacedName,
126+
})
127+
Expect(err).NotTo(HaveOccurred())
128+
Expect(k8sClient.Get(ctx, typeNamespacedName, connection)).NotTo(HaveOccurred())
129+
Expect(connection.Status.Phase).To(Equal(tfv1.TensorFusionConnectionStarting))
130+
131+
workerPod := &corev1.Pod{}
132+
Expect(k8sClient.Get(ctx, typeNamespacedName, workerPod)).NotTo(HaveOccurred())
133+
Expect(workerPod.Spec.NodeSelector).To(Equal(gpu.Status.NodeSelector))
134+
135+
Expect(k8sClient.Get(ctx, types.NamespacedName{Name: "mock-gpu"}, gpu)).NotTo(HaveOccurred())
136+
Expect(gpu.Status.Available.Tflops).To(Equal(resource.MustParse("1")))
137+
Expect(gpu.Status.Available.Vram).To(Equal(resource.MustParse("1Gi")))
84138
})
85139
})
86140
})

internal/webhook/v1/pod_webhook.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ func (m *TensorFusionPodMutator) patchTFClient(pod *corev1.Pod, clientConfig *tf
192192
// Convert the current pod to JSON
193193
currentBytes, err := json.Marshal(pod)
194194
if err != nil {
195-
return nil, fmt.Errorf("marshal current pod: %v", err)
195+
return nil, fmt.Errorf("marshal current pod: %w", err)
196196
}
197197

198198
// Patch to Container
@@ -203,19 +203,19 @@ func (m *TensorFusionPodMutator) patchTFClient(pod *corev1.Pod, clientConfig *tf
203203
// patch from config
204204
containerJSON, err := json.Marshal(container)
205205
if err != nil {
206-
return nil, fmt.Errorf("marshal container: %v", err)
206+
return nil, fmt.Errorf("marshal container: %w", err)
207207
}
208208
patchJSON, err := json.Marshal(clientConfig.PatchToContainer)
209209
if err != nil {
210-
return nil, fmt.Errorf("marshal patchToContainer: %v", err)
210+
return nil, fmt.Errorf("marshal patchToContainer: %w", err)
211211
}
212212

213213
patchedJSON, err := strategicpatch.StrategicMergePatch(containerJSON, patchJSON, corev1.Container{})
214214
if err != nil {
215-
return nil, fmt.Errorf("apply strategic merge patch to container: %v", err)
215+
return nil, fmt.Errorf("apply strategic merge patch to container: %w", err)
216216
}
217217
if err := json.Unmarshal(patchedJSON, container); err != nil {
218-
return nil, fmt.Errorf("unmarshal patched container: %v", err)
218+
return nil, fmt.Errorf("unmarshal patched container: %w", err)
219219
}
220220

221221
// add connection env
@@ -240,35 +240,35 @@ func (m *TensorFusionPodMutator) patchTFClient(pod *corev1.Pod, clientConfig *tf
240240

241241
containerPatchedJSON, err := json.Marshal(pod)
242242
if err != nil {
243-
return nil, fmt.Errorf("marshal current pod: %v", err)
243+
return nil, fmt.Errorf("marshal current pod: %w", err)
244244
}
245245
patches, err := jsonpatch.CreatePatch(currentBytes, containerPatchedJSON)
246246
if err != nil {
247-
return nil, fmt.Errorf("patch to container: %v", err)
247+
return nil, fmt.Errorf("patch to container: %w", err)
248248
}
249249

250250
// Convert the strategic merge patch to JSON
251251
patchBytes, err := json.Marshal(clientConfig.PatchToPod)
252252

253253
if err != nil {
254-
return nil, fmt.Errorf("marshal patch: %v", err)
254+
return nil, fmt.Errorf("marshal patch: %w", err)
255255
}
256256

257257
// Apply the strategic merge patch
258258
resultBytes, err := strategicpatch.StrategicMergePatch(currentBytes, patchBytes, corev1.Pod{})
259259
if err != nil {
260-
return nil, fmt.Errorf("apply strategic merge patch: %v", err)
260+
return nil, fmt.Errorf("apply strategic merge patch: %w", err)
261261
}
262262

263263
// Generate JSON patch operations by comparing original and patched pod
264264
strategicpatches, err := jsonpatch.CreatePatch(currentBytes, resultBytes)
265265
if err != nil {
266-
return nil, fmt.Errorf("create json patch: %v", err)
266+
return nil, fmt.Errorf("create json patch: %w", err)
267267
}
268268

269269
// Unmarshal the result back into the pod
270270
if err := json.Unmarshal(resultBytes, pod); err != nil {
271-
return nil, fmt.Errorf("unmarshal patched pod: %v", err)
271+
return nil, fmt.Errorf("unmarshal patched pod: %w", err)
272272
}
273273

274274
patches = append(patches, strategicpatches...)

internal/worker/worker.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package worker
22

33
import (
4+
"encoding/json"
45
"fmt"
56
"strconv"
67
"time"
@@ -44,8 +45,13 @@ func (wg *WorkerGenerator) GenerateWorkerPod(
4445
connection *tfv1.TensorFusionConnection,
4546
namespacedName types.NamespacedName,
4647
port int,
47-
) *corev1.Pod {
48-
spec := wg.WorkerConfig.PodTemplate.Object.(*corev1.PodTemplate).Template.Spec.DeepCopy()
48+
) (*corev1.Pod, error) {
49+
podTmpl := &corev1.PodTemplate{}
50+
err := json.Unmarshal(wg.WorkerConfig.PodTemplate.Raw, podTmpl)
51+
if err != nil {
52+
return nil, fmt.Errorf("failed to unmarshal pod template: %w", err)
53+
}
54+
spec := podTmpl.Template.Spec
4955
if spec.NodeSelector == nil {
5056
spec.NodeSelector = make(map[string]string)
5157
}
@@ -64,6 +70,6 @@ func (wg *WorkerGenerator) GenerateWorkerPod(
6470
Name: namespacedName.Name,
6571
Namespace: namespacedName.Namespace,
6672
},
67-
Spec: *spec,
68-
}
73+
Spec: spec,
74+
}, nil
6975
}

0 commit comments

Comments
 (0)