Skip to content

Commit 9591b78

Browse files
authored
refactor: optimize the order of Pods when scaling down (#324)
1 parent 649c72d commit 9591b78

File tree

1 file changed

+36
-33
lines changed

1 file changed

+36
-33
lines changed

internal/controller/tensorfusionworkload_controller.go

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2727
"k8s.io/apimachinery/pkg/runtime"
2828
"k8s.io/client-go/tools/record"
29+
"k8s.io/kubernetes/pkg/controller"
2930
ctrl "sigs.k8s.io/controller-runtime"
3031
"sigs.k8s.io/controller-runtime/pkg/client"
3132
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
@@ -38,7 +39,6 @@ import (
3839
"github.com/NexusGPU/tensor-fusion/internal/portallocator"
3940
"github.com/NexusGPU/tensor-fusion/internal/utils"
4041
"github.com/NexusGPU/tensor-fusion/internal/worker"
41-
"github.com/samber/lo"
4242
)
4343

4444
// TensorFusionWorkloadReconciler reconciles a TensorFusionWorkload object
@@ -78,9 +78,7 @@ func (r *TensorFusionWorkloadReconciler) Reconcile(ctx context.Context, req ctrl
7878
return ctrl.Result{}, fmt.Errorf("list pods: %w", err)
7979
}
8080
// only calculate state based on not deleted pods, otherwise will cause wrong total replica count
81-
podList.Items = lo.Filter(podList.Items, func(pod corev1.Pod, _ int) bool {
82-
return pod.DeletionTimestamp.IsZero()
83-
})
81+
activePods := filterActivePods(podList)
8482

8583
// handle finalizer
8684
shouldReturn, err := utils.HandleFinalizer(ctx, workload, r.Client, func(ctx context.Context, workload *tfv1.TensorFusionWorkload) (bool, error) {
@@ -91,10 +89,10 @@ func (r *TensorFusionWorkloadReconciler) Reconcile(ctx context.Context, req ctrl
9189

9290
// fixed replica mode which created by user, should trigger pod deletion and stop scale up
9391
// when all pods are deleted, finalizer will be removed
94-
if len(podList.Items) == 0 {
92+
if len(activePods) == 0 {
9593
return true, nil
9694
}
97-
if err := r.scaleDownWorkers(ctx, workload, podList.Items); err != nil {
95+
if err := r.scaleDownWorkers(ctx, workload, activePods); err != nil {
9896
return false, err
9997
}
10098
return false, nil
@@ -140,13 +138,13 @@ func (r *TensorFusionWorkloadReconciler) Reconcile(ctx context.Context, req ctrl
140138
// In this mode, allow any Pod select connection to connect to any worker,
141139
// to achieve a sub-pool for lower costs when CPU side scaling frequency is high
142140
if !workload.Spec.IsDynamicReplica() {
143-
err := r.reconcileScaling(ctx, workload, podList, workerGenerator, podTemplateHash)
141+
err := r.reconcileScaling(ctx, workload, activePods, workerGenerator, podTemplateHash)
144142
if err != nil {
145143
return ctrl.Result{}, err
146144
}
147145
}
148146

149-
if err := r.updateStatus(ctx, workload, podList.Items); err != nil {
147+
if err := r.updateStatus(ctx, workload, activePods); err != nil {
150148
return ctrl.Result{}, err
151149
}
152150

@@ -157,23 +155,22 @@ func (r *TensorFusionWorkloadReconciler) Reconcile(ctx context.Context, req ctrl
157155
func (r *TensorFusionWorkloadReconciler) reconcileScaling(
158156
ctx context.Context,
159157
workload *tfv1.TensorFusionWorkload,
160-
podList *corev1.PodList,
158+
activePods []*corev1.Pod,
161159
workerGenerator *worker.WorkerGenerator,
162160
podTemplateHash string,
163161
) error {
164162
log := log.FromContext(ctx)
165163
// Check if there are any Pods using the old podTemplateHash and delete them if any
166-
if len(podList.Items) > 0 {
164+
if len(activePods) > 0 {
167165
// make oldest pod first, to delete from oldest to latest outdated pod
168-
sort.Slice(podList.Items, func(i, j int) bool {
169-
return podList.Items[i].CreationTimestamp.Before(&podList.Items[j].CreationTimestamp)
166+
sort.Slice(activePods, func(i, j int) bool {
167+
return activePods[i].CreationTimestamp.Before(&activePods[j].CreationTimestamp)
170168
})
171169

172-
var outdatedPods []corev1.Pod
173-
for i := range podList.Items {
174-
pod := &podList.Items[i]
170+
var outdatedPods []*corev1.Pod
171+
for _, pod := range activePods {
175172
if pod.Labels[constants.LabelKeyPodTemplateHash] != podTemplateHash {
176-
outdatedPods = append(outdatedPods, *pod)
173+
outdatedPods = append(outdatedPods, pod)
177174
}
178175
}
179176

@@ -194,7 +191,7 @@ func (r *TensorFusionWorkloadReconciler) reconcileScaling(
194191
}
195192

196193
// Count current replicas
197-
currentReplicas := int32(len(podList.Items))
194+
currentReplicas := int32(len(activePods))
198195
log.Info("Current replicas", "count", currentReplicas, "desired", desiredReplicas)
199196

200197
// Update workload status
@@ -205,26 +202,23 @@ func (r *TensorFusionWorkloadReconciler) reconcileScaling(
205202
}
206203
}
207204

205+
diff := currentReplicas - desiredReplicas
208206
// Scale up if needed
209-
if currentReplicas < desiredReplicas {
207+
if diff < 0 {
210208
log.Info("Scaling up workers", "from", currentReplicas, "to", desiredReplicas)
211209

212-
// Calculate how many pods need to be added
213-
podsToAdd := int(desiredReplicas - currentReplicas)
214-
if err := r.scaleUpWorkers(ctx, workerGenerator, workload, podsToAdd, podTemplateHash); err != nil {
210+
if err := r.scaleUpWorkers(ctx, workerGenerator, workload, int(-diff), podTemplateHash); err != nil {
215211
return fmt.Errorf("scale up workers: %w", err)
216212
}
217-
} else if currentReplicas > desiredReplicas {
213+
} else if diff > 0 {
218214
log.Info("Scaling down workers", "from", currentReplicas, "to", desiredReplicas)
219215

220-
// Sort pods by creation time (oldest first)
221-
sort.Slice(podList.Items, func(i, j int) bool {
222-
return podList.Items[i].CreationTimestamp.Before(&podList.Items[j].CreationTimestamp)
223-
})
216+
// No need to sort if we are about to delete all pods
217+
if diff < int32(len(activePods)) {
218+
sort.Sort(controller.ActivePods(activePods))
219+
}
224220

225-
// Calculate how many pods need to be removed
226-
podsToRemove := int(currentReplicas - desiredReplicas)
227-
if err := r.scaleDownWorkers(ctx, workload, podList.Items[:podsToRemove]); err != nil {
221+
if err := r.scaleDownWorkers(ctx, workload, activePods[:diff]); err != nil {
228222
return err
229223
}
230224
}
@@ -259,10 +253,9 @@ func (r *TensorFusionWorkloadReconciler) tryStartWorker(
259253
}
260254

261255
// scaleDownWorkers handles the scaling down of worker pods
262-
func (r *TensorFusionWorkloadReconciler) scaleDownWorkers(ctx context.Context, workload *tfv1.TensorFusionWorkload, pods []corev1.Pod) error {
256+
func (r *TensorFusionWorkloadReconciler) scaleDownWorkers(ctx context.Context, workload *tfv1.TensorFusionWorkload, pods []*corev1.Pod) error {
263257
log := log.FromContext(ctx)
264-
for i := range pods {
265-
podToDelete := &pods[i]
258+
for _, podToDelete := range pods {
266259
log.Info("Scaling down worker pod", "name", podToDelete.Name, "workload", workload.Name)
267260

268261
// If it's already being deleting, should avoid call delete multiple times
@@ -316,7 +309,7 @@ func (r *TensorFusionWorkloadReconciler) scaleUpWorkers(ctx context.Context, wor
316309
func (r *TensorFusionWorkloadReconciler) updateStatus(
317310
ctx context.Context,
318311
workload *tfv1.TensorFusionWorkload,
319-
pods []corev1.Pod,
312+
pods []*corev1.Pod,
320313
) error {
321314
log := log.FromContext(ctx)
322315
readyReplicas := int32(0)
@@ -396,6 +389,16 @@ func (r *TensorFusionWorkloadReconciler) updateStatus(
396389
return nil
397390
}
398391

392+
func filterActivePods(podList *corev1.PodList) []*corev1.Pod {
393+
var activePods []*corev1.Pod
394+
for _, pod := range podList.Items {
395+
if pod.DeletionTimestamp.IsZero() {
396+
activePods = append(activePods, &pod)
397+
}
398+
}
399+
return activePods
400+
}
401+
399402
// SetupWithManager sets up the controller with the Manager.
400403
func (r *TensorFusionWorkloadReconciler) SetupWithManager(mgr ctrl.Manager) error {
401404
return ctrl.NewControllerManagedBy(mgr).

0 commit comments

Comments
 (0)