Skip to content

Commit f51c471

Browse files
authored
refactor(gpuallocator): use parameter struct for Alloc function (#227)
1 parent 3df147a commit f51c471

File tree

3 files changed

+44
-30
lines changed

3 files changed

+44
-30
lines changed

internal/controller/tensorfusionworkload_controller.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -357,16 +357,11 @@ func (r *TensorFusionWorkloadReconciler) handlePodGPUCleanup(ctx context.Context
357357
return types.NamespacedName{Name: gpuName}
358358
})
359359
// Release GPU resources
360-
if err := r.Allocator.Dealloc(ctx,
361-
tfv1.NameNamespace{Namespace: workload.Namespace, Name: workload.Name},
362-
workload.Spec.Resources.Requests, gpus); err != nil {
360+
if err := r.Allocator.Dealloc(ctx, tfv1.NameNamespace{Name: workload.Name, Namespace: workload.Namespace}, workload.Spec.Resources.Requests, gpus); err != nil {
363361
log.Error(err, "Failed to release GPU resources, will retry", "gpus", gpus, "pod", pod.Name)
364362
return false, err
365363
}
366364
log.Info("Released GPU resources via finalizer", "gpus", gpus, "pod", pod.Name)
367-
if pod.Annotations == nil {
368-
pod.Annotations = make(map[string]string)
369-
}
370365

371366
return true, nil
372367
}
@@ -391,7 +386,13 @@ func (r *TensorFusionWorkloadReconciler) scaleUpWorkers(ctx context.Context, wor
391386
// Create worker pods
392387
for range count {
393388
// Schedule GPU for the worker
394-
gpus, err := r.Allocator.Alloc(ctx, workload.Spec.PoolName, workloadNameNs, workload.Spec.Resources.Requests, workload.Spec.GPUCount, workload.Spec.GPUModel)
389+
gpus, err := r.Allocator.Alloc(ctx, gpuallocator.AllocRequest{
390+
PoolName: workload.Spec.PoolName,
391+
WorkloadNameNamespace: workloadNameNs,
392+
Request: workload.Spec.Resources.Requests,
393+
Count: workload.Spec.GPUCount,
394+
GPUModel: workload.Spec.GPUModel,
395+
})
395396
if err != nil {
396397
r.Recorder.Eventf(workload, corev1.EventTypeWarning, "ScheduleGPUFailed", "Failed to schedule GPU: %v", err)
397398
return ctrl.Result{RequeueAfter: constants.PendingRequeueDuration}, nil
@@ -526,7 +527,7 @@ func (r *TensorFusionWorkloadReconciler) updateStatus(
526527
func (r *TensorFusionWorkloadReconciler) SetupWithManager(mgr ctrl.Manager) error {
527528
return ctrl.NewControllerManagedBy(mgr).
528529
For(&tfv1.TensorFusionWorkload{}).
529-
Named("tensorfusionworkload").
530530
Owns(&corev1.Pod{}).
531+
Named("tensorfusionworkload").
531532
Complete(r)
532533
}

internal/gpuallocator/gpuallocator.go

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,28 +53,35 @@ type GpuAllocator struct {
5353
dirtyQueueLock sync.Mutex
5454
}
5555

56+
// AllocRequest encapsulates all parameters needed for GPU allocation
57+
type AllocRequest struct {
58+
// Name of the GPU pool to allocate from
59+
PoolName string
60+
// Namespace information for the workload
61+
WorkloadNameNamespace tfv1.NameNamespace
62+
// Resource requirements for the allocation
63+
Request tfv1.Resource
64+
// Number of GPUs to allocate
65+
Count uint
66+
// Specific GPU model to allocate, empty string means any model
67+
GPUModel string
68+
}
69+
5670
// Alloc allocates a request to a gpu or multiple gpus from the same node.
57-
func (s *GpuAllocator) Alloc(
58-
ctx context.Context,
59-
poolName string,
60-
workloadNameNamespace tfv1.NameNamespace,
61-
request tfv1.Resource,
62-
count uint,
63-
gpuModel string,
64-
) ([]*tfv1.GPU, error) {
71+
func (s *GpuAllocator) Alloc(ctx context.Context, req AllocRequest) ([]*tfv1.GPU, error) {
6572
// Get GPUs from the pool using the in-memory store
66-
poolGPUs := s.listGPUsFromPool(poolName)
73+
poolGPUs := s.listGPUsFromPool(req.PoolName)
6774

6875
// Add SameNodeFilter if count > 1 to ensure GPUs are from the same node
69-
filterRegistry := s.filterRegistry.With(filter.NewResourceFilter(request))
76+
filterRegistry := s.filterRegistry.With(filter.NewResourceFilter(req.Request))
7077

7178
// Add GPU model filter if specified
72-
if gpuModel != "" {
73-
filterRegistry = filterRegistry.With(filter.NewGPUModelFilter(gpuModel))
79+
if req.GPUModel != "" {
80+
filterRegistry = filterRegistry.With(filter.NewGPUModelFilter(req.GPUModel))
7481
}
7582

76-
if count > 1 {
77-
filterRegistry = filterRegistry.With(filter.NewSameNodeFilter(count))
83+
if req.Count > 1 {
84+
filterRegistry = filterRegistry.With(filter.NewSameNodeFilter(req.Count))
7885
}
7986

8087
// Apply the filters in sequence
@@ -84,12 +91,12 @@ func (s *GpuAllocator) Alloc(
8491
}
8592

8693
if len(filteredGPUs) == 0 {
87-
return nil, fmt.Errorf("no gpus available in pool %s after filtering", poolName)
94+
return nil, fmt.Errorf("no gpus available in pool %s after filtering", req.PoolName)
8895
}
8996

9097
pool := &tfv1.GPUPool{}
91-
if err := s.Get(ctx, client.ObjectKey{Name: poolName}, pool); err != nil {
92-
return nil, fmt.Errorf("get pool %s: %w", poolName, err)
98+
if err := s.Get(ctx, client.ObjectKey{Name: req.PoolName}, pool); err != nil {
99+
return nil, fmt.Errorf("get pool %s: %w", req.PoolName, err)
93100
}
94101

95102
schedulingConfigTemplate := &tfv1.SchedulingConfigTemplate{}
@@ -100,7 +107,7 @@ func (s *GpuAllocator) Alloc(
100107
}
101108

102109
strategy := NewStrategy(schedulingConfigTemplate.Spec.Placement.Mode)
103-
selectedGPUs, err := strategy.SelectGPUs(filteredGPUs, count)
110+
selectedGPUs, err := strategy.SelectGPUs(filteredGPUs, req.Count)
104111
if err != nil {
105112
return nil, fmt.Errorf("select GPU: %w", err)
106113
}
@@ -121,11 +128,11 @@ func (s *GpuAllocator) Alloc(
121128
}
122129

123130
// reduce available resource on the GPU status
124-
gpu.Status.Available.Tflops.Sub(request.Tflops)
125-
gpu.Status.Available.Vram.Sub(request.Vram)
131+
gpu.Status.Available.Tflops.Sub(req.Request.Tflops)
132+
gpu.Status.Available.Vram.Sub(req.Request.Vram)
126133

127134
if !appAdded {
128-
addRunningApp(ctx, gpu, workloadNameNamespace)
135+
addRunningApp(ctx, gpu, req.WorkloadNameNamespace)
129136
appAdded = true
130137
}
131138

internal/gpuallocator/gpuallocator_test.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@ var _ = Describe("GPU Allocator", func() {
3636
var allocator *GpuAllocator
3737

3838
allocateAndSync := func(poolName string, request tfv1.Resource, count uint, gpuModel string) ([]*tfv1.GPU, error) {
39-
gpus, err := allocator.Alloc(ctx, poolName, workloadNameNs, request, count, gpuModel)
39+
gpus, err := allocator.Alloc(ctx, AllocRequest{
40+
PoolName: poolName,
41+
WorkloadNameNamespace: workloadNameNs,
42+
Request: request,
43+
Count: count,
44+
GPUModel: gpuModel,
45+
})
4046
allocator.syncToK8s(ctx)
4147
return gpus, err
4248
}

0 commit comments

Comments
 (0)