Skip to content

Commit 7da5041

Browse files
committed
move prefix cache update from prerequest to responseReceived.
1 parent d5ec68a commit 7da5041

File tree

2 files changed

+49
-13
lines changed

2 files changed

+49
-13
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import (
2727
"github.com/cespare/xxhash/v2"
2828
k8stypes "k8s.io/apimachinery/pkg/types"
2929
"sigs.k8s.io/controller-runtime/pkg/log"
30-
30+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
3131
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix"
3232
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3333
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
@@ -128,12 +128,28 @@ func (s ServerID) String() string {
128128
// compile-time type validation
129129
var _ plugins.StateData = &SchedulingContextState{}
130130

131+
// compile-time type assertion
132+
var (
133+
_ framework.Scorer = &Plugin{}
134+
_ requestcontrol.PreRequest = &Plugin{}
135+
_ requestcontrol.ResponseReceived = &Plugin{}
136+
)
137+
131138
// SchedulingContextState is the state of this plugin to be used during a scheduling cycle.
132139
type SchedulingContextState struct {
133140
// PrefixHashes is a list of prefix hashes of the request prompt broken into blocks.
134141
PrefixHashes []BlockHash
135142
// A map of server to its longest prefix cache match length.
136143
PrefixCacheServers map[ServerID]int
144+
// SelectedServers is a list of servers that have been selected for the request.
145+
// Populated at PreRequest.
146+
SelectedServers []Server
147+
// StatsBlockSize is the block size used for metrics.
148+
// Populated at PreRequest.
149+
StatsBlockSize int
150+
// StatsMatchLen is the length of the longest prefix match for the target server.
151+
// Populated at PreRequest.
152+
StatsMatchLen int
137153
}
138154

139155
func (s *SchedulingContextState) Clone() plugins.StateData {
@@ -143,19 +159,18 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
143159
for key, value := range s.PrefixCacheServers {
144160
prefixCacheServers[key] = value
145161
}
162+
selectedServers := make([]Server, len(s.SelectedServers))
163+
copy(selectedServers, s.SelectedServers)
146164

147165
return &SchedulingContextState{
148166
PrefixHashes: prefixHashes,
149167
PrefixCacheServers: prefixCacheServers,
168+
SelectedServers: selectedServers,
169+
StatsBlockSize: s.StatsBlockSize,
170+
StatsMatchLen: s.StatsMatchLen,
150171
}
151172
}
152173

153-
// compile-time type assertion
154-
var (
155-
_ framework.Scorer = &Plugin{}
156-
_ requestcontrol.PreRequest = &Plugin{}
157-
)
158-
159174
// PrefixCachePluginFactory defines the factory function for Prefix plugin.
160175
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
161176
parameters := DefaultConfig
@@ -282,29 +297,40 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
282297
}
283298

284299
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
285-
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
286300
if err != nil {
287301
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
288302
return
289303
}
290304

305+
// Update the state with selected servers and stats, then write it back.
306+
state.SelectedServers = servers
307+
state.StatsMatchLen = state.PrefixCacheServers[ServerID(targetEndpoint.GetMetadata().NamespacedName)]
308+
state.StatsBlockSize = getBlockSize(primaryProfileResult.TargetEndpoints, p.config)
309+
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state)
310+
}
311+
312+
func (p *Plugin) ResponseReceived(ctx context.Context, request *types.LLMRequest, _ *requestcontrol.Response, targetPod *datalayer.EndpointMetadata) {
313+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
314+
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
315+
if err != nil {
316+
// This is expected if the request was shedded before PreRequest or if PreRequest failed.
317+
return
318+
}
319+
291320
// This function is just adding data, it does not need to block other operations.
292321
// TODO: look into making this entire function async, none of this needs to be done in-band
293322
// The PR that introduces this change is meant as a cherrypick, so it was minimally invasive.
294323
// WaitGroup is added to the Plugin struct to allow waiting in tests.
295324
p.wg.Add(1)
296325
go func() {
297-
for _, s := range servers {
326+
for _, s := range state.SelectedServers {
298327
p.indexer.Add(state.PrefixHashes, s)
299328
}
300329
p.wg.Done()
301330
}()
302331

303332
total := len(state.PrefixHashes)
304-
matchLen := state.PrefixCacheServers[ServerID(targetEndpoint.GetMetadata().NamespacedName)]
305-
306-
blockSize := getBlockSize(primaryProfileResult.TargetEndpoints, p.config)
307-
metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize)
333+
metrics.RecordPrefixCacheMatch(state.StatsMatchLen*state.StatsBlockSize, total*state.StatsBlockSize)
308334
}
309335

310336
func (p *Plugin) makeServer(targetEndpoint types.Endpoint) Server {

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
8181
},
8282
}
8383
plugin.PreRequest(context.Background(), req1, schedulingResult)
84+
plugin.ResponseReceived(context.Background(), req1, nil, endpoint1.GetMetadata())
8485
plugin.wg.Wait()
8586

8687
// Second request doesn't share any prefix with first one. It should be added to the cache but
@@ -113,6 +114,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
113114
},
114115
}
115116
plugin.PreRequest(context.Background(), req2, schedulingResult)
117+
plugin.ResponseReceived(context.Background(), req2, nil, endpoint2.GetMetadata())
116118
plugin.wg.Wait()
117119

118120
// Third request shares partial prefix with first one.
@@ -144,6 +146,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
144146
},
145147
}
146148
plugin.PreRequest(context.Background(), req3, schedulingResult)
149+
plugin.ResponseReceived(context.Background(), req3, nil, endpoint1.GetMetadata())
147150
plugin.wg.Wait()
148151

149152
// 4th request is same as req3 except the model is different, still no match.
@@ -174,6 +177,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
174177
},
175178
}
176179
plugin.PreRequest(context.Background(), req4, schedulingResult)
180+
plugin.ResponseReceived(context.Background(), req4, nil, endpoint1.GetMetadata())
177181
plugin.wg.Wait()
178182

179183
// 5th request shares partial prefix with 3rd one.
@@ -204,6 +208,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
204208
},
205209
}
206210
plugin.PreRequest(context.Background(), req5, schedulingResult)
211+
plugin.ResponseReceived(context.Background(), req5, nil, endpoint1.GetMetadata())
207212
plugin.wg.Wait()
208213
}
209214

@@ -284,6 +289,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
284289
},
285290
}
286291
plugin.PreRequest(context.Background(), req1, schedulingResult)
292+
plugin.ResponseReceived(context.Background(), req1, nil, endpoint1.GetMetadata())
287293
plugin.wg.Wait()
288294

289295
// Second request adds assistant response and new user message (conversation grows)
@@ -317,6 +323,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
317323

318324
// Simulate pod1 was picked again
319325
plugin.PreRequest(context.Background(), req2, schedulingResult)
326+
plugin.ResponseReceived(context.Background(), req2, nil, endpoint1.GetMetadata())
320327
plugin.wg.Wait()
321328

322329
// Third request continues the conversation even further
@@ -528,6 +535,7 @@ func TestPrefixPluginAutoTune(t *testing.T) {
528535
},
529536
}
530537
plugin.PreRequest(context.Background(), req, schedulingResult)
538+
plugin.ResponseReceived(context.Background(), req, nil, endpoint.GetMetadata())
531539
plugin.wg.Wait()
532540

533541
// Check indexer state
@@ -563,6 +571,7 @@ func TestPrefixPluginAutoTune(t *testing.T) {
563571
},
564572
}
565573
plugin.PreRequest(context.Background(), req, schedulingResult)
574+
plugin.ResponseReceived(context.Background(), req, nil, endpoint.GetMetadata())
566575
plugin.wg.Wait()
567576

568577
assert.Contains(t, plugin.indexer.Pods(), ServerID(endpoint.GetMetadata().NamespacedName))
@@ -609,6 +618,7 @@ func TestPrepareRequestData(t *testing.T) {
609618
},
610619
}
611620
plugin.PreRequest(context.Background(), req1, schedulingResult)
621+
plugin.ResponseReceived(context.Background(), req1, nil, endpoint1.GetMetadata())
612622
plugin.wg.Wait()
613623

614624
// Second request that shares a prefix.

0 commit comments

Comments
 (0)