Skip to content

Commit 72db6f6

Browse files
authored
random endpoint pick on tie break in max score picker (#1205)
Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent 976f385 commit 72db6f6

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ import (
2020
"context"
2121
"encoding/json"
2222
"fmt"
23+
"math/rand"
2324
"slices"
25+
"time"
2426

2527
"sigs.k8s.io/controller-runtime/pkg/log"
2628

@@ -58,13 +60,15 @@ func NewMaxScorePicker(maxNumOfEndpoints int) *MaxScorePicker {
5860
return &MaxScorePicker{
5961
typedName: plugins.TypedName{Type: MaxScorePickerType, Name: MaxScorePickerType},
6062
maxNumOfEndpoints: maxNumOfEndpoints,
63+
randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())),
6164
}
6265
}
6366

6467
// MaxScorePicker picks pod(s) with the maximum score from the list of candidates.
6568
type MaxScorePicker struct {
6669
typedName plugins.TypedName
67-
maxNumOfEndpoints int // maximum number of endpoints to pick
70+
maxNumOfEndpoints int // maximum number of endpoints to pick
71+
randomGenerator *rand.Rand // randomGenerator for randomly pick endpoint on tie-break
6872
}
6973

7074
// WithName sets the picker's name
@@ -83,6 +87,11 @@ func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState,
8387
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates sorted by max score: %+v", p.maxNumOfEndpoints,
8488
len(scoredPods), scoredPods))
8589

90+
// Shuffle in-place - needed for random tie break when scores are equal
91+
p.randomGenerator.Shuffle(len(scoredPods), func(i, j int) {
92+
scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i]
93+
})
94+
8695
slices.SortStableFunc(scoredPods, func(i, j *types.ScoredPod) int { // highest score first
8796
if i.Score > j.Score {
8897
return -1

pkg/epp/scheduling/framework/plugins/picker/picker_test.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"testing"
2222

2323
"github.com/google/go-cmp/cmp"
24+
"github.com/google/go-cmp/cmp/cmpopts"
2425
k8stypes "k8s.io/apimachinery/pkg/types"
2526

2627
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
@@ -34,10 +35,11 @@ func TestPickMaxScorePicker(t *testing.T) {
3435
pod3 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}
3536

3637
tests := []struct {
37-
name string
38-
picker framework.Picker
39-
input []*types.ScoredPod
40-
output []types.Pod
38+
name string
39+
picker framework.Picker
40+
input []*types.ScoredPod
41+
output []types.Pod
42+
tieBreakCandidates int // tie break is random, specify how many candidate with max score
4143
}{
4244
{
4345
name: "Single max score",
@@ -63,6 +65,7 @@ func TestPickMaxScorePicker(t *testing.T) {
6365
&types.ScoredPod{Pod: pod1, Score: 50},
6466
&types.ScoredPod{Pod: pod2, Score: 50},
6567
},
68+
tieBreakCandidates: 2,
6669
},
6770
{
6871
name: "Multiple results sorted by highest score, more pods than needed",
@@ -104,6 +107,7 @@ func TestPickMaxScorePicker(t *testing.T) {
104107
&types.ScoredPod{Pod: pod3, Score: 30},
105108
&types.ScoredPod{Pod: pod2, Score: 25},
106109
},
110+
tieBreakCandidates: 2,
107111
},
108112
}
109113

@@ -112,6 +116,19 @@ func TestPickMaxScorePicker(t *testing.T) {
112116
result := test.picker.Pick(context.Background(), types.NewCycleState(), test.input)
113117
got := result.TargetPods
114118

119+
if test.tieBreakCandidates > 0 {
120+
testMaxScoredPods := test.output[:test.tieBreakCandidates]
121+
gotMaxScoredPods := got[:test.tieBreakCandidates]
122+
diff := cmp.Diff(testMaxScoredPods, gotMaxScoredPods, cmpopts.SortSlices(func(a, b types.Pod) bool {
123+
return a.String() < b.String() // predictable order within the pods with equal scores
124+
}))
125+
if diff != "" {
126+
t.Errorf("Unexpected output (-want +got): %v", diff)
127+
}
128+
test.output = test.output[test.tieBreakCandidates:]
129+
got = got[test.tieBreakCandidates:]
130+
}
131+
115132
if diff := cmp.Diff(test.output, got); diff != "" {
116133
t.Errorf("Unexpected output (-want +got): %v", diff)
117134
}

pkg/epp/scheduling/framework/plugins/picker/random_picker.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"encoding/json"
2222
"fmt"
2323
"math/rand"
24+
"time"
2425

2526
"sigs.k8s.io/controller-runtime/pkg/log"
2627

@@ -57,13 +58,15 @@ func NewRandomPicker(maxNumOfEndpoints int) *RandomPicker {
5758
return &RandomPicker{
5859
typedName: plugins.TypedName{Type: RandomPickerType, Name: RandomPickerType},
5960
maxNumOfEndpoints: maxNumOfEndpoints,
61+
randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())),
6062
}
6163
}
6264

6365
// RandomPicker picks random pod(s) from the list of candidates.
6466
type RandomPicker struct {
6567
typedName plugins.TypedName
6668
maxNumOfEndpoints int
69+
randomGenerator *rand.Rand // randomGenerator for randomly pick endpoint on tie-break
6770
}
6871

6972
// WithName sets the name of the picker.
@@ -83,7 +86,7 @@ func (p *RandomPicker) Pick(ctx context.Context, _ *types.CycleState, scoredPods
8386
len(scoredPods), scoredPods))
8487

8588
// Shuffle in-place
86-
rand.Shuffle(len(scoredPods), func(i, j int) {
89+
p.randomGenerator.Shuffle(len(scoredPods), func(i, j int) {
8790
scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i]
8891
})
8992

0 commit comments

Comments
 (0)