@@ -28,6 +28,7 @@ import (
2828
2929 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3030 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
31+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
3132 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
3233 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3334 logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
@@ -73,9 +74,10 @@ type Config struct {
7374}
7475
7576type Plugin struct {
76- Config
77- typedName plugins.TypedName
78- indexer Indexer
77+ typedName plugins.TypedName
78+ config Config
79+ pluginState * plugins.PluginState
80+ indexer Indexer
7981}
8082
8183// podSet holds an pods servers that may have a specific prefix hash.
@@ -122,10 +124,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
122124
123125// compile-time type assertion
124126var _ framework.Scorer = & Plugin {}
125- var _ framework. PostCycle = & Plugin {}
127+ var _ requestcontrol. PreRequest = & Plugin {}
126128
127129// PrefixCachePluginFactory defines the factory function for Prefix plugin.
128- func PrefixCachePluginFactory (name string , rawParameters json.RawMessage , _ plugins.Handle ) (plugins.Plugin , error ) {
130+ func PrefixCachePluginFactory (name string , rawParameters json.RawMessage , handle plugins.Handle ) (plugins.Plugin , error ) {
129131 parameters := Config {
130132 HashBlockSize : DefaultHashBlockSize ,
131133 MaxPrefixBlocksToMatch : DefaultMaxPrefixBlocks ,
@@ -138,11 +140,11 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plug
138140 }
139141 }
140142
141- return New (parameters ).WithName (name ), nil
143+ return New (handle . Context (), parameters ).WithName (name ), nil
142144}
143145
144146// New initializes a new prefix Plugin and returns its pointer.
145- func New (config Config ) * Plugin {
147+ func New (ctx context. Context , config Config ) * Plugin {
146148 capacity := config .LRUCapacityPerServer
147149 if capacity <= 0 {
148150 capacity = DefaultLRUCapacityPerServer
@@ -153,34 +155,35 @@ func New(config Config) *Plugin {
153155 }
154156
155157 return & Plugin {
156- typedName : plugins.TypedName {Type : PrefixCachePluginType , Name : PrefixCachePluginType },
157- Config : config ,
158- indexer : newIndexer (capacity ),
158+ typedName : plugins.TypedName {Type : PrefixCachePluginType , Name : PrefixCachePluginType },
159+ config : config ,
160+ pluginState : plugins .NewPluginState (ctx ),
161+ indexer : newIndexer (capacity ),
159162 }
160163}
161164
162165// TypedName returns the type and name tuple of this plugin instance.
163- func (m * Plugin ) TypedName () plugins.TypedName {
164- return m .typedName
166+ func (p * Plugin ) TypedName () plugins.TypedName {
167+ return p .typedName
165168}
166169
167170// WithName sets the name of the plugin.
168- func (m * Plugin ) WithName (name string ) * Plugin {
169- m .typedName .Name = name
170- return m
171+ func (p * Plugin ) WithName (name string ) * Plugin {
172+ p .typedName .Name = name
173+ return p
171174}
172175
173176// Score returns the scoring result for the given list of pods based on context.
174- func (m * Plugin ) Score (ctx context.Context , cycleState * types.CycleState , request * types.LLMRequest , pods []types.Pod ) map [types.Pod ]float64 {
177+ func (p * Plugin ) Score (ctx context.Context , _ * types.CycleState , request * types.LLMRequest , pods []types.Pod ) map [types.Pod ]float64 {
175178 loggerTrace := log .FromContext (ctx ).V (logutil .TRACE )
176179 // pre score step, hashing prompt and find longest prefix match.
177- hashes := hashPrompt (ctx , request , m . HashBlockSize , m .MaxPrefixBlocksToMatch )
180+ hashes := hashPrompt (ctx , request , p . config . HashBlockSize , p . config .MaxPrefixBlocksToMatch )
178181 state := & SchedulingContextState {
179182 PrefixHashes : hashes ,
180- PrefixCacheServers : m .matchLongestPrefix (ctx , hashes ),
183+ PrefixCacheServers : p .matchLongestPrefix (ctx , hashes ),
181184 }
182185
183- cycleState . Write (plugins .StateKey (m .TypedName ().Type ), state )
186+ p . pluginState . Write (request . RequestId , plugins .StateKey (p .TypedName ().Type ), state )
184187 loggerTrace .Info (fmt .Sprintf ("cached servers: %+v" , state .PrefixCacheServers ), "hashes" , state .PrefixHashes )
185188 // calculate the scores of pods
186189 scores := make (map [types.Pod ]float64 , len (pods ))
@@ -200,31 +203,34 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
200203 return scores
201204}
202205
203- // PostCycle records in the plugin cache the result of the scheduling selection.
204- func (m * Plugin ) PostCycle (ctx context.Context , cycleState * types.CycleState , res * types.ProfileRunResult ) {
205- targetPod := res .TargetPods [0 ].GetPod ()
206- state , err := types .ReadCycleStateKey [* SchedulingContextState ](cycleState , PrefixCachePluginType )
206+ // PreRequest records in the plugin cache the result of the scheduling selection.
207+ func (p * Plugin ) PreRequest (ctx context.Context , request * types.LLMRequest , schedulingResult * types.SchedulingResult , _ int ) {
208+ primaryProfileResult := schedulingResult .ProfileResults [schedulingResult .PrimaryProfileName ]
209+ targetPod := primaryProfileResult .TargetPods [0 ].GetPod () // get the first pod of the primary profile
210+
211+ state , err := plugins .ReadPluginStateKey [* SchedulingContextState ](p .pluginState , request .RequestId , PrefixCachePluginType )
212+ p .pluginState .Delete (request .RequestId ) // delete the state explicitly after completing using it
207213 if err != nil {
208- log .FromContext (ctx ).Error (err , "failed to read prefix plugin cycle state" )
214+ log .FromContext (ctx ).Error (err , "failed to read prefix plugin state" , "requestID" , request . RequestId )
209215 return
210216 }
211217
212- m .indexer .Add (state .PrefixHashes , ServerID (targetPod .NamespacedName ))
218+ p .indexer .Add (state .PrefixHashes , ServerID (targetPod .NamespacedName ))
213219
214220 total := len (state .PrefixHashes )
215221 matchLen := state .PrefixCacheServers [ServerID (targetPod .NamespacedName )]
216- metrics .RecordPrefixCacheMatch (matchLen * m . HashBlockSize , total * m .HashBlockSize )
222+ metrics .RecordPrefixCacheMatch (matchLen * p . config . HashBlockSize , total * p . config .HashBlockSize )
217223}
218224
219225// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
220- func (m * Plugin ) matchLongestPrefix (ctx context.Context , hashes []BlockHash ) map [ServerID ]int {
226+ func (p * Plugin ) matchLongestPrefix (ctx context.Context , hashes []BlockHash ) map [ServerID ]int {
221227 loggerTrace := log .FromContext (ctx ).V (logutil .TRACE )
222228 res := make (map [ServerID ]int )
223229 // Use a greedy strategy to search from the longest prefix.
224230 // NOTE: It's possible to further optimize this with a binary search.
225231 for i := 0 ; i < len (hashes ); i ++ {
226232 hash := hashes [i ]
227- cachedServers := m .indexer .Get (hash )
233+ cachedServers := p .indexer .Get (hash )
228234 if len (cachedServers ) == 0 {
229235 break
230236 } else {
0 commit comments