diff --git a/pkg/epp/config/loader/configloader_test.go b/pkg/epp/config/loader/configloader_test.go index 0701f9642..15a97f262 100644 --- a/pkg/epp/config/loader/configloader_test.go +++ b/pkg/epp/config/loader/configloader_test.go @@ -663,7 +663,6 @@ func (f *test1) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMReque // compile-time type validation var _ framework.Scorer = &test2{} -var _ framework.PostCycle = &test2{} type test2 struct { typedName plugins.TypedName @@ -683,8 +682,6 @@ func (m *test2) Score(_ context.Context, _ *types.CycleState, _ *types.LLMReques return map[types.Pod]float64{} } -func (m *test2) PostCycle(_ context.Context, _ *types.CycleState, _ *types.ProfileRunResult) {} - // compile-time type validation var _ framework.Picker = &testPicker{} diff --git a/pkg/epp/config/loader/defaults.go b/pkg/epp/config/loader/defaults.go index 7be0d980f..2a8236ea5 100644 --- a/pkg/epp/config/loader/defaults.go +++ b/pkg/epp/config/loader/defaults.go @@ -69,7 +69,7 @@ func setDefaultsPhaseTwo(cfg *configapi.EndpointPickerConfig, handle plugins.Han thePlugins := []configapi.SchedulingPlugin{} for pluginName, plugin := range allPlugins { switch plugin.(type) { - case framework.Filter, framework.Picker, framework.PostCycle, framework.Scorer: + case framework.Filter, framework.Scorer, framework.Picker: thePlugins = append(thePlugins, configapi.SchedulingPlugin{PluginRef: pluginName}) } } diff --git a/pkg/epp/plugins/plugin_state_test.go b/pkg/epp/plugins/plugin_state_test.go index 46c0de87e..0de65c23a 100644 --- a/pkg/epp/plugins/plugin_state_test.go +++ b/pkg/epp/plugins/plugin_state_test.go @@ -66,8 +66,14 @@ func TestPluginState_ReadWrite(t *testing.T) { assert.True(t, ok, "should be able to cast to pluginTestData") assert.Equal(t, data1, td.value) - // Delete the req2 data and verify it's removed + // Delete the req2 data and verify content that was read before is still valid + readData, err = state.Read(req2, key) + assert.NoError(t, err) state.Delete(req2) + td, ok = readData.(*pluginTestData) + assert.True(t, ok, "should be able to cast to pluginTestData") + assert.Equal(t, data2, td.value) + // try to read again aftet deletion, verify error readData, err = state.Read(req2, key) assert.Equal(t, ErrNotFound, err) assert.Nil(t, readData, "expected no data after delete") diff --git a/pkg/epp/scheduling/framework/plugins.go b/pkg/epp/scheduling/framework/plugins.go index 2c2284366..99397a4b3 100644 --- a/pkg/epp/scheduling/framework/plugins.go +++ b/pkg/epp/scheduling/framework/plugins.go @@ -28,7 +28,6 @@ const ( FilterExtensionPoint = "Filter" ScorerExtensionPoint = "Scorer" PickerExtensionPoint = "Picker" - PostCycleExtensionPoint = "PostCycle" ProcessProfilesResultsExtensionPoint = "ProcessProfilesResults" ) @@ -69,10 +68,3 @@ type Picker interface { plugins.Plugin Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult } - -// PostCycle is called by the scheduler after it selects a targetPod for the request in the SchedulerProfile cycle. -// DEPRECATED - do not use PostCycle. this is in the process of deprecation. -type PostCycle interface { - plugins.Plugin - PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) -} diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 9c1a8d844..6cb7a3a6c 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -28,6 +28,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -73,9 +74,10 @@ type Config struct { } type Plugin struct { - Config - typedName plugins.TypedName - indexer Indexer + typedName plugins.TypedName + config Config + pluginState *plugins.PluginState + indexer Indexer } // podSet holds an pods servers that may have a specific prefix hash. @@ -122,10 +124,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData { // compile-time type assertion var _ framework.Scorer = &Plugin{} -var _ framework.PostCycle = &Plugin{} +var _ requestcontrol.PreRequest = &Plugin{} // PrefixCachePluginFactory defines the factory function for Prefix plugin. -func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { parameters := Config{ HashBlockSize: DefaultHashBlockSize, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, @@ -138,11 +140,11 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plug } } - return New(parameters).WithName(name), nil + return New(handle.Context(), parameters).WithName(name), nil } // New initializes a new prefix Plugin and returns its pointer. -func New(config Config) *Plugin { +func New(ctx context.Context, config Config) *Plugin { capacity := config.LRUCapacityPerServer if capacity <= 0 { capacity = DefaultLRUCapacityPerServer @@ -153,34 +155,35 @@ func New(config Config) *Plugin { } return &Plugin{ - typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType}, - Config: config, - indexer: newIndexer(capacity), + typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType}, + config: config, + pluginState: plugins.NewPluginState(ctx), + indexer: newIndexer(capacity), } } // TypedName returns the type and name tuple of this plugin instance. -func (m *Plugin) TypedName() plugins.TypedName { - return m.typedName +func (p *Plugin) TypedName() plugins.TypedName { + return p.typedName } // WithName sets the name of the plugin. -func (m *Plugin) WithName(name string) *Plugin { - m.typedName.Name = name - return m +func (p *Plugin) WithName(name string) *Plugin { + p.typedName.Name = name + return p } // Score returns the scoring result for the given list of pods based on context. -func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (p *Plugin) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { loggerTrace := log.FromContext(ctx).V(logutil.TRACE) // pre score step, hashing prompt and find longest prefix match. - hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch) + hashes := hashPrompt(ctx, request, p.config.HashBlockSize, p.config.MaxPrefixBlocksToMatch) state := &SchedulingContextState{ PrefixHashes: hashes, - PrefixCacheServers: m.matchLongestPrefix(ctx, hashes), + PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), } - cycleState.Write(plugins.StateKey(m.TypedName().Type), state) + p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().Type), state) loggerTrace.Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes) // calculate the scores of pods scores := make(map[types.Pod]float64, len(pods)) @@ -200,31 +203,34 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques return scores } -// PostCycle records in the plugin cache the result of the scheduling selection. -func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) { - targetPod := res.TargetPods[0].GetPod() - state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType) +// PreRequest records in the plugin cache the result of the scheduling selection. +func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, _ int) { + primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName] + targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile + + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, PrefixCachePluginType) + p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it if err != nil { - log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state") + log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) return } - m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName)) + p.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName)) total := len(state.PrefixHashes) matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)] - metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize) + metrics.RecordPrefixCacheMatch(matchLen*p.config.HashBlockSize, total*p.config.HashBlockSize) } // matchLongestPrefix returns a map of servers and length of prefix that each server caches. -func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int { +func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int { loggerTrace := log.FromContext(ctx).V(logutil.TRACE) res := make(map[ServerID]int) // Use a greedy strategy to search from the longest prefix. // NOTE: It's possible to further optimize this with a binary search. for i := 0; i < len(hashes); i++ { hash := hashes[i] - cachedServers := m.indexer.Get(hash) + cachedServers := p.indexer.Get(hash) if len(cachedServers) == 0 { break } else { diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index aaf68f0d8..666d05346 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -24,10 +24,12 @@ import ( "strings" "testing" + "github.com/google/uuid" "github.com/stretchr/testify/assert" k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -38,7 +40,7 @@ func TestPrefixPlugin(t *testing.T) { MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, LRUCapacityPerServer: DefaultLRUCapacityPerServer, } - plugin := New(config) + plugin := New(context.Background(), config) pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} @@ -46,12 +48,12 @@ func TestPrefixPlugin(t *testing.T) { // First request. req1 := &types.LLMRequest{ + RequestId: uuid.NewString(), TargetModel: "test-model1", Prompt: "aaaaaa", } - cycleState1 := types.NewCycleState() - scores := plugin.Score(context.Background(), cycleState1, req1, pods) - state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState1, PrefixCachePluginType) + scores := plugin.Score(context.Background(), nil, req1, pods) + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, PrefixCachePluginType) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 6, hash block size is 4, the last 2 characters are ignored. @@ -62,17 +64,23 @@ func TestPrefixPlugin(t *testing.T) { assert.Equal(t, float64(0), scores[pod2], "score for pod2") // Simulate pod1 was picked. - plugin.PostCycle(context.Background(), cycleState1, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}}) + schedulingResult := &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod1}}, + }, + } + plugin.PreRequest(context.Background(), req1, schedulingResult, 0) // Second request doesn't share any prefix with first one. It should be added to the cache but // the pod score should be 0. req2 := &types.LLMRequest{ + RequestId: uuid.NewString(), TargetModel: "test-model2", Prompt: "bbbbbb", } - cycleState2 := types.NewCycleState() - scores = plugin.Score(context.Background(), cycleState2, req2, pods) - state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState2, PrefixCachePluginType) + scores = plugin.Score(context.Background(), nil, req2, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, PrefixCachePluginType) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 6, hash block size is 4, the last 2 characters are ignored. @@ -83,16 +91,22 @@ func TestPrefixPlugin(t *testing.T) { assert.Equal(t, float64(0), scores[pod2], "score for pod2") // Simulate pod2 was picked. - plugin.PostCycle(context.Background(), cycleState2, &types.ProfileRunResult{TargetPods: []types.Pod{pod2}}) + schedulingResult = &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod2}}, + }, + } + plugin.PreRequest(context.Background(), req2, schedulingResult, 0) // Third request shares partial prefix with first one. req3 := &types.LLMRequest{ + RequestId: uuid.NewString(), TargetModel: "test-model1", Prompt: "aaaabbbb", } - cycleState3 := types.NewCycleState() - scores = plugin.Score(context.Background(), cycleState3, req3, pods) - state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState3, PrefixCachePluginType) + scores = plugin.Score(context.Background(), nil, req3, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, PrefixCachePluginType) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 8, hash block size is 4, so 2 hashes will be calculated. @@ -102,16 +116,22 @@ func TestPrefixPlugin(t *testing.T) { assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match") assert.Equal(t, float64(0), scores[pod2], "score for pod2") - plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}}) + schedulingResult = &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod1}}, + }, + } + plugin.PreRequest(context.Background(), req3, schedulingResult, 0) // 4th request is same as req3 except the model is different, still no match. req4 := &types.LLMRequest{ + RequestId: uuid.NewString(), TargetModel: "test-model-new", Prompt: "aaaabbbb", } - cycleState4 := types.NewCycleState() - scores = plugin.Score(context.Background(), cycleState4, req4, pods) - state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState4, PrefixCachePluginType) + scores = plugin.Score(context.Background(), nil, req4, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, PrefixCachePluginType) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 8, hash block size is 4, so 2 hashes will be calculated. @@ -121,16 +141,22 @@ func TestPrefixPlugin(t *testing.T) { assert.Equal(t, float64(0), scores[pod1], "score for pod1") assert.Equal(t, float64(0), scores[pod2], "score for pod2") - plugin.PostCycle(context.Background(), cycleState4, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}}) + schedulingResult = &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod1}}, + }, + } + plugin.PreRequest(context.Background(), req4, schedulingResult, 0) // 5th request shares partial prefix with 3rd one. req5 := &types.LLMRequest{ + RequestId: uuid.NewString(), TargetModel: "test-model1", Prompt: "aaaabbbbcccc", } - cycleState5 := types.NewCycleState() - scores = plugin.Score(context.Background(), cycleState5, req5, pods) - state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState5, PrefixCachePluginType) + scores = plugin.Score(context.Background(), nil, req5, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, PrefixCachePluginType) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 12, hash block size is 4, so 3 hashes will be calculated. @@ -140,7 +166,13 @@ func TestPrefixPlugin(t *testing.T) { assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match") assert.Equal(t, float64(0), scores[pod2], "score for pod2") - plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}}) + schedulingResult = &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod1}}, + }, + } + plugin.PreRequest(context.Background(), req5, schedulingResult, 0) } // TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length. @@ -153,7 +185,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) { LRUCapacityPerServer: DefaultLRUCapacityPerServer, } - plugin := New(config) + plugin := New(context.Background(), config) types.NewCycleState() var promptLen []int for i := 1; i <= 1024; i++ { @@ -174,17 +206,23 @@ func BenchmarkPrefixPluginStress(b *testing.B) { pods := []types.Pod{pod} req := &types.LLMRequest{ + RequestId: uuid.NewString(), TargetModel: "model-stress", Prompt: prompt, } // First cycle: simulate scheduling and insert prefix info into the cache - cycleState := types.NewCycleState() - plugin.Score(context.Background(), cycleState, req, pods) - plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPods: []types.Pod{pod}}) + plugin.Score(context.Background(), nil, req, pods) + schedulingResult := &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod}}, + }, + } + plugin.PreRequest(context.Background(), req, schedulingResult, 0) // Second cycle: validate internal state - state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType) + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, PrefixCachePluginType) assert.NoError(b, err) expectedHashes := int(math.Min(float64(maxPrefixBlocks+1), float64(len(req.Prompt)/blockSize+1))) // the extra one is for the model. assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect") diff --git a/pkg/epp/scheduling/framework/scheduler_profile.go b/pkg/epp/scheduling/framework/scheduler_profile.go index b3b24d5a2..6037e3541 100644 --- a/pkg/epp/scheduling/framework/scheduler_profile.go +++ b/pkg/epp/scheduling/framework/scheduler_profile.go @@ -34,19 +34,17 @@ import ( // NewSchedulerProfile creates a new SchedulerProfile object and returns its pointer. func NewSchedulerProfile() *SchedulerProfile { return &SchedulerProfile{ - filters: []Filter{}, - scorers: []*WeightedScorer{}, - postCyclePlugins: []PostCycle{}, + filters: []Filter{}, + scorers: []*WeightedScorer{}, // picker remains nil since profile doesn't support multiple pickers } } // SchedulerProfile provides a profile configuration for the scheduler which influence routing decisions. type SchedulerProfile struct { - filters []Filter - scorers []*WeightedScorer - picker Picker - postCyclePlugins []PostCycle + filters []Filter + scorers []*WeightedScorer + picker Picker } // WithFilters sets the given filter plugins as the Filter plugins. @@ -70,13 +68,6 @@ func (p *SchedulerProfile) WithPicker(picker Picker) *SchedulerProfile { return p } -// WithPostCyclePlugins sets the given plugins as the PostCycle plugins. -// If the SchedulerProfile has PostCycle plugins, this call replaces the existing plugins with the given ones. -func (p *SchedulerProfile) WithPostCyclePlugins(plugins ...PostCycle) *SchedulerProfile { - p.postCyclePlugins = plugins - return p -} - // AddPlugins adds the given plugins to all scheduler plugins according to the interfaces each plugin implements. // A plugin may implement more than one scheduler plugin interface. // Special Case: In order to add a scorer, one must use the scorer.NewWeightedScorer function in order to provide a weight. @@ -99,9 +90,6 @@ func (p *SchedulerProfile) AddPlugins(pluginObjects ...plugins.Plugin) error { } p.picker = picker } - if postCyclePlugin, ok := plugin.(PostCycle); ok { - p.postCyclePlugins = append(p.postCyclePlugins, postCyclePlugin) - } } return nil } @@ -115,21 +103,17 @@ func (p *SchedulerProfile) String() string { for i, scorer := range p.scorers { scorerNames[i] = fmt.Sprintf("%s: %d", scorer.TypedName(), scorer.Weight()) } - postCyclePluginNames := make([]string, len(p.postCyclePlugins)) - for i, postCyclePlugin := range p.postCyclePlugins { - postCyclePluginNames[i] = postCyclePlugin.TypedName().String() - } + return fmt.Sprintf( - "{Filters: [%s], Scorers: [%s], Picker: %s, PostCyclePlugins: [%s]}", + "{Filters: [%s], Scorers: [%s], Picker: %s}", strings.Join(filterNames, ", "), strings.Join(scorerNames, ", "), p.picker.TypedName(), - strings.Join(postCyclePluginNames, ", "), ) } -// RunCycle runs a SchedulerProfile cycle. In other words, it invokes all the SchedulerProfile plugins in this -// order - Filters, Scorers, Picker, PostCyclePlugins. After completing all, it returns the result. +// Run runs a SchedulerProfile. It invokes all the SchedulerProfile plugins for the given request in this +// order - Filters, Scorers, Picker. After completing all, it returns the result. func (p *SchedulerProfile) Run(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, candidatePods []types.Pod) (*types.ProfileRunResult, error) { pods := p.runFilterPlugins(ctx, request, cycleState, candidatePods) if len(pods) == 0 { @@ -140,8 +124,6 @@ func (p *SchedulerProfile) Run(ctx context.Context, request *types.LLMRequest, c result := p.runPickerPlugin(ctx, cycleState, weightedScorePerPod) - p.runPostCyclePlugins(ctx, cycleState, result) - return result, nil } @@ -207,17 +189,6 @@ func (p *SchedulerProfile) runPickerPlugin(ctx context.Context, cycleState *type return result } -func (p *SchedulerProfile) runPostCyclePlugins(ctx context.Context, cycleState *types.CycleState, result *types.ProfileRunResult) { - loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) - for _, plugin := range p.postCyclePlugins { - loggerDebug.Info("Running post-cycle plugin", "plugin", plugin.TypedName()) - before := time.Now() - plugin.PostCycle(ctx, cycleState, result) - metrics.RecordPluginProcessingLatency(PostCycleExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before)) - loggerDebug.Info("Completed running post-cycle plugin successfully", "plugin", plugin.TypedName()) - } -} - func enforceScoreRange(score float64) float64 { if score < 0 { return 0 diff --git a/pkg/epp/scheduling/framework/scheduler_profile_test.go b/pkg/epp/scheduling/framework/scheduler_profile_test.go index 020223c00..f79b48de7 100644 --- a/pkg/epp/scheduling/framework/scheduler_profile_test.go +++ b/pkg/epp/scheduling/framework/scheduler_profile_test.go @@ -64,8 +64,7 @@ func TestSchedulePlugins(t *testing.T) { profile: NewSchedulerProfile(). WithFilters(tp1, tp2). WithScorers(NewWeightedScorer(tp1, 1), NewWeightedScorer(tp2, 1)). - WithPicker(pickerPlugin). - WithPostCyclePlugins(tp1, tp2), + WithPicker(pickerPlugin), input: []types.Pod{ &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, @@ -81,8 +80,7 @@ func TestSchedulePlugins(t *testing.T) { profile: NewSchedulerProfile(). WithFilters(tp1, tp2). WithScorers(NewWeightedScorer(tp1, 60), NewWeightedScorer(tp2, 40)). - WithPicker(pickerPlugin). - WithPostCyclePlugins(tp1, tp2), + WithPicker(pickerPlugin), input: []types.Pod{ &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, @@ -98,8 +96,7 @@ func TestSchedulePlugins(t *testing.T) { profile: NewSchedulerProfile(). WithFilters(tp1, tp_filterAll). WithScorers(NewWeightedScorer(tp1, 1), NewWeightedScorer(tp2, 1)). - WithPicker(pickerPlugin). - WithPostCyclePlugins(tp1, tp2), + WithPicker(pickerPlugin), input: []types.Pod{ &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, @@ -120,9 +117,6 @@ func TestSchedulePlugins(t *testing.T) { plugin.Scorer.(*testPlugin).reset() } test.profile.picker.(*testPlugin).reset() - for _, plugin := range test.profile.postCyclePlugins { - plugin.(*testPlugin).reset() - } // Initialize the scheduling context request := &types.LLMRequest{ @@ -179,12 +173,6 @@ func TestSchedulePlugins(t *testing.T) { if tp.WinnerPodScore != test.targetPodScore { t.Errorf("winner pod score %v, expected %v", tp.WinnerPodScore, test.targetPodScore) } - for _, plugin := range test.profile.postCyclePlugins { - tp, _ := plugin.(*testPlugin) - if tp.PostCycleCallCount != 1 { - t.Errorf("Plugin '%s' PostCycle() called %d times, expected 1", plugin.TypedName(), tp.PostCycleCallCount) - } - } }) } } @@ -193,7 +181,6 @@ func TestSchedulePlugins(t *testing.T) { var _ Filter = &testPlugin{} var _ Scorer = &testPlugin{} var _ Picker = &testPlugin{} -var _ PostCycle = &testPlugin{} // testPlugin is an implementation useful in unit tests. type testPlugin struct { @@ -204,7 +191,6 @@ type testPlugin struct { ScoreRes float64 FilterCallCount int FilterRes []k8stypes.NamespacedName - PostCycleCallCount int PickCallCount int NumOfPickerCandidates int PickRes k8stypes.NamespacedName @@ -246,15 +232,10 @@ func (tp *testPlugin) Pick(_ context.Context, _ *types.CycleState, scoredPods [] return &types.ProfileRunResult{TargetPods: winnerPods} } -func (tp *testPlugin) PostCycle(_ context.Context, _ *types.CycleState, res *types.ProfileRunResult) { - tp.PostCycleCallCount++ -} - func (tp *testPlugin) reset() { tp.FilterCallCount = 0 tp.ScoreCallCount = 0 tp.NumOfScoredPods = 0 - tp.PostCycleCallCount = 0 tp.PickCallCount = 0 tp.NumOfPickerCandidates = 0 } diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index aea9c5756..c197096ba 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -38,7 +38,7 @@ import ( func TestSchedule(t *testing.T) { kvCacheUtilizationScorer := scorer.NewKVCacheUtilizationScorer() queueingScorer := scorer.NewQueueScorer() - prefixCacheScorer := prefix.New(prefix.DefaultConfig) + prefixCacheScorer := prefix.New(context.Background(), prefix.DefaultConfig) loraAffinityScorer := scorer.NewLoraAffinityScorer() defaultProfile := framework.NewSchedulerProfile(). diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index 48cf889c1..b421f7f6f 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -1104,7 +1104,7 @@ func BeforeSuite() func() { kvCacheUtilizationScorer := scorer.NewKVCacheUtilizationScorer() queueingScorer := scorer.NewQueueScorer() - prefixCacheScorer := prefix.New(prefix.DefaultConfig) + prefixCacheScorer := prefix.New(context.Background(), prefix.DefaultConfig) loraAffinityScorer := scorer.NewLoraAffinityScorer() defaultProfile := framework.NewSchedulerProfile().