@@ -27,7 +27,7 @@ import (
2727 "github.com/cespare/xxhash/v2"
2828 k8stypes "k8s.io/apimachinery/pkg/types"
2929 "sigs.k8s.io/controller-runtime/pkg/log"
30-
30+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
3131 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix"
3232 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3333 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
@@ -128,12 +128,28 @@ func (s ServerID) String() string {
128128// compile-time type validation
129129var _ plugins.StateData = & SchedulingContextState {}
130130
131+ // compile-time type assertion
132+ var (
133+ _ framework.Scorer = & Plugin {}
134+ _ requestcontrol.PreRequest = & Plugin {}
135+ _ requestcontrol.ResponseReceived = & Plugin {}
136+ )
137+
131138// SchedulingContextState is the state of this plugin to be used during a scheduling cycle.
132139type SchedulingContextState struct {
133140 // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks.
134141 PrefixHashes []BlockHash
135142 // A map of server to its longest prefix cache match length.
136143 PrefixCacheServers map [ServerID ]int
144+ // SelectedServers is a list of servers that have been selected for the request.
145+ // Populated at PreRequest.
146+ SelectedServers []Server
147+ // StatsBlockSize is the block size used for metrics.
148+ // Populated at PreRequest.
149+ StatsBlockSize int
150+ // StatsMatchLen is the length of the longest prefix match for the target server.
151+ // Populated at PreRequest.
152+ StatsMatchLen int
137153}
138154
139155func (s * SchedulingContextState ) Clone () plugins.StateData {
@@ -143,19 +159,18 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
143159 for key , value := range s .PrefixCacheServers {
144160 prefixCacheServers [key ] = value
145161 }
162+ selectedServers := make ([]Server , len (s .SelectedServers ))
163+ copy (selectedServers , s .SelectedServers )
146164
147165 return & SchedulingContextState {
148166 PrefixHashes : prefixHashes ,
149167 PrefixCacheServers : prefixCacheServers ,
168+ SelectedServers : selectedServers ,
169+ StatsBlockSize : s .StatsBlockSize ,
170+ StatsMatchLen : s .StatsMatchLen ,
150171 }
151172}
152173
153- // compile-time type assertion
154- var (
155- _ framework.Scorer = & Plugin {}
156- _ requestcontrol.PreRequest = & Plugin {}
157- )
158-
159174// PrefixCachePluginFactory defines the factory function for Prefix plugin.
160175func PrefixCachePluginFactory (name string , rawParameters json.RawMessage , handle plugins.Handle ) (plugins.Plugin , error ) {
161176 parameters := DefaultConfig
@@ -282,29 +297,40 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
282297 }
283298
284299 state , err := plugins .ReadPluginStateKey [* SchedulingContextState ](p .pluginState , request .RequestId , plugins .StateKey (p .TypedName ().String ()))
285- p .pluginState .Delete (request .RequestId ) // delete the state explicitly after completing using it
286300 if err != nil {
287301 log .FromContext (ctx ).Error (err , "failed to read prefix plugin state" , "requestID" , request .RequestId )
288302 return
289303 }
290304
305+ // Update the state with selected servers and stats, then write it back.
306+ state .SelectedServers = servers
307+ state .StatsMatchLen = state .PrefixCacheServers [ServerID (targetEndpoint .GetMetadata ().NamespacedName )]
308+ state .StatsBlockSize = getBlockSize (primaryProfileResult .TargetEndpoints , p .config )
309+ p .pluginState .Write (request .RequestId , plugins .StateKey (p .TypedName ().String ()), state )
310+ }
311+
312+ func (p * Plugin ) ResponseReceived (ctx context.Context , request * types.LLMRequest , _ * requestcontrol.Response , targetPod * datalayer.EndpointMetadata ) {
313+ state , err := plugins .ReadPluginStateKey [* SchedulingContextState ](p .pluginState , request .RequestId , plugins .StateKey (p .TypedName ().String ()))
314+ p .pluginState .Delete (request .RequestId ) // delete the state explicitly after completing using it
315+ if err != nil {
316+ // This is expected if the request was shedded before PreRequest or if PreRequest failed.
317+ return
318+ }
319+
291320 // This function is just adding data, it does not need to block other operations.
292321 // TODO: look into making this entire function async, none of this needs to be done in-band
293322 // The PR that introduces this change is meant as a cherrypick, so it was minimally invasive.
294323 // WaitGroup is added to the Plugin struct to allow waiting in tests.
295324 p .wg .Add (1 )
296325 go func () {
297- for _ , s := range servers {
326+ for _ , s := range state . SelectedServers {
298327 p .indexer .Add (state .PrefixHashes , s )
299328 }
300329 p .wg .Done ()
301330 }()
302331
303332 total := len (state .PrefixHashes )
304- matchLen := state .PrefixCacheServers [ServerID (targetEndpoint .GetMetadata ().NamespacedName )]
305-
306- blockSize := getBlockSize (primaryProfileResult .TargetEndpoints , p .config )
307- metrics .RecordPrefixCacheMatch (matchLen * blockSize , total * blockSize )
333+ metrics .RecordPrefixCacheMatch (state .StatsMatchLen * state .StatsBlockSize , total * state .StatsBlockSize )
308334}
309335
310336func (p * Plugin ) makeServer (targetEndpoint types.Endpoint ) Server {
0 commit comments