Skip to content

Commit c14f67d

Browse files
authored
fix: helm typo and patch gpu node bug (#301)
* fix: helm typo * fix: node discovery patch gpu node status issue
1 parent d338a41 commit c14f67d

File tree

5 files changed

+313
-28
lines changed

5 files changed

+313
-28
lines changed

api/v1/gpunode_types.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,15 @@ type GPUNodeStatus struct {
6363
TotalTFlops resource.Quantity `json:"totalTFlops"`
6464
TotalVRAM resource.Quantity `json:"totalVRAM"`
6565

66-
VirtualTFlops resource.Quantity `json:"virtualTFlops"`
67-
VirtualVRAM resource.Quantity `json:"virtualVRAM"`
66+
// +optional
67+
VirtualTFlops resource.Quantity `json:"virtualTFlops,omitempty"`
68+
// +optional
69+
VirtualVRAM resource.Quantity `json:"virtualVRAM,omitempty"`
6870

69-
AvailableTFlops resource.Quantity `json:"availableTFlops"`
70-
AvailableVRAM resource.Quantity `json:"availableVRAM"`
71+
// +optional
72+
AvailableTFlops resource.Quantity `json:"availableTFlops,omitempty"`
73+
// +optional
74+
AvailableVRAM resource.Quantity `json:"availableVRAM,omitempty"`
7175

7276
// +optional
7377
VirtualAvailableTFlops *resource.Quantity `json:"virtualAvailableTFlops,omitempty"`

charts/tensor-fusion/crds/tensor-fusion.ai_gpunodes.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,15 +261,11 @@ spec:
261261
pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
262262
x-kubernetes-int-or-string: true
263263
required:
264-
- availableTFlops
265-
- availableVRAM
266264
- managedGPUs
267265
- phase
268266
- totalGPUs
269267
- totalTFlops
270268
- totalVRAM
271-
- virtualTFlops
272-
- virtualVRAM
273269
type: object
274270
type: object
275271
served: true

cmd/nodediscovery/main.go

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -169,29 +169,27 @@ func main() {
169169
availableVRAM.Add(gpu.Status.Available.Vram)
170170
}
171171

172-
// Use proper patch-based update with retry on conflict
173172
err = retry.RetryOnConflict(retry.DefaultBackoff, func() error {
174-
// Get the latest version of the resource
175-
currentGPUNode := &tfv1.GPUNode{}
176-
if err := k8sClient.Get(ctx, client.ObjectKeyFromObject(gpunode), currentGPUNode); err != nil {
177-
return err
178-
}
179-
180-
// Create a patch from the original to the desired state
181-
patch := client.MergeFrom(currentGPUNode.DeepCopy())
182-
183-
// Update status fields conditionally
184-
updateGPUNodeStatus(&currentGPUNode.Status, totalTFlops, totalVRAM, int32(count), allDeviceIDs)
185-
186-
// Apply the patch using the status subresource
187-
return k8sClient.Status().Patch(ctx, currentGPUNode, patch)
173+
return patchGPUNodeStatus(k8sClient, ctx, gpunode, totalTFlops, totalVRAM, int32(count), allDeviceIDs)
188174
})
189175
if err != nil {
190-
ctrl.Log.Error(err, "failed to update status of GPUNode after retries")
176+
ctrl.Log.Error(err, "failed to patch status of GPUNode after retries")
191177
os.Exit(1)
192178
}
193179
}
194180

181+
// Use proper patch-based update with retry on conflict
182+
func patchGPUNodeStatus(k8sClient client.Client, ctx context.Context, gpunode *tfv1.GPUNode, totalTFlops resource.Quantity, totalVRAM resource.Quantity, count int32, allDeviceIDs []string) error {
183+
184+
currentGPUNode := &tfv1.GPUNode{}
185+
if err := k8sClient.Get(ctx, client.ObjectKeyFromObject(gpunode), currentGPUNode); err != nil {
186+
return err
187+
}
188+
patch := client.MergeFrom(currentGPUNode.DeepCopy())
189+
updateGPUNodeStatus(&currentGPUNode.Status, totalTFlops, totalVRAM, int32(count), allDeviceIDs)
190+
return k8sClient.Status().Patch(ctx, currentGPUNode, patch)
191+
}
192+
195193
func createOrUpdateTensorFusionGPU(
196194
k8sClient client.Client, ctx context.Context, k8sNodeName string, gpunode *tfv1.GPUNode,
197195
uuid string, deviceName string, memInfo nvml.Memory_v2, tflops resource.Quantity) *tfv1.GPU {

cmd/nodediscovery/main_test.go

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"fmt"
56
"testing"
67
"time"
78

@@ -12,6 +13,7 @@ import (
1213
"k8s.io/apimachinery/pkg/api/resource"
1314
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1415
"k8s.io/apimachinery/pkg/runtime"
16+
"sigs.k8s.io/controller-runtime/pkg/client"
1517
"sigs.k8s.io/controller-runtime/pkg/client/fake"
1618
)
1719

@@ -106,3 +108,292 @@ func TestGPUControllerReference(t *testing.T) {
106108
assert.True(t, metav1.IsControlledBy(gpu, newGpuNode))
107109
assert.False(t, metav1.IsControlledBy(gpu, gpuNode))
108110
}
111+
112+
func TestPatchGPUNodeStatus(t *testing.T) {
113+
tests := []struct {
114+
name string
115+
setupGPUNode func() *tfv1.GPUNode
116+
totalTFlops resource.Quantity
117+
totalVRAM resource.Quantity
118+
count int32
119+
allDeviceIDs []string
120+
expectError bool
121+
validateResult func(t *testing.T, originalNode, patchedNode *tfv1.GPUNode)
122+
}{
123+
{
124+
name: "successful patch with empty phase",
125+
setupGPUNode: func() *tfv1.GPUNode {
126+
return &tfv1.GPUNode{
127+
ObjectMeta: metav1.ObjectMeta{
128+
Name: "test-gpu-node",
129+
Namespace: "default",
130+
},
131+
Status: tfv1.GPUNodeStatus{
132+
Phase: "", // Empty phase should be set to pending
133+
TotalTFlops: resource.MustParse("50"),
134+
TotalVRAM: resource.MustParse("8Gi"),
135+
TotalGPUs: 2,
136+
},
137+
}
138+
},
139+
totalTFlops: resource.MustParse("100"),
140+
totalVRAM: resource.MustParse("16Gi"),
141+
count: 4,
142+
allDeviceIDs: []string{"gpu-0", "gpu-1", "gpu-2", "gpu-3"},
143+
expectError: false,
144+
validateResult: func(t *testing.T, originalNode, patchedNode *tfv1.GPUNode) {
145+
// Verify status fields were updated
146+
assert.Equal(t, resource.MustParse("100"), patchedNode.Status.TotalTFlops)
147+
assert.Equal(t, resource.MustParse("16Gi"), patchedNode.Status.TotalVRAM)
148+
assert.Equal(t, int32(4), patchedNode.Status.TotalGPUs)
149+
assert.Equal(t, int32(4), patchedNode.Status.ManagedGPUs)
150+
assert.Equal(t, []string{"gpu-0", "gpu-1", "gpu-2", "gpu-3"}, patchedNode.Status.ManagedGPUDeviceIDs)
151+
assert.Equal(t, tfv1.TensorFusionGPUNodePhasePending, patchedNode.Status.Phase)
152+
// Verify NodeInfo was updated
153+
assert.True(t, patchedNode.Status.NodeInfo.RAMSize.Value() > 0)
154+
assert.True(t, patchedNode.Status.NodeInfo.DataDiskSize.Value() > 0)
155+
},
156+
},
157+
{
158+
name: "successful patch with existing phase preserved",
159+
setupGPUNode: func() *tfv1.GPUNode {
160+
return &tfv1.GPUNode{
161+
ObjectMeta: metav1.ObjectMeta{
162+
Name: "test-gpu-node-running",
163+
Namespace: "default",
164+
},
165+
Status: tfv1.GPUNodeStatus{
166+
Phase: tfv1.TensorFusionGPUNodePhaseRunning,
167+
TotalTFlops: resource.MustParse("200"),
168+
TotalVRAM: resource.MustParse("32Gi"),
169+
TotalGPUs: 8,
170+
},
171+
}
172+
},
173+
totalTFlops: resource.MustParse("150"),
174+
totalVRAM: resource.MustParse("24Gi"),
175+
count: 6,
176+
allDeviceIDs: []string{"gpu-0", "gpu-1", "gpu-2", "gpu-3", "gpu-4", "gpu-5"},
177+
expectError: false,
178+
validateResult: func(t *testing.T, originalNode, patchedNode *tfv1.GPUNode) {
179+
// Verify status fields were updated
180+
assert.Equal(t, resource.MustParse("150"), patchedNode.Status.TotalTFlops)
181+
assert.Equal(t, resource.MustParse("24Gi"), patchedNode.Status.TotalVRAM)
182+
assert.Equal(t, int32(6), patchedNode.Status.TotalGPUs)
183+
assert.Equal(t, int32(6), patchedNode.Status.ManagedGPUs)
184+
assert.Equal(t, []string{"gpu-0", "gpu-1", "gpu-2", "gpu-3", "gpu-4", "gpu-5"}, patchedNode.Status.ManagedGPUDeviceIDs)
185+
// Verify existing phase was preserved
186+
assert.Equal(t, tfv1.TensorFusionGPUNodePhaseRunning, patchedNode.Status.Phase)
187+
},
188+
},
189+
{
190+
name: "zero resources handled correctly",
191+
setupGPUNode: func() *tfv1.GPUNode {
192+
return &tfv1.GPUNode{
193+
ObjectMeta: metav1.ObjectMeta{
194+
Name: "test-gpu-node-zero",
195+
Namespace: "default",
196+
},
197+
Status: tfv1.GPUNodeStatus{
198+
Phase: "",
199+
},
200+
}
201+
},
202+
totalTFlops: resource.MustParse("0"),
203+
totalVRAM: resource.MustParse("0"),
204+
count: 0,
205+
allDeviceIDs: []string{},
206+
expectError: false,
207+
validateResult: func(t *testing.T, originalNode, patchedNode *tfv1.GPUNode) {
208+
assert.Equal(t, resource.MustParse("0"), patchedNode.Status.TotalTFlops)
209+
assert.Equal(t, resource.MustParse("0"), patchedNode.Status.TotalVRAM)
210+
assert.Equal(t, int32(0), patchedNode.Status.TotalGPUs)
211+
assert.Equal(t, int32(0), patchedNode.Status.ManagedGPUs)
212+
assert.Empty(t, patchedNode.Status.ManagedGPUDeviceIDs)
213+
assert.Equal(t, tfv1.TensorFusionGPUNodePhasePending, patchedNode.Status.Phase)
214+
},
215+
},
216+
}
217+
218+
for _, tt := range tests {
219+
t.Run(tt.name, func(t *testing.T) {
220+
ctx := context.Background()
221+
gpuNode := tt.setupGPUNode()
222+
223+
// Setup fake client with the GPUNode
224+
scheme := runtime.NewScheme()
225+
_ = tfv1.AddToScheme(scheme)
226+
k8sClient := fake.NewClientBuilder().
227+
WithScheme(scheme).
228+
WithStatusSubresource(&tfv1.GPUNode{}).
229+
WithObjects(gpuNode).
230+
Build()
231+
232+
// Store original state for comparison
233+
originalNode := gpuNode.DeepCopy()
234+
235+
// Call the function under test
236+
err := patchGPUNodeStatus(k8sClient, ctx, gpuNode, tt.totalTFlops, tt.totalVRAM, tt.count, tt.allDeviceIDs)
237+
238+
// Verify error expectation
239+
if tt.expectError {
240+
assert.Error(t, err, "Expected an error but got none")
241+
return
242+
}
243+
assert.NoError(t, err, "Unexpected error")
244+
245+
// Get the updated GPUNode from the client to verify the patch was applied
246+
updatedNode := &tfv1.GPUNode{}
247+
err = k8sClient.Get(ctx, client.ObjectKeyFromObject(gpuNode), updatedNode)
248+
assert.NoError(t, err, "Failed to get updated GPUNode")
249+
250+
// Run custom validation
251+
if tt.validateResult != nil {
252+
tt.validateResult(t, originalNode, updatedNode)
253+
}
254+
})
255+
}
256+
}
257+
258+
func TestPatchGPUNodeStatus_ErrorScenarios(t *testing.T) {
259+
tests := []struct {
260+
name string
261+
setupClient func() client.Client
262+
setupGPUNode func() *tfv1.GPUNode
263+
expectedErr string
264+
}{
265+
{
266+
name: "GPUNode not found error",
267+
setupClient: func() client.Client {
268+
// Create client without the GPUNode object
269+
scheme := runtime.NewScheme()
270+
_ = tfv1.AddToScheme(scheme)
271+
return fake.NewClientBuilder().
272+
WithScheme(scheme).
273+
WithStatusSubresource(&tfv1.GPUNode{}).
274+
Build()
275+
},
276+
setupGPUNode: func() *tfv1.GPUNode {
277+
return &tfv1.GPUNode{
278+
ObjectMeta: metav1.ObjectMeta{
279+
Name: "nonexistent-gpu-node",
280+
Namespace: "default",
281+
},
282+
}
283+
},
284+
expectedErr: "not found",
285+
},
286+
}
287+
288+
for _, tt := range tests {
289+
t.Run(tt.name, func(t *testing.T) {
290+
ctx := context.Background()
291+
k8sClient := tt.setupClient()
292+
gpuNode := tt.setupGPUNode()
293+
294+
// Call the function under test
295+
err := patchGPUNodeStatus(k8sClient, ctx, gpuNode,
296+
resource.MustParse("100"),
297+
resource.MustParse("16Gi"),
298+
4,
299+
[]string{"gpu-0", "gpu-1", "gpu-2", "gpu-3"})
300+
301+
// Verify the expected error occurred
302+
assert.Error(t, err, "Expected an error but got none")
303+
assert.Contains(t, err.Error(), tt.expectedErr, "Error message should contain expected text")
304+
})
305+
}
306+
}
307+
308+
func TestPatchGPUNodeStatus_Integration(t *testing.T) {
309+
// Integration test that verifies the complete flow
310+
ctx := context.Background()
311+
312+
// Setup initial GPUNode
313+
gpuNode := &tfv1.GPUNode{
314+
ObjectMeta: metav1.ObjectMeta{
315+
Name: "integration-test-node",
316+
Namespace: "default",
317+
},
318+
Status: tfv1.GPUNodeStatus{
319+
Phase: "",
320+
TotalTFlops: resource.MustParse("10"),
321+
TotalVRAM: resource.MustParse("2Gi"),
322+
TotalGPUs: 1,
323+
ManagedGPUs: 0, // Different from TotalGPUs to test sync
324+
ManagedGPUDeviceIDs: []string{"old-device"},
325+
NodeInfo: tfv1.GPUNodeInfo{
326+
RAMSize: resource.MustParse("1Gi"),
327+
DataDiskSize: resource.MustParse("1Gi"),
328+
},
329+
},
330+
}
331+
332+
// Setup fake client
333+
scheme := runtime.NewScheme()
334+
_ = tfv1.AddToScheme(scheme)
335+
k8sClient := fake.NewClientBuilder().
336+
WithScheme(scheme).
337+
WithStatusSubresource(&tfv1.GPUNode{}).
338+
WithObjects(gpuNode).
339+
Build()
340+
341+
// Test multiple sequential patches to verify state consistency
342+
updates := []struct {
343+
totalTFlops resource.Quantity
344+
totalVRAM resource.Quantity
345+
count int32
346+
allDeviceIDs []string
347+
}{
348+
{
349+
totalTFlops: resource.MustParse("100"),
350+
totalVRAM: resource.MustParse("16Gi"),
351+
count: 4,
352+
allDeviceIDs: []string{"gpu-0", "gpu-1", "gpu-2", "gpu-3"},
353+
},
354+
{
355+
totalTFlops: resource.MustParse("200"),
356+
totalVRAM: resource.MustParse("32Gi"),
357+
count: 8,
358+
allDeviceIDs: []string{"gpu-0", "gpu-1", "gpu-2", "gpu-3", "gpu-4", "gpu-5", "gpu-6", "gpu-7"},
359+
},
360+
{
361+
totalTFlops: resource.MustParse("50"),
362+
totalVRAM: resource.MustParse("8Gi"),
363+
count: 2,
364+
allDeviceIDs: []string{"gpu-0", "gpu-1"},
365+
},
366+
}
367+
368+
for i, update := range updates {
369+
t.Run(fmt.Sprintf("update_%d", i+1), func(t *testing.T) {
370+
// Apply the patch
371+
err := patchGPUNodeStatus(k8sClient, ctx, gpuNode, update.totalTFlops, update.totalVRAM, update.count, update.allDeviceIDs)
372+
assert.NoError(t, err, "Patch should succeed")
373+
374+
// Verify the update was applied
375+
updatedNode := &tfv1.GPUNode{}
376+
err = k8sClient.Get(ctx, client.ObjectKeyFromObject(gpuNode), updatedNode)
377+
assert.NoError(t, err, "Should be able to get updated node")
378+
379+
// Verify all fields were updated correctly
380+
assert.Equal(t, update.totalTFlops, updatedNode.Status.TotalTFlops)
381+
assert.Equal(t, update.totalVRAM, updatedNode.Status.TotalVRAM)
382+
assert.Equal(t, update.count, updatedNode.Status.TotalGPUs)
383+
assert.Equal(t, update.count, updatedNode.Status.ManagedGPUs)
384+
assert.Equal(t, update.allDeviceIDs, updatedNode.Status.ManagedGPUDeviceIDs)
385+
386+
// Phase should be set to pending only on first update
387+
if i == 0 {
388+
assert.Equal(t, tfv1.TensorFusionGPUNodePhasePending, updatedNode.Status.Phase)
389+
} else {
390+
// Should remain pending on subsequent updates
391+
assert.Equal(t, tfv1.TensorFusionGPUNodePhasePending, updatedNode.Status.Phase)
392+
}
393+
394+
// NodeInfo should be updated with system values
395+
assert.True(t, updatedNode.Status.NodeInfo.RAMSize.Value() > 0)
396+
assert.True(t, updatedNode.Status.NodeInfo.DataDiskSize.Value() > 0)
397+
})
398+
}
399+
}

0 commit comments

Comments
 (0)