Skip to content

Commit 9794292

Browse files
authored
feat: add GPU node selector and UUID support (#23)
- Add UUID and NodeSelector fields to GPUStatus - Update worker pod generation to use GPU node selector - Add NVIDIA_VISIBLE_DEVICES env var with GPU UUID - Simplify connection URL generation by removing unused GPU parameter
1 parent bc0e54b commit 9794292

File tree

5 files changed

+38
-8
lines changed

5 files changed

+38
-8
lines changed

api/v1/gpu_types.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ import (
2222

2323
// GPUStatus defines the observed state of GPU.
2424
type GPUStatus struct {
25-
Capacity Resource `json:"capacity"`
26-
Available Resource `json:"available"`
25+
UUID string `json:"uuid"`
26+
NodeSelector map[string]string `json:"nodeSelector"`
27+
Capacity Resource `json:"capacity"`
28+
Available Resource `json:"available"`
2729
}
2830

2931
// +kubebuilder:object:root=true

api/v1/zz_generated.deepcopy.go

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

config/crd/bases/tensor-fusion.ai_gpus.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,17 @@ spec:
7575
- tflops
7676
- vram
7777
type: object
78+
nodeSelector:
79+
additionalProperties:
80+
type: string
81+
type: object
82+
uuid:
83+
type: string
7884
required:
7985
- available
8086
- capacity
87+
- nodeSelector
88+
- uuid
8189
type: object
8290
type: object
8391
served: true

internal/controller/tensorfusionconnection_controller.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,15 @@ func (r *TensorFusionConnectionReconciler) Reconcile(ctx context.Context, req ct
118118

119119
if connection.Status.Phase != tfv1.TensorFusionConnectionPending {
120120
// Start worker job
121-
workerPod, err := r.tryStartWorker(ctx, connection, types.NamespacedName{Name: connection.Name, Namespace: connection.Namespace})
121+
workerPod, err := r.tryStartWorker(ctx, gpu, connection, types.NamespacedName{Name: connection.Name, Namespace: connection.Namespace})
122122
if err != nil {
123123
log.Error(err, "Failed to start worker pod")
124124
return ctrl.Result{}, err
125125
}
126126

127127
if workerPod.Status.Phase == corev1.PodRunning {
128128
connection.Status.Phase = tfv1.TensorFusionConnectionRunning
129-
connection.Status.ConnectionURL = r.WorkerGenerator.GenerateConnectionURL(gpu, connection, workerPod)
129+
connection.Status.ConnectionURL = r.WorkerGenerator.GenerateConnectionURL(connection, workerPod)
130130
}
131131
// TODO: Handle PodFailure
132132
}
@@ -143,13 +143,13 @@ func (r *TensorFusionConnectionReconciler) Reconcile(ctx context.Context, req ct
143143
return ctrl.Result{}, nil
144144
}
145145

146-
func (r *TensorFusionConnectionReconciler) tryStartWorker(ctx context.Context, connection *tfv1.TensorFusionConnection, namespacedName types.NamespacedName) (*corev1.Pod, error) {
146+
func (r *TensorFusionConnectionReconciler) tryStartWorker(ctx context.Context, gpu *tfv1.GPU, connection *tfv1.TensorFusionConnection, namespacedName types.NamespacedName) (*corev1.Pod, error) {
147147
// Try to get the Pod
148148
pod := &corev1.Pod{}
149149
if err := r.Get(ctx, namespacedName, pod); err != nil {
150150
if errors.IsNotFound(err) {
151151
// Pod doesn't exist, create a new one
152-
pod = r.WorkerGenerator.GenerateWorkerPod(connection, namespacedName)
152+
pod = r.WorkerGenerator.GenerateWorkerPod(gpu, connection, namespacedName)
153153
if err := ctrl.SetControllerReference(connection, pod, r.Scheme); err != nil {
154154
return nil, fmt.Errorf("set owner reference %w", err)
155155
}

internal/worker/worker.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,32 @@ type WorkerGenerator struct {
1414
WorkerConfig *config.Worker
1515
}
1616

17-
func (wg *WorkerGenerator) GenerateConnectionURL(_gpu *tfv1.GPU, connection *tfv1.TensorFusionConnection, pod *corev1.Pod) string {
17+
func (wg *WorkerGenerator) GenerateConnectionURL(connection *tfv1.TensorFusionConnection, pod *corev1.Pod) string {
1818
return fmt.Sprintf("native+%s+%d", pod.Status.PodIP, wg.WorkerConfig.Port)
1919
}
2020

2121
func (wg *WorkerGenerator) GenerateWorkerPod(
22+
gpu *tfv1.GPU,
2223
connection *tfv1.TensorFusionConnection,
2324
namespacedName types.NamespacedName,
2425
) *corev1.Pod {
26+
27+
spec := wg.WorkerConfig.Template.Spec
28+
if spec.NodeSelector == nil {
29+
spec.NodeSelector = make(map[string]string)
30+
}
31+
spec.NodeSelector = gpu.Status.NodeSelector
32+
33+
spec.Containers[0].Env = append(spec.Containers[0].Env, corev1.EnvVar{
34+
Name: "NVIDIA_VISIBLE_DEVICES",
35+
Value: gpu.Status.UUID,
36+
})
37+
2538
return &corev1.Pod{
2639
ObjectMeta: metav1.ObjectMeta{
2740
Name: namespacedName.Name,
2841
Namespace: namespacedName.Namespace,
2942
},
30-
Spec: wg.WorkerConfig.Template.Spec,
43+
Spec: spec,
3144
}
3245
}

0 commit comments

Comments
 (0)