From 3232a17f08368e432adcbfd0d634e680d32bcef7 Mon Sep 17 00:00:00 2001 From: Nir Rozenbaum Date: Sun, 31 Aug 2025 14:41:35 +0300 Subject: [PATCH] prefix state Signed-off-by: Nir Rozenbaum --- .../framework/plugins/multi/prefix/plugin.go | 7 +++--- .../plugins/multi/prefix/plugin_test.go | 24 +++++++++---------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index ef33521b4..4e0416720 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -174,7 +174,7 @@ func (p *Plugin) WithName(name string) *Plugin { } // Score returns the scoring result for the given list of pods based on context. -func (p *Plugin) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { loggerTrace := log.FromContext(ctx).V(logutil.TRACE) // pre score step, hashing prompt and find longest prefix match. hashes := hashPrompt(ctx, request, p.config.HashBlockSize, p.config.MaxPrefixBlocksToMatch) @@ -183,7 +183,8 @@ func (p *Plugin) Score(ctx context.Context, _ *types.CycleState, request *types. PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), } - p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().Type), state) + cycleState.Write(plugins.StateKey(p.TypedName().String()), state) + p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state) loggerTrace.Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes) // calculate the scores of pods scores := make(map[types.Pod]float64, len(pods)) @@ -208,7 +209,7 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName] targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile - state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, PrefixCachePluginType) + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it if err != nil { log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index eee8f63f4..d6ec43cbb 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -52,8 +52,8 @@ func TestPrefixPlugin(t *testing.T) { TargetModel: "test-model1", Prompt: "aaaaaa", } - scores := plugin.Score(context.Background(), nil, req1, pods) - state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, PrefixCachePluginType) + scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 6, hash block size is 4, the last 2 characters are ignored. @@ -79,8 +79,8 @@ func TestPrefixPlugin(t *testing.T) { TargetModel: "test-model2", Prompt: "bbbbbb", } - scores = plugin.Score(context.Background(), nil, req2, pods) - state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, PrefixCachePluginType) + scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 6, hash block size is 4, the last 2 characters are ignored. @@ -105,8 +105,8 @@ func TestPrefixPlugin(t *testing.T) { TargetModel: "test-model1", Prompt: "aaaabbbb", } - scores = plugin.Score(context.Background(), nil, req3, pods) - state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, PrefixCachePluginType) + scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 8, hash block size is 4, so 2 hashes will be calculated. @@ -130,8 +130,8 @@ func TestPrefixPlugin(t *testing.T) { TargetModel: "test-model-new", Prompt: "aaaabbbb", } - scores = plugin.Score(context.Background(), nil, req4, pods) - state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, PrefixCachePluginType) + scores = plugin.Score(context.Background(), types.NewCycleState(), req4, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 8, hash block size is 4, so 2 hashes will be calculated. @@ -155,8 +155,8 @@ func TestPrefixPlugin(t *testing.T) { TargetModel: "test-model1", Prompt: "aaaabbbbcccc", } - scores = plugin.Score(context.Background(), nil, req5, pods) - state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, PrefixCachePluginType) + scores = plugin.Score(context.Background(), types.NewCycleState(), req5, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 12, hash block size is 4, so 3 hashes will be calculated. @@ -212,7 +212,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) { } // First cycle: simulate scheduling and insert prefix info into the cache - plugin.Score(context.Background(), nil, req, pods) + plugin.Score(context.Background(), types.NewCycleState(), req, pods) schedulingResult := &types.SchedulingResult{ PrimaryProfileName: "default", ProfileResults: map[string]*types.ProfileRunResult{ @@ -222,7 +222,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) { plugin.PreRequest(context.Background(), req, schedulingResult, 0) // Second cycle: validate internal state - state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, PrefixCachePluginType) + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(b, err) expectedHashes := int(math.Min(float64(maxPrefixBlocks), float64(len(req.Prompt)/blockSize))) assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect")