Skip to content

Commit c465980

Browse files
authored
feat: create worker pod using GenerateName (#65)
1 parent 56a1d17 commit c465980

File tree

3 files changed

+45
-58
lines changed

3 files changed

+45
-58
lines changed

internal/constants/constants.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@ const (
4444
ConnectionNameEnv = "TENSOR_FUSION_CONNECTION_NAME"
4545
ConnectionNamespaceEnv = "TENSOR_FUSION_CONNECTION_NAMESPACE"
4646

47-
WorkerPortEnv = "TENSOR_FUSION_WORKER_PORT"
48-
WokerCudaUpLimitEnv = "TENSOR_FUSION_CUDA_UP_LIMIT"
49-
WokerCudaMemLimitEnv = "TENSOR_FUSION_CUDA_MEM_LIMIT"
50-
NamespaceEnv = "OPERATOR_NAMESPACE"
51-
NamespaceDefaultVal = "tensor-fusion"
47+
WorkerPortEnv = "TENSOR_FUSION_WORKER_PORT"
48+
WorkerCudaUpLimitEnv = "TENSOR_FUSION_CUDA_UP_LIMIT"
49+
WorkerCudaMemLimitEnv = "TENSOR_FUSION_CUDA_MEM_LIMIT"
50+
WorkerPodNameEnv = "POD_NAME"
51+
NamespaceEnv = "OPERATOR_NAMESPACE"
52+
NamespaceDefaultVal = "tensor-fusion"
5253
)
5354

5455
const (

internal/controller/tensorfusionworkload_controller.go

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
corev1 "k8s.io/api/core/v1"
2525
"k8s.io/apimachinery/pkg/api/errors"
2626
"k8s.io/apimachinery/pkg/runtime"
27-
"k8s.io/apimachinery/pkg/types"
2827
"k8s.io/client-go/tools/record"
2928
ctrl "sigs.k8s.io/controller-runtime"
3029
"sigs.k8s.io/controller-runtime/pkg/client"
@@ -129,7 +128,7 @@ func (r *TensorFusionWorkloadReconciler) Reconcile(ctx context.Context, req ctrl
129128

130129
// Calculate how many pods need to be added
131130
podsToAdd := int(desiredReplicas - currentReplicas)
132-
if err := r.scaleUpWorkers(ctx, workerGenerator, workload, podsToAdd, req.Namespace); err != nil {
131+
if err := r.scaleUpWorkers(ctx, workerGenerator, workload, podsToAdd); err != nil {
133132
return ctrl.Result{}, err
134133
}
135134
} else if currentReplicas > desiredReplicas {
@@ -159,38 +158,28 @@ func (r *TensorFusionWorkloadReconciler) tryStartWorker(
159158
workerGenerator *worker.WorkerGenerator,
160159
gpu *tfv1.GPU,
161160
workload *tfv1.TensorFusionWorkload,
162-
namespacedName types.NamespacedName,
163161
) (*corev1.Pod, error) {
164-
// Try to get the Pod
165-
pod := &corev1.Pod{}
166-
if err := r.Get(ctx, namespacedName, pod); err != nil {
167-
if errors.IsNotFound(err) {
168-
// Pod doesn't exist, create a new one
169-
port := workerGenerator.AllocPort()
170-
pod, err = workerGenerator.GenerateWorkerPod(gpu, namespacedName, port, workload.Spec.Resources.Limits)
171-
if err != nil {
172-
return nil, fmt.Errorf("generate worker pod %w", err)
173-
}
162+
port := workerGenerator.AllocPort()
163+
pod, err := workerGenerator.GenerateWorkerPod(gpu, workload.Name, workload.Namespace, port, workload.Spec.Resources.Limits)
164+
if err != nil {
165+
return nil, fmt.Errorf("generate worker pod %w", err)
166+
}
174167

175-
// Add labels to identify this pod as part of the workload
176-
if pod.Labels == nil {
177-
pod.Labels = make(map[string]string)
178-
}
179-
pod.Labels[constants.WorkloadKey] = workload.Name
180-
pod.Labels[constants.GpuKey] = gpu.Name
168+
// Add labels to identify this pod as part of the workload
169+
if pod.Labels == nil {
170+
pod.Labels = make(map[string]string)
171+
}
172+
pod.Labels[constants.WorkloadKey] = workload.Name
173+
pod.Labels[constants.GpuKey] = gpu.Name
181174

182-
// Add finalizer for GPU resource cleanup
183-
pod.Finalizers = append(pod.Finalizers, constants.Finalizer)
175+
// Add finalizer for GPU resource cleanup
176+
pod.Finalizers = append(pod.Finalizers, constants.Finalizer)
184177

185-
if err := ctrl.SetControllerReference(workload, pod, r.Scheme); err != nil {
186-
return nil, fmt.Errorf("set owner reference %w", err)
187-
}
188-
if err := r.Create(ctx, pod); err != nil {
189-
return nil, fmt.Errorf("create pod %w", err)
190-
}
191-
return pod, nil
192-
}
193-
return nil, err
178+
if err := ctrl.SetControllerReference(workload, pod, r.Scheme); err != nil {
179+
return nil, fmt.Errorf("set owner reference %w", err)
180+
}
181+
if err := r.Create(ctx, pod); err != nil {
182+
return nil, fmt.Errorf("create pod %w", err)
194183
}
195184
return pod, nil
196185
}
@@ -270,27 +259,19 @@ func (r *TensorFusionWorkloadReconciler) deletePod(ctx context.Context, pod *cor
270259
}
271260

272261
// scaleUpWorkers handles the scaling up of worker pods
273-
func (r *TensorFusionWorkloadReconciler) scaleUpWorkers(ctx context.Context, workerGenerator *worker.WorkerGenerator, workload *tfv1.TensorFusionWorkload, count int, namespace string) error {
262+
func (r *TensorFusionWorkloadReconciler) scaleUpWorkers(ctx context.Context, workerGenerator *worker.WorkerGenerator, workload *tfv1.TensorFusionWorkload, count int) error {
274263
log := log.FromContext(ctx)
275264

276265
// Create worker pods
277-
currentCount := int(workload.Status.Replicas)
278-
for i := range count {
266+
for range count {
279267
// Schedule GPU for the worker
280268
gpu, err := r.Scheduler.Schedule(ctx, workload.Spec.PoolName, workload.Spec.Resources.Requests)
281269
if err != nil {
282270
r.Recorder.Eventf(workload, corev1.EventTypeWarning, "ScheduleGPUFailed", "Failed to schedule GPU: %v", err)
283271
return fmt.Errorf("schedule GPU: %w", err)
284272
}
285273

286-
// Create worker pod
287-
workerName := fmt.Sprintf("%s-worker-%d", workload.Name, currentCount+i)
288-
namespacedName := types.NamespacedName{
289-
Namespace: namespace,
290-
Name: workerName,
291-
}
292-
293-
_, err = r.tryStartWorker(ctx, workerGenerator, gpu, workload, namespacedName)
274+
_, err = r.tryStartWorker(ctx, workerGenerator, gpu, workload)
294275
if err != nil {
295276
// Try to release the GPU resource if pod creation fails
296277
releaseErr := r.Scheduler.Release(ctx, workload.Spec.Resources.Requests, gpu)
@@ -299,8 +280,6 @@ func (r *TensorFusionWorkloadReconciler) scaleUpWorkers(ctx context.Context, wor
299280
}
300281
return fmt.Errorf("create worker pod: %w", err)
301282
}
302-
303-
log.Info("Created worker pod", "name", workerName)
304283
}
305284

306285
return nil

internal/worker/worker.go

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7-
"path"
87
"strconv"
98
"time"
109

@@ -14,7 +13,6 @@ import (
1413
"golang.org/x/exp/rand"
1514
corev1 "k8s.io/api/core/v1"
1615
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
17-
"k8s.io/apimachinery/pkg/types"
1816
"sigs.k8s.io/controller-runtime/pkg/client"
1917
)
2018

@@ -46,7 +44,8 @@ func (wg *WorkerGenerator) AllocPort() int {
4644

4745
func (wg *WorkerGenerator) GenerateWorkerPod(
4846
gpu *tfv1.GPU,
49-
namespacedName types.NamespacedName,
47+
generateName string,
48+
namespace string,
5049
port int,
5150
limits tfv1.Resource,
5251
) (*corev1.Pod, error) {
@@ -64,13 +63,14 @@ func (wg *WorkerGenerator) GenerateWorkerPod(
6463
Name: constants.DataVolumeName,
6564
VolumeSource: corev1.VolumeSource{
6665
HostPath: &corev1.HostPathVolumeSource{
67-
Path: path.Join(constants.TFDataPath, namespacedName.Name),
66+
Path: constants.TFDataPath,
6867
},
6968
},
7069
})
7170
spec.Containers[0].VolumeMounts = append(spec.Containers[0].VolumeMounts, corev1.VolumeMount{
72-
Name: constants.DataVolumeName,
73-
MountPath: constants.TFDataPath,
71+
Name: constants.DataVolumeName,
72+
MountPath: constants.TFDataPath,
73+
SubPathExpr: fmt.Sprintf("${%s}", constants.WorkerPodNameEnv),
7474
})
7575

7676
spec.Containers[0].Env = append(spec.Containers[0].Env, corev1.EnvVar{
@@ -80,19 +80,26 @@ func (wg *WorkerGenerator) GenerateWorkerPod(
8080
Name: constants.WorkerPortEnv,
8181
Value: strconv.Itoa(port),
8282
}, corev1.EnvVar{
83-
Name: constants.WokerCudaUpLimitEnv,
83+
Name: constants.WorkerCudaUpLimitEnv,
8484
// TODO: convert tflops to percent
8585
Value: "100",
8686
}, corev1.EnvVar{
87-
Name: constants.WokerCudaMemLimitEnv,
87+
Name: constants.WorkerCudaMemLimitEnv,
8888
// bytesize
8989
Value: strconv.FormatInt(limits.Vram.Value(), 10),
90+
}, corev1.EnvVar{
91+
Name: constants.WorkerPodNameEnv,
92+
ValueFrom: &corev1.EnvVarSource{
93+
FieldRef: &corev1.ObjectFieldSelector{
94+
FieldPath: "metadata.name",
95+
},
96+
},
9097
})
9198

9299
return &corev1.Pod{
93100
ObjectMeta: metav1.ObjectMeta{
94-
Name: namespacedName.Name,
95-
Namespace: namespacedName.Namespace,
101+
GenerateName: generateName,
102+
Namespace: namespace,
96103
},
97104
Spec: spec,
98105
}, nil

0 commit comments

Comments
 (0)