Skip to content

Commit f36111c

Browse files
authored
depreacte post cycle from scheduling framework (#1392)
* depreacte post cycle from scheduling framework Signed-off-by: Nir Rozenbaum <[email protected]> * linter Signed-off-by: Nir Rozenbaum <[email protected]> * make linter happy Signed-off-by: Nir Rozenbaum <[email protected]> * addressed code review Signed-off-by: Nir Rozenbaum <[email protected]> --------- Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent 1738c2f commit f36111c

File tree

10 files changed

+120
-129
lines changed

10 files changed

+120
-129
lines changed

pkg/epp/config/loader/configloader_test.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,6 @@ func (f *test1) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMReque
663663

664664
// compile-time type validation
665665
var _ framework.Scorer = &test2{}
666-
var _ framework.PostCycle = &test2{}
667666

668667
type test2 struct {
669668
typedName plugins.TypedName
@@ -683,8 +682,6 @@ func (m *test2) Score(_ context.Context, _ *types.CycleState, _ *types.LLMReques
683682
return map[types.Pod]float64{}
684683
}
685684

686-
func (m *test2) PostCycle(_ context.Context, _ *types.CycleState, _ *types.ProfileRunResult) {}
687-
688685
// compile-time type validation
689686
var _ framework.Picker = &testPicker{}
690687

pkg/epp/config/loader/defaults.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func setDefaultsPhaseTwo(cfg *configapi.EndpointPickerConfig, handle plugins.Han
6969
thePlugins := []configapi.SchedulingPlugin{}
7070
for pluginName, plugin := range allPlugins {
7171
switch plugin.(type) {
72-
case framework.Filter, framework.Picker, framework.PostCycle, framework.Scorer:
72+
case framework.Filter, framework.Scorer, framework.Picker:
7373
thePlugins = append(thePlugins, configapi.SchedulingPlugin{PluginRef: pluginName})
7474
}
7575
}

pkg/epp/plugins/plugin_state_test.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,14 @@ func TestPluginState_ReadWrite(t *testing.T) {
6666
assert.True(t, ok, "should be able to cast to pluginTestData")
6767
assert.Equal(t, data1, td.value)
6868

69-
// Delete the req2 data and verify it's removed
69+
// Delete the req2 data and verify content that was read before is still valid
70+
readData, err = state.Read(req2, key)
71+
assert.NoError(t, err)
7072
state.Delete(req2)
73+
td, ok = readData.(*pluginTestData)
74+
assert.True(t, ok, "should be able to cast to pluginTestData")
75+
assert.Equal(t, data2, td.value)
76+
// try to read again aftet deletion, verify error
7177
readData, err = state.Read(req2, key)
7278
assert.Equal(t, ErrNotFound, err)
7379
assert.Nil(t, readData, "expected no data after delete")

pkg/epp/scheduling/framework/plugins.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ const (
2828
FilterExtensionPoint = "Filter"
2929
ScorerExtensionPoint = "Scorer"
3030
PickerExtensionPoint = "Picker"
31-
PostCycleExtensionPoint = "PostCycle"
3231
ProcessProfilesResultsExtensionPoint = "ProcessProfilesResults"
3332
)
3433

@@ -69,10 +68,3 @@ type Picker interface {
6968
plugins.Plugin
7069
Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult
7170
}
72-
73-
// PostCycle is called by the scheduler after it selects a targetPod for the request in the SchedulerProfile cycle.
74-
// DEPRECATED - do not use PostCycle. this is in the process of deprecation.
75-
type PostCycle interface {
76-
plugins.Plugin
77-
PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult)
78-
}

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828

2929
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3030
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
31+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
3132
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
3233
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3334
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
@@ -73,9 +74,10 @@ type Config struct {
7374
}
7475

7576
type Plugin struct {
76-
Config
77-
typedName plugins.TypedName
78-
indexer Indexer
77+
typedName plugins.TypedName
78+
config Config
79+
pluginState *plugins.PluginState
80+
indexer Indexer
7981
}
8082

8183
// podSet holds an pods servers that may have a specific prefix hash.
@@ -122,10 +124,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
122124

123125
// compile-time type assertion
124126
var _ framework.Scorer = &Plugin{}
125-
var _ framework.PostCycle = &Plugin{}
127+
var _ requestcontrol.PreRequest = &Plugin{}
126128

127129
// PrefixCachePluginFactory defines the factory function for Prefix plugin.
128-
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
130+
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
129131
parameters := Config{
130132
HashBlockSize: DefaultHashBlockSize,
131133
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
@@ -138,11 +140,11 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plug
138140
}
139141
}
140142

141-
return New(parameters).WithName(name), nil
143+
return New(handle.Context(), parameters).WithName(name), nil
142144
}
143145

144146
// New initializes a new prefix Plugin and returns its pointer.
145-
func New(config Config) *Plugin {
147+
func New(ctx context.Context, config Config) *Plugin {
146148
capacity := config.LRUCapacityPerServer
147149
if capacity <= 0 {
148150
capacity = DefaultLRUCapacityPerServer
@@ -153,34 +155,35 @@ func New(config Config) *Plugin {
153155
}
154156

155157
return &Plugin{
156-
typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType},
157-
Config: config,
158-
indexer: newIndexer(capacity),
158+
typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType},
159+
config: config,
160+
pluginState: plugins.NewPluginState(ctx),
161+
indexer: newIndexer(capacity),
159162
}
160163
}
161164

162165
// TypedName returns the type and name tuple of this plugin instance.
163-
func (m *Plugin) TypedName() plugins.TypedName {
164-
return m.typedName
166+
func (p *Plugin) TypedName() plugins.TypedName {
167+
return p.typedName
165168
}
166169

167170
// WithName sets the name of the plugin.
168-
func (m *Plugin) WithName(name string) *Plugin {
169-
m.typedName.Name = name
170-
return m
171+
func (p *Plugin) WithName(name string) *Plugin {
172+
p.typedName.Name = name
173+
return p
171174
}
172175

173176
// Score returns the scoring result for the given list of pods based on context.
174-
func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
177+
func (p *Plugin) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
175178
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
176179
// pre score step, hashing prompt and find longest prefix match.
177-
hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
180+
hashes := hashPrompt(ctx, request, p.config.HashBlockSize, p.config.MaxPrefixBlocksToMatch)
178181
state := &SchedulingContextState{
179182
PrefixHashes: hashes,
180-
PrefixCacheServers: m.matchLongestPrefix(ctx, hashes),
183+
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
181184
}
182185

183-
cycleState.Write(plugins.StateKey(m.TypedName().Type), state)
186+
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().Type), state)
184187
loggerTrace.Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes)
185188
// calculate the scores of pods
186189
scores := make(map[types.Pod]float64, len(pods))
@@ -200,31 +203,34 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
200203
return scores
201204
}
202205

203-
// PostCycle records in the plugin cache the result of the scheduling selection.
204-
func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) {
205-
targetPod := res.TargetPods[0].GetPod()
206-
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
206+
// PreRequest records in the plugin cache the result of the scheduling selection.
207+
func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, _ int) {
208+
primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
209+
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile
210+
211+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, PrefixCachePluginType)
212+
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
207213
if err != nil {
208-
log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state")
214+
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
209215
return
210216
}
211217

212-
m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName))
218+
p.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName))
213219

214220
total := len(state.PrefixHashes)
215221
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]
216-
metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize)
222+
metrics.RecordPrefixCacheMatch(matchLen*p.config.HashBlockSize, total*p.config.HashBlockSize)
217223
}
218224

219225
// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
220-
func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int {
226+
func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int {
221227
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
222228
res := make(map[ServerID]int)
223229
// Use a greedy strategy to search from the longest prefix.
224230
// NOTE: It's possible to further optimize this with a binary search.
225231
for i := 0; i < len(hashes); i++ {
226232
hash := hashes[i]
227-
cachedServers := m.indexer.Get(hash)
233+
cachedServers := p.indexer.Get(hash)
228234
if len(cachedServers) == 0 {
229235
break
230236
} else {

0 commit comments

Comments
 (0)