diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index ce5ea197b..e553830ab 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -27,7 +27,7 @@ import ( "github.com/cespare/xxhash/v2" k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" - + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" @@ -128,12 +128,28 @@ func (s ServerID) String() string { // compile-time type validation var _ plugins.StateData = &SchedulingContextState{} +// compile-time type assertion +var ( + _ framework.Scorer = &Plugin{} + _ requestcontrol.PreRequest = &Plugin{} + _ requestcontrol.ResponseReceived = &Plugin{} +) + // SchedulingContextState is the state of this plugin to be used during a scheduling cycle. type SchedulingContextState struct { // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks. PrefixHashes []BlockHash // A map of server to its longest prefix cache match length. PrefixCacheServers map[ServerID]int + // SelectedServers is a list of servers that have been selected for the request. + // Populated at PreRequest. + SelectedServers []Server + // StatsBlockSize is the block size used for metrics. + // Populated at PreRequest. + StatsBlockSize int + // StatsMatchLen is the length of the longest prefix match for the target server. + // Populated at PreRequest. + StatsMatchLen int } func (s *SchedulingContextState) Clone() plugins.StateData { @@ -143,19 +159,18 @@ func (s *SchedulingContextState) Clone() plugins.StateData { for key, value := range s.PrefixCacheServers { prefixCacheServers[key] = value } + selectedServers := make([]Server, len(s.SelectedServers)) + copy(selectedServers, s.SelectedServers) return &SchedulingContextState{ PrefixHashes: prefixHashes, PrefixCacheServers: prefixCacheServers, + SelectedServers: selectedServers, + StatsBlockSize: s.StatsBlockSize, + StatsMatchLen: s.StatsMatchLen, } } -// compile-time type assertion -var ( - _ framework.Scorer = &Plugin{} - _ requestcontrol.PreRequest = &Plugin{} -) - // PrefixCachePluginFactory defines the factory function for Prefix plugin. func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { parameters := DefaultConfig @@ -282,29 +297,40 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche } state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) - p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it if err != nil { log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) return } + // Update the state with selected servers and stats, then write it back. + state.SelectedServers = servers + state.StatsMatchLen = state.PrefixCacheServers[ServerID(targetEndpoint.GetMetadata().NamespacedName)] + state.StatsBlockSize = getBlockSize(primaryProfileResult.TargetEndpoints, p.config) + p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state) +} + +func (p *Plugin) ResponseReceived(ctx context.Context, request *types.LLMRequest, _ *requestcontrol.Response, targetPod *datalayer.EndpointMetadata) { + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) + p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it + if err != nil { + // This is expected if the request was shedded before PreRequest or if PreRequest failed. + return + } + // This function is just adding data, it does not need to block other operations. // TODO: look into making this entire function async, none of this needs to be done in-band // The PR that introduces this change is meant as a cherrypick, so it was minimally invasive. // WaitGroup is added to the Plugin struct to allow waiting in tests. p.wg.Add(1) go func() { - for _, s := range servers { + for _, s := range state.SelectedServers { p.indexer.Add(state.PrefixHashes, s) } p.wg.Done() }() total := len(state.PrefixHashes) - matchLen := state.PrefixCacheServers[ServerID(targetEndpoint.GetMetadata().NamespacedName)] - - blockSize := getBlockSize(primaryProfileResult.TargetEndpoints, p.config) - metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize) + metrics.RecordPrefixCacheMatch(state.StatsMatchLen*state.StatsBlockSize, total*state.StatsBlockSize) } func (p *Plugin) makeServer(targetEndpoint types.Endpoint) Server { diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index a317f0c59..a01fae03c 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -81,6 +81,7 @@ func TestPrefixPluginCompletion(t *testing.T) { }, } plugin.PreRequest(context.Background(), req1, schedulingResult) + plugin.ResponseReceived(context.Background(), req1, nil, endpoint1.GetMetadata()) plugin.wg.Wait() // Second request doesn't share any prefix with first one. It should be added to the cache but @@ -113,6 +114,7 @@ func TestPrefixPluginCompletion(t *testing.T) { }, } plugin.PreRequest(context.Background(), req2, schedulingResult) + plugin.ResponseReceived(context.Background(), req2, nil, endpoint2.GetMetadata()) plugin.wg.Wait() // Third request shares partial prefix with first one. @@ -144,6 +146,7 @@ func TestPrefixPluginCompletion(t *testing.T) { }, } plugin.PreRequest(context.Background(), req3, schedulingResult) + plugin.ResponseReceived(context.Background(), req3, nil, endpoint1.GetMetadata()) plugin.wg.Wait() // 4th request is same as req3 except the model is different, still no match. @@ -174,6 +177,7 @@ func TestPrefixPluginCompletion(t *testing.T) { }, } plugin.PreRequest(context.Background(), req4, schedulingResult) + plugin.ResponseReceived(context.Background(), req4, nil, endpoint1.GetMetadata()) plugin.wg.Wait() // 5th request shares partial prefix with 3rd one. @@ -204,6 +208,7 @@ func TestPrefixPluginCompletion(t *testing.T) { }, } plugin.PreRequest(context.Background(), req5, schedulingResult) + plugin.ResponseReceived(context.Background(), req5, nil, endpoint1.GetMetadata()) plugin.wg.Wait() } @@ -284,6 +289,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { }, } plugin.PreRequest(context.Background(), req1, schedulingResult) + plugin.ResponseReceived(context.Background(), req1, nil, endpoint1.GetMetadata()) plugin.wg.Wait() // Second request adds assistant response and new user message (conversation grows) @@ -317,6 +323,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { // Simulate pod1 was picked again plugin.PreRequest(context.Background(), req2, schedulingResult) + plugin.ResponseReceived(context.Background(), req2, nil, endpoint1.GetMetadata()) plugin.wg.Wait() // Third request continues the conversation even further @@ -528,6 +535,7 @@ func TestPrefixPluginAutoTune(t *testing.T) { }, } plugin.PreRequest(context.Background(), req, schedulingResult) + plugin.ResponseReceived(context.Background(), req, nil, endpoint.GetMetadata()) plugin.wg.Wait() // Check indexer state @@ -563,6 +571,7 @@ func TestPrefixPluginAutoTune(t *testing.T) { }, } plugin.PreRequest(context.Background(), req, schedulingResult) + plugin.ResponseReceived(context.Background(), req, nil, endpoint.GetMetadata()) plugin.wg.Wait() assert.Contains(t, plugin.indexer.Pods(), ServerID(endpoint.GetMetadata().NamespacedName)) @@ -609,6 +618,7 @@ func TestPrepareRequestData(t *testing.T) { }, } plugin.PreRequest(context.Background(), req1, schedulingResult) + plugin.ResponseReceived(context.Background(), req1, nil, endpoint1.GetMetadata()) plugin.wg.Wait() // Second request that shares a prefix.