Skip to content

Commit 91e3047

Browse files
authored
handle picking multiple destinations in scheduling layer (#1059)
* implement multiple destination as the output of the scheduler Signed-off-by: Nir Rozenbaum <[email protected]> * updated max score picker unit tests to cover multiple pods Signed-off-by: Nir Rozenbaum <[email protected]> * imports Signed-off-by: Nir Rozenbaum <[email protected]> * unit-test fix Signed-off-by: Nir Rozenbaum <[email protected]> --------- Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent c204c89 commit 91e3047

File tree

15 files changed

+273
-168
lines changed

15 files changed

+273
-168
lines changed

cmd/epp/runner/runner.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ func (r *Runner) initializeScheduler() (*scheduling.Scheduler, error) {
294294
schedulerProfile := framework.NewSchedulerProfile().
295295
WithScorers(framework.NewWeightedScorer(scorer.NewQueueScorer(), queueScorerWeight),
296296
framework.NewWeightedScorer(scorer.NewKVCacheScorer(), kvCacheScorerWeight)).
297-
WithPicker(picker.NewMaxScorePicker())
297+
WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints))
298298

299299
if prefixCacheScheduling {
300300
prefixScorerWeight := envutil.GetEnvInt("PREFIX_CACHE_SCORE_WEIGHT", prefix.DefaultScorerWeight, setupLog)

conformance/testing-epp/scheduler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import (
3030
func NewReqHeaderBasedScheduler() *scheduling.Scheduler {
3131
predicatableSchedulerProfile := framework.NewSchedulerProfile().
3232
WithFilters(filter.NewHeaderBasedTestingFilter()).
33-
WithPicker(picker.NewMaxScorePicker())
33+
WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints))
3434

3535
return scheduling.NewSchedulerWithConfig(scheduling.NewSchedulerConfig(
3636
profile.NewSingleProfileHandler(), map[string]*framework.SchedulerProfile{"req-header-based-profile": predicatableSchedulerProfile}))

conformance/testing-epp/scheduler_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,13 @@ func TestSchedule(t *testing.T) {
8282
wantRes: &types.SchedulingResult{
8383
ProfileResults: map[string]*types.ProfileRunResult{
8484
"req-header-based-profile": {
85-
TargetPod: &types.ScoredPod{
86-
Pod: &types.PodMetrics{
87-
Pod: &backend.Pod{
88-
Address: "matched-endpoint",
89-
Labels: map[string]string{},
85+
TargetPods: []types.Pod{
86+
&types.ScoredPod{
87+
Pod: &types.PodMetrics{
88+
Pod: &backend.Pod{
89+
Address: "matched-endpoint",
90+
Labels: map[string]string{},
91+
},
9092
},
9193
},
9294
},

pkg/epp/requestcontrol/director.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC
238238
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"}
239239
}
240240
// primary profile is used to set destination
241-
targetPod := result.ProfileResults[result.PrimaryProfileName].TargetPod.GetPod()
241+
// TODO should use multiple destinations according to epp protocol. current code assumes a single target
242+
targetPod := result.ProfileResults[result.PrimaryProfileName].TargetPods[0].GetPod()
242243

243244
pool, err := d.datastore.PoolGet()
244245
if err != nil {

pkg/epp/requestcontrol/director_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,13 @@ func TestDirector_HandleRequest(t *testing.T) {
131131
defaultSuccessfulScheduleResults := &schedulingtypes.SchedulingResult{
132132
ProfileResults: map[string]*schedulingtypes.ProfileRunResult{
133133
"testProfile": {
134-
TargetPod: &schedulingtypes.ScoredPod{
135-
Pod: &schedulingtypes.PodMetrics{
136-
Pod: &backend.Pod{
137-
Address: "192.168.1.100",
138-
NamespacedName: k8stypes.NamespacedName{Name: "pod1", Namespace: "default"},
134+
TargetPods: []schedulingtypes.Pod{
135+
&schedulingtypes.ScoredPod{
136+
Pod: &schedulingtypes.PodMetrics{
137+
Pod: &backend.Pod{
138+
Address: "192.168.1.100",
139+
NamespacedName: k8stypes.NamespacedName{Name: "pod1", Namespace: "default"},
140+
},
139141
},
140142
},
141143
},

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
196196

197197
// PostCycle records in the plugin cache the result of the scheduling selection.
198198
func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) {
199-
targetPod := res.TargetPod.GetPod()
199+
targetPod := res.TargetPods[0].GetPod()
200200
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
201201
if err != nil {
202202
log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state")

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func TestPrefixPlugin(t *testing.T) {
6161
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
6262

6363
// Simulate pod1 was picked.
64-
plugin.PostCycle(context.Background(), cycleState1, &types.ProfileRunResult{TargetPod: pod1})
64+
plugin.PostCycle(context.Background(), cycleState1, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
6565

6666
// Second request doesn't share any prefix with first one. It should be added to the cache but
6767
// the pod score should be 0.
@@ -82,7 +82,7 @@ func TestPrefixPlugin(t *testing.T) {
8282
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
8383

8484
// Simulate pod2 was picked.
85-
plugin.PostCycle(context.Background(), cycleState2, &types.ProfileRunResult{TargetPod: pod2})
85+
plugin.PostCycle(context.Background(), cycleState2, &types.ProfileRunResult{TargetPods: []types.Pod{pod2}})
8686

8787
// Third request shares partial prefix with first one.
8888
req3 := &types.LLMRequest{
@@ -101,7 +101,7 @@ func TestPrefixPlugin(t *testing.T) {
101101
assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match")
102102
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
103103

104-
plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPod: pod1})
104+
plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
105105

106106
// 4th request is same as req3 except the model is different, still no match.
107107
req4 := &types.LLMRequest{
@@ -120,7 +120,7 @@ func TestPrefixPlugin(t *testing.T) {
120120
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
121121
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
122122

123-
plugin.PostCycle(context.Background(), cycleState4, &types.ProfileRunResult{TargetPod: pod1})
123+
plugin.PostCycle(context.Background(), cycleState4, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
124124

125125
// 5th request shares partial prefix with 3rd one.
126126
req5 := &types.LLMRequest{
@@ -139,7 +139,7 @@ func TestPrefixPlugin(t *testing.T) {
139139
assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match")
140140
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
141141

142-
plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPod: pod1})
142+
plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
143143
}
144144

145145
// TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length.
@@ -180,7 +180,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
180180
// First cycle: simulate scheduling and insert prefix info into the cache
181181
cycleState := types.NewCycleState()
182182
plugin.Score(context.Background(), cycleState, req, pods)
183-
plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPod: pod})
183+
plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPods: []types.Pod{pod}})
184184

185185
// Second cycle: validate internal state
186186
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package picker
18+
19+
const (
20+
DefaultMaxNumOfEndpoints = 1 // common default to all pickers
21+
)
22+
23+
// pickerParameters defines the common parameters for all pickers
24+
type pickerParameters struct {
25+
MaxNumOfEndpoints int `json:"maxNumOfEndpoints"`
26+
}

pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ import (
2020
"context"
2121
"encoding/json"
2222
"fmt"
23+
"slices"
2324

2425
"sigs.k8s.io/controller-runtime/pkg/log"
26+
2527
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
2628
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
2729
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
@@ -36,53 +38,71 @@ const (
3638
var _ framework.Picker = &MaxScorePicker{}
3739

3840
// MaxScorePickerFactory defines the factory function for MaxScorePicker.
39-
func MaxScorePickerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
40-
return NewMaxScorePicker().WithName(name), nil
41+
func MaxScorePickerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
42+
parameters := pickerParameters{MaxNumOfEndpoints: DefaultMaxNumOfEndpoints}
43+
if rawParameters != nil {
44+
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
45+
return nil, fmt.Errorf("failed to parse the parameters of the '%s' picker - %w", MaxScorePickerType, err)
46+
}
47+
}
48+
49+
return NewMaxScorePicker(parameters.MaxNumOfEndpoints).WithName(name), nil
4150
}
4251

4352
// NewMaxScorePicker initializes a new MaxScorePicker and returns its pointer.
44-
func NewMaxScorePicker() *MaxScorePicker {
53+
func NewMaxScorePicker(maxNumOfEndpoints int) *MaxScorePicker {
54+
if maxNumOfEndpoints <= 0 {
55+
maxNumOfEndpoints = DefaultMaxNumOfEndpoints // on invalid configuration value, fallback to default value
56+
}
57+
4558
return &MaxScorePicker{
46-
tn: plugins.TypedName{Type: MaxScorePickerType, Name: MaxScorePickerType},
47-
random: NewRandomPicker(),
59+
typedName: plugins.TypedName{Type: MaxScorePickerType, Name: MaxScorePickerType},
60+
maxNumOfEndpoints: maxNumOfEndpoints,
4861
}
4962
}
5063

51-
// MaxScorePicker picks the pod with the maximum score from the list of candidates.
64+
// MaxScorePicker picks pod(s) with the maximum score from the list of candidates.
5265
type MaxScorePicker struct {
53-
tn plugins.TypedName
54-
random *RandomPicker
55-
}
56-
57-
// TypedName returns the type and name tuple of this plugin instance.
58-
func (p *MaxScorePicker) TypedName() plugins.TypedName {
59-
return p.tn
66+
typedName plugins.TypedName
67+
maxNumOfEndpoints int // maximum number of endpoints to pick
6068
}
6169

6270
// WithName sets the picker's name
6371
func (p *MaxScorePicker) WithName(name string) *MaxScorePicker {
64-
p.tn.Name = name
72+
p.typedName.Name = name
6573
return p
6674
}
6775

76+
// TypedName returns the type and name tuple of this plugin instance.
77+
func (p *MaxScorePicker) TypedName() plugins.TypedName {
78+
return p.typedName
79+
}
80+
6881
// Pick selects the pod with the maximum score from the list of candidates.
6982
func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult {
70-
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a pod with the max score from %d candidates: %+v", len(scoredPods), scoredPods))
71-
72-
highestScorePods := []*types.ScoredPod{}
73-
maxScore := -1.0 // pods min score is 0, putting value lower than 0 in order to find at least one pod as highest
74-
for _, pod := range scoredPods {
75-
if pod.Score > maxScore {
76-
maxScore = pod.Score
77-
highestScorePods = []*types.ScoredPod{pod}
78-
} else if pod.Score == maxScore {
79-
highestScorePods = append(highestScorePods, pod)
83+
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates sorted by max score: %+v", p.maxNumOfEndpoints,
84+
len(scoredPods), scoredPods))
85+
86+
slices.SortStableFunc(scoredPods, func(i, j *types.ScoredPod) int { // highest score first
87+
if i.Score > j.Score {
88+
return -1
89+
}
90+
if i.Score < j.Score {
91+
return 1
8092
}
93+
return 0
94+
})
95+
96+
// if we have enough pods to return keep only the "maxNumOfEndpoints" highest scored pods
97+
if p.maxNumOfEndpoints < len(scoredPods) {
98+
scoredPods = scoredPods[:p.maxNumOfEndpoints]
8199
}
82100

83-
if len(highestScorePods) > 1 {
84-
return p.random.Pick(ctx, cycleState, highestScorePods) // pick randomly from the highest score pods
101+
targetPods := make([]types.Pod, len(scoredPods))
102+
for i, scoredPod := range scoredPods {
103+
targetPods[i] = scoredPod
85104
}
86105

87-
return &types.ProfileRunResult{TargetPod: highestScorePods[0]}
106+
return &types.ProfileRunResult{TargetPods: targetPods}
107+
88108
}

0 commit comments

Comments
 (0)