Skip to content

Commit aaa252d

Browse files
authored
fix: calculate hash based on whole spec (#175)
* fix: calculate hash based on whole spec * fix: pod template hash issue
1 parent 4aabf21 commit aaa252d

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

internal/controller/tensorfusionworkload_controller.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"sort"
2323

2424
corev1 "k8s.io/api/core/v1"
25+
"k8s.io/apimachinery/pkg/api/equality"
2526
"k8s.io/apimachinery/pkg/api/errors"
2627
"k8s.io/apimachinery/pkg/runtime"
2728
"k8s.io/client-go/tools/record"
@@ -31,8 +32,6 @@ import (
3132

3233
"slices"
3334

34-
"reflect"
35-
3635
tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
3736
"github.com/NexusGPU/tensor-fusion/internal/config"
3837
"github.com/NexusGPU/tensor-fusion/internal/constants"
@@ -132,7 +131,7 @@ func (r *TensorFusionWorkloadReconciler) Reconcile(ctx context.Context, req ctrl
132131
// Create worker generator
133132
workerGenerator := &worker.WorkerGenerator{WorkerConfig: pool.Spec.ComponentConfig.Worker, GpuInfos: r.GpuInfos}
134133

135-
podTemplateHash, err := workerGenerator.PodTemplateHash(workload.Spec.Resources.Limits)
134+
podTemplateHash, err := workerGenerator.PodTemplateHash(workload.Spec)
136135
if err != nil {
137136
return ctrl.Result{}, fmt.Errorf("get pod template hash: %w", err)
138137
}
@@ -188,7 +187,7 @@ func (r *TensorFusionWorkloadReconciler) Reconcile(ctx context.Context, req ctrl
188187

189188
// Calculate how many pods need to be added
190189
podsToAdd := int(desiredReplicas - currentReplicas)
191-
result, err := r.scaleUpWorkers(ctx, workerGenerator, workload, podsToAdd)
190+
result, err := r.scaleUpWorkers(ctx, workerGenerator, workload, podsToAdd, podTemplateHash)
192191
if err != nil {
193192
return ctrl.Result{}, fmt.Errorf("scale up workers: %w", err)
194193
}
@@ -222,9 +221,10 @@ func (r *TensorFusionWorkloadReconciler) tryStartWorker(
222221
workerGenerator *worker.WorkerGenerator,
223222
gpu *tfv1.GPU,
224223
workload *tfv1.TensorFusionWorkload,
224+
hash string,
225225
) (*corev1.Pod, error) {
226226
port := workerGenerator.AllocPort()
227-
pod, hash, err := workerGenerator.GenerateWorkerPod(gpu, fmt.Sprintf("%s-tf-worker-", workload.Name), workload.Namespace, port, workload.Spec.Resources.Limits)
227+
pod, hash, err := workerGenerator.GenerateWorkerPod(gpu, fmt.Sprintf("%s-tf-worker-", workload.Name), workload.Namespace, port, workload.Spec.Resources.Limits, hash)
228228
if err != nil {
229229
return nil, fmt.Errorf("generate worker pod %w", err)
230230
}
@@ -334,7 +334,7 @@ func (r *TensorFusionWorkloadReconciler) deletePod(ctx context.Context, pod *cor
334334
}
335335

336336
// scaleUpWorkers handles the scaling up of worker pods
337-
func (r *TensorFusionWorkloadReconciler) scaleUpWorkers(ctx context.Context, workerGenerator *worker.WorkerGenerator, workload *tfv1.TensorFusionWorkload, count int) (ctrl.Result, error) {
337+
func (r *TensorFusionWorkloadReconciler) scaleUpWorkers(ctx context.Context, workerGenerator *worker.WorkerGenerator, workload *tfv1.TensorFusionWorkload, count int, hash string) (ctrl.Result, error) {
338338
log := log.FromContext(ctx)
339339

340340
// Create worker pods
@@ -346,7 +346,7 @@ func (r *TensorFusionWorkloadReconciler) scaleUpWorkers(ctx context.Context, wor
346346
return ctrl.Result{RequeueAfter: constants.PendingRequeueDuration}, nil
347347
}
348348

349-
pod, err := r.tryStartWorker(ctx, workerGenerator, gpu, workload)
349+
pod, err := r.tryStartWorker(ctx, workerGenerator, gpu, workload, hash)
350350
if err != nil {
351351
// Try to release the GPU resource if pod creation fails
352352
releaseErr := r.Scheduler.Release(ctx, workload.Spec.Resources.Requests, gpu)
@@ -426,7 +426,7 @@ func (r *TensorFusionWorkloadReconciler) updateStatus(
426426

427427
// Check if we need to update status
428428
statusChanged := workload.Status.ReadyReplicas != readyReplicas ||
429-
!reflect.DeepEqual(workload.Status.WorkerStatuses, workerStatuses)
429+
!equality.Semantic.DeepEqual(workload.Status.WorkerStatuses, workerStatuses)
430430

431431
if statusChanged {
432432
log.Info("Updating workload status", "readyReplicas", readyReplicas, "workerCount", len(workerStatuses))

internal/webhook/v1/pod_webhook.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ import (
2121
"encoding/json"
2222
"fmt"
2323
"net/http"
24-
"reflect"
2524

2625
"gomodules.xyz/jsonpatch/v2"
2726
corev1 "k8s.io/api/core/v1"
27+
"k8s.io/apimachinery/pkg/api/equality"
2828
"k8s.io/apimachinery/pkg/api/errors"
2929
"k8s.io/apimachinery/pkg/runtime"
3030
"k8s.io/apimachinery/pkg/util/strategicpatch"
@@ -204,7 +204,7 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload(ctx context.Context, pod
204204
}
205205

206206
// Compare the entire spec at once
207-
if !reflect.DeepEqual(workload.Spec, desiredSpec) {
207+
if !equality.Semantic.DeepEqual(workload.Spec, desiredSpec) {
208208
workload.Spec = desiredSpec
209209
if err := m.Client.Update(ctx, workload); err != nil {
210210
return fmt.Errorf("failed to update workload: %w", err)

internal/worker/worker.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ func (wg *WorkerGenerator) AllocPort() int {
4646
return rand.Intn(max-min+1) + min
4747
}
4848

49-
func (wg *WorkerGenerator) PodTemplateHash(limits tfv1.Resource) (string, error) {
49+
func (wg *WorkerGenerator) PodTemplateHash(workloadSpec any) (string, error) {
5050
podTmpl := &corev1.PodTemplate{}
5151
err := json.Unmarshal(wg.WorkerConfig.PodTemplate.Raw, podTmpl)
5252
if err != nil {
5353
return "", fmt.Errorf("failed to unmarshal pod template: %w", err)
5454
}
55-
return utils.GetObjectHash(podTmpl, limits), nil
55+
return utils.GetObjectHash(podTmpl, workloadSpec), nil
5656
}
5757

5858
func (wg *WorkerGenerator) GenerateWorkerPod(
@@ -61,13 +61,13 @@ func (wg *WorkerGenerator) GenerateWorkerPod(
6161
namespace string,
6262
port int,
6363
limits tfv1.Resource,
64+
podTemplateHash string,
6465
) (*corev1.Pod, string, error) {
6566
podTmpl := &corev1.PodTemplate{}
6667
err := json.Unmarshal(wg.WorkerConfig.PodTemplate.Raw, podTmpl)
6768
if err != nil {
6869
return nil, "", fmt.Errorf("failed to unmarshal pod template: %w", err)
6970
}
70-
podTemplateHash := utils.GetObjectHash(podTmpl, limits)
7171
spec := podTmpl.Template.Spec
7272
if spec.NodeSelector == nil {
7373
spec.NodeSelector = make(map[string]string)

0 commit comments

Comments
 (0)