Skip to content

Commit 287f962

Browse files
committed
depreacte post cycle from scheduling framework
Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent 5b4fbb9 commit 287f962

File tree

9 files changed

+112
-128
lines changed

9 files changed

+112
-128
lines changed

pkg/epp/config/loader/configloader_test.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,6 @@ func (f *test1) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMReque
663663

664664
// compile-time type validation
665665
var _ framework.Scorer = &test2{}
666-
var _ framework.PostCycle = &test2{}
667666

668667
type test2 struct {
669668
typedName plugins.TypedName
@@ -683,8 +682,6 @@ func (m *test2) Score(_ context.Context, _ *types.CycleState, _ *types.LLMReques
683682
return map[types.Pod]float64{}
684683
}
685684

686-
func (m *test2) PostCycle(_ context.Context, _ *types.CycleState, _ *types.ProfileRunResult) {}
687-
688685
// compile-time type validation
689686
var _ framework.Picker = &testPicker{}
690687

pkg/epp/config/loader/defaults.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func setDefaultsPhaseTwo(cfg *configapi.EndpointPickerConfig, handle plugins.Han
6969
thePlugins := []configapi.SchedulingPlugin{}
7070
for pluginName, plugin := range allPlugins {
7171
switch plugin.(type) {
72-
case framework.Filter, framework.Picker, framework.PostCycle, framework.Scorer:
72+
case framework.Filter, framework.Scorer, framework.Picker:
7373
thePlugins = append(thePlugins, configapi.SchedulingPlugin{PluginRef: pluginName})
7474
}
7575
}

pkg/epp/scheduling/framework/plugins.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ const (
2828
FilterExtensionPoint = "Filter"
2929
ScorerExtensionPoint = "Scorer"
3030
PickerExtensionPoint = "Picker"
31-
PostCycleExtensionPoint = "PostCycle"
3231
ProcessProfilesResultsExtensionPoint = "ProcessProfilesResults"
3332
)
3433

@@ -69,10 +68,3 @@ type Picker interface {
6968
plugins.Plugin
7069
Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult
7170
}
72-
73-
// PostCycle is called by the scheduler after it selects a targetPod for the request in the SchedulerProfile cycle.
74-
// DEPRECATED - do not use PostCycle. this is in the process of deprecation.
75-
type PostCycle interface {
76-
plugins.Plugin
77-
PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult)
78-
}

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828

2929
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3030
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
31+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
3132
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
3233
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3334
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
@@ -73,9 +74,10 @@ type Config struct {
7374
}
7475

7576
type Plugin struct {
76-
Config
77-
typedName plugins.TypedName
78-
indexer Indexer
77+
typedName plugins.TypedName
78+
config Config
79+
pluginState *plugins.PluginState
80+
indexer Indexer
7981
}
8082

8183
// podSet holds an pods servers that may have a specific prefix hash.
@@ -122,10 +124,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
122124

123125
// compile-time type assertion
124126
var _ framework.Scorer = &Plugin{}
125-
var _ framework.PostCycle = &Plugin{}
127+
var _ requestcontrol.PreRequest = &Plugin{}
126128

127129
// PrefixCachePluginFactory defines the factory function for Prefix plugin.
128-
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
130+
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
129131
parameters := Config{
130132
HashBlockSize: DefaultHashBlockSize,
131133
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
@@ -138,11 +140,11 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plug
138140
}
139141
}
140142

141-
return New(parameters).WithName(name), nil
143+
return New(handle.Context(), parameters).WithName(name), nil
142144
}
143145

144146
// New initializes a new prefix Plugin and returns its pointer.
145-
func New(config Config) *Plugin {
147+
func New(ctx context.Context, config Config) *Plugin {
146148
capacity := config.LRUCapacityPerServer
147149
if capacity <= 0 {
148150
capacity = DefaultLRUCapacityPerServer
@@ -153,34 +155,35 @@ func New(config Config) *Plugin {
153155
}
154156

155157
return &Plugin{
156-
typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType},
157-
Config: config,
158-
indexer: newIndexer(capacity),
158+
typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType},
159+
config: config,
160+
pluginState: plugins.NewPluginState(ctx),
161+
indexer: newIndexer(capacity),
159162
}
160163
}
161164

162165
// TypedName returns the type and name tuple of this plugin instance.
163-
func (m *Plugin) TypedName() plugins.TypedName {
164-
return m.typedName
166+
func (p *Plugin) TypedName() plugins.TypedName {
167+
return p.typedName
165168
}
166169

167170
// WithName sets the name of the plugin.
168-
func (m *Plugin) WithName(name string) *Plugin {
169-
m.typedName.Name = name
170-
return m
171+
func (p *Plugin) WithName(name string) *Plugin {
172+
p.typedName.Name = name
173+
return p
171174
}
172175

173176
// Score returns the scoring result for the given list of pods based on context.
174-
func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
177+
func (p *Plugin) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
175178
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
176179
// pre score step, hashing prompt and find longest prefix match.
177-
hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
180+
hashes := hashPrompt(ctx, request, p.config.HashBlockSize, p.config.MaxPrefixBlocksToMatch)
178181
state := &SchedulingContextState{
179182
PrefixHashes: hashes,
180-
PrefixCacheServers: m.matchLongestPrefix(ctx, hashes),
183+
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
181184
}
182185

183-
cycleState.Write(plugins.StateKey(m.TypedName().Type), state)
186+
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().Type), state)
184187
loggerTrace.Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes)
185188
// calculate the scores of pods
186189
scores := make(map[types.Pod]float64, len(pods))
@@ -200,31 +203,33 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
200203
return scores
201204
}
202205

203-
// PostCycle records in the plugin cache the result of the scheduling selection.
204-
func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) {
205-
targetPod := res.TargetPods[0].GetPod()
206-
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
206+
// PreRequest records in the plugin cache the result of the scheduling selection.
207+
func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, _ int) {
208+
primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
209+
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile
210+
211+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, PrefixCachePluginType)
207212
if err != nil {
208-
log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state")
213+
log.FromContext(ctx).Error(err, "failed to read prefix plugin state for request %s", request.RequestId)
209214
return
210215
}
211216

212-
m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName))
217+
p.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName))
213218

214219
total := len(state.PrefixHashes)
215220
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]
216-
metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize)
221+
metrics.RecordPrefixCacheMatch(matchLen*p.config.HashBlockSize, total*p.config.HashBlockSize)
217222
}
218223

219224
// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
220-
func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int {
225+
func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int {
221226
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
222227
res := make(map[ServerID]int)
223228
// Use a greedy strategy to search from the longest prefix.
224229
// NOTE: It's possible to further optimize this with a binary search.
225230
for i := 0; i < len(hashes); i++ {
226231
hash := hashes[i]
227-
cachedServers := m.indexer.Get(hash)
232+
cachedServers := p.indexer.Get(hash)
228233
if len(cachedServers) == 0 {
229234
break
230235
} else {

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ import (
2424
"strings"
2525
"testing"
2626

27+
"github.com/google/uuid"
2728
"github.com/stretchr/testify/assert"
2829
k8stypes "k8s.io/apimachinery/pkg/types"
2930

3031
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
32+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
3133
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3234
)
3335

@@ -38,20 +40,20 @@ func TestPrefixPlugin(t *testing.T) {
3840
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
3941
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
4042
}
41-
plugin := New(config)
43+
plugin := New(context.Background(), config)
4244

4345
pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}
4446
pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}
4547
pods := []types.Pod{pod1, pod2}
4648

4749
// First request.
4850
req1 := &types.LLMRequest{
51+
RequestId: uuid.NewString(),
4952
TargetModel: "test-model1",
5053
Prompt: "aaaaaa",
5154
}
52-
cycleState1 := types.NewCycleState()
53-
scores := plugin.Score(context.Background(), cycleState1, req1, pods)
54-
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState1, PrefixCachePluginType)
55+
scores := plugin.Score(context.Background(), nil, req1, pods)
56+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, PrefixCachePluginType)
5557
assert.NoError(t, err)
5658
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
5759
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
@@ -62,17 +64,23 @@ func TestPrefixPlugin(t *testing.T) {
6264
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
6365

6466
// Simulate pod1 was picked.
65-
plugin.PostCycle(context.Background(), cycleState1, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
67+
schedulingResult := &types.SchedulingResult{
68+
PrimaryProfileName: "default",
69+
ProfileResults: map[string]*types.ProfileRunResult{
70+
"default": &types.ProfileRunResult{TargetPods: []types.Pod{pod1}},
71+
},
72+
}
73+
plugin.PreRequest(context.Background(), req1, schedulingResult, 0)
6674

6775
// Second request doesn't share any prefix with first one. It should be added to the cache but
6876
// the pod score should be 0.
6977
req2 := &types.LLMRequest{
78+
RequestId: uuid.NewString(),
7079
TargetModel: "test-model2",
7180
Prompt: "bbbbbb",
7281
}
73-
cycleState2 := types.NewCycleState()
74-
scores = plugin.Score(context.Background(), cycleState2, req2, pods)
75-
state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState2, PrefixCachePluginType)
82+
scores = plugin.Score(context.Background(), nil, req2, pods)
83+
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, PrefixCachePluginType)
7684
assert.NoError(t, err)
7785
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
7886
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
@@ -83,16 +91,22 @@ func TestPrefixPlugin(t *testing.T) {
8391
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
8492

8593
// Simulate pod2 was picked.
86-
plugin.PostCycle(context.Background(), cycleState2, &types.ProfileRunResult{TargetPods: []types.Pod{pod2}})
94+
schedulingResult = &types.SchedulingResult{
95+
PrimaryProfileName: "default",
96+
ProfileResults: map[string]*types.ProfileRunResult{
97+
"default": &types.ProfileRunResult{TargetPods: []types.Pod{pod2}},
98+
},
99+
}
100+
plugin.PreRequest(context.Background(), req2, schedulingResult, 0)
87101

88102
// Third request shares partial prefix with first one.
89103
req3 := &types.LLMRequest{
104+
RequestId: uuid.NewString(),
90105
TargetModel: "test-model1",
91106
Prompt: "aaaabbbb",
92107
}
93-
cycleState3 := types.NewCycleState()
94-
scores = plugin.Score(context.Background(), cycleState3, req3, pods)
95-
state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState3, PrefixCachePluginType)
108+
scores = plugin.Score(context.Background(), nil, req3, pods)
109+
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, PrefixCachePluginType)
96110
assert.NoError(t, err)
97111
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
98112
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
@@ -102,16 +116,22 @@ func TestPrefixPlugin(t *testing.T) {
102116
assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match")
103117
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
104118

105-
plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
119+
schedulingResult = &types.SchedulingResult{
120+
PrimaryProfileName: "default",
121+
ProfileResults: map[string]*types.ProfileRunResult{
122+
"default": &types.ProfileRunResult{TargetPods: []types.Pod{pod1}},
123+
},
124+
}
125+
plugin.PreRequest(context.Background(), req3, schedulingResult, 0)
106126

107127
// 4th request is same as req3 except the model is different, still no match.
108128
req4 := &types.LLMRequest{
129+
RequestId: uuid.NewString(),
109130
TargetModel: "test-model-new",
110131
Prompt: "aaaabbbb",
111132
}
112-
cycleState4 := types.NewCycleState()
113-
scores = plugin.Score(context.Background(), cycleState4, req4, pods)
114-
state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState4, PrefixCachePluginType)
133+
scores = plugin.Score(context.Background(), nil, req4, pods)
134+
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, PrefixCachePluginType)
115135
assert.NoError(t, err)
116136
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
117137
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
@@ -121,16 +141,22 @@ func TestPrefixPlugin(t *testing.T) {
121141
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
122142
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
123143

124-
plugin.PostCycle(context.Background(), cycleState4, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
144+
schedulingResult = &types.SchedulingResult{
145+
PrimaryProfileName: "default",
146+
ProfileResults: map[string]*types.ProfileRunResult{
147+
"default": &types.ProfileRunResult{TargetPods: []types.Pod{pod1}},
148+
},
149+
}
150+
plugin.PreRequest(context.Background(), req4, schedulingResult, 0)
125151

126152
// 5th request shares partial prefix with 3rd one.
127153
req5 := &types.LLMRequest{
154+
RequestId: uuid.NewString(),
128155
TargetModel: "test-model1",
129156
Prompt: "aaaabbbbcccc",
130157
}
131-
cycleState5 := types.NewCycleState()
132-
scores = plugin.Score(context.Background(), cycleState5, req5, pods)
133-
state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState5, PrefixCachePluginType)
158+
scores = plugin.Score(context.Background(), nil, req5, pods)
159+
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, PrefixCachePluginType)
134160
assert.NoError(t, err)
135161
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
136162
// Input size is 12, hash block size is 4, so 3 hashes will be calculated.
@@ -140,7 +166,13 @@ func TestPrefixPlugin(t *testing.T) {
140166
assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match")
141167
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
142168

143-
plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
169+
schedulingResult = &types.SchedulingResult{
170+
PrimaryProfileName: "default",
171+
ProfileResults: map[string]*types.ProfileRunResult{
172+
"default": &types.ProfileRunResult{TargetPods: []types.Pod{pod1}},
173+
},
174+
}
175+
plugin.PreRequest(context.Background(), req5, schedulingResult, 0)
144176
}
145177

146178
// TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length.
@@ -153,7 +185,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
153185
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
154186
}
155187

156-
plugin := New(config)
188+
plugin := New(context.Background(), config)
157189
types.NewCycleState()
158190
var promptLen []int
159191
for i := 1; i <= 1024; i++ {
@@ -174,17 +206,23 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
174206

175207
pods := []types.Pod{pod}
176208
req := &types.LLMRequest{
209+
RequestId: uuid.NewString(),
177210
TargetModel: "model-stress",
178211
Prompt: prompt,
179212
}
180213

181214
// First cycle: simulate scheduling and insert prefix info into the cache
182-
cycleState := types.NewCycleState()
183-
plugin.Score(context.Background(), cycleState, req, pods)
184-
plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPods: []types.Pod{pod}})
215+
plugin.Score(context.Background(), nil, req, pods)
216+
schedulingResult := &types.SchedulingResult{
217+
PrimaryProfileName: "default",
218+
ProfileResults: map[string]*types.ProfileRunResult{
219+
"default": &types.ProfileRunResult{TargetPods: []types.Pod{pod}},
220+
},
221+
}
222+
plugin.PreRequest(context.Background(), req, schedulingResult, 0)
185223

186224
// Second cycle: validate internal state
187-
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
225+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, PrefixCachePluginType)
188226
assert.NoError(b, err)
189227
expectedHashes := int(math.Min(float64(maxPrefixBlocks+1), float64(len(req.Prompt)/blockSize+1))) // the extra one is for the model.
190228
assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect")

0 commit comments

Comments
 (0)