Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
"github.com/cespare/xxhash/v2"
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/log"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
Expand Down Expand Up @@ -128,12 +128,28 @@ func (s ServerID) String() string {
// compile-time type validation
var _ plugins.StateData = &SchedulingContextState{}

// compile-time type assertion
var (
_ framework.Scorer = &Plugin{}
_ requestcontrol.PreRequest = &Plugin{}
_ requestcontrol.ResponseReceived = &Plugin{}
Copy link
Contributor

@nirrozenbaum nirrozenbaum Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just as a clarification for any reviewer - the extension point ResponseReceived is called when the response HEADERS are received, it is NOT waiting for the body to arrive.
(naming selection here is bad, we have an open issue to update the extension point name).

// ResponseReceived is called by the director after the response headers are successfully received
// which indicates the beginning of the response handling by the model server.
// The given pod argument is the pod that served the request.
type ResponseReceived interface {
plugins.Plugin
ResponseReceived(ctx context.Context, request *types.LLMRequest, response *Response, targetEndpoint *datalayer.EndpointMetadata)
}

)

// SchedulingContextState is the state of this plugin to be used during a scheduling cycle.
type SchedulingContextState struct {
// PrefixHashes is a list of prefix hashes of the request prompt broken into blocks.
PrefixHashes []BlockHash
// A map of server to its longest prefix cache match length.
PrefixCacheServers map[ServerID]int
// SelectedServers is a list of servers that have been selected for the request.
// Populated at PreRequest.
SelectedServers []Server
// StatsBlockSize is the block size used for metrics.
// Populated at PreRequest.
StatsBlockSize int
// StatsMatchLen is the length of the longest prefix match for the target server.
// Populated at PreRequest.
StatsMatchLen int
}

func (s *SchedulingContextState) Clone() plugins.StateData {
Expand All @@ -143,19 +159,18 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
for key, value := range s.PrefixCacheServers {
prefixCacheServers[key] = value
}
selectedServers := make([]Server, len(s.SelectedServers))
copy(selectedServers, s.SelectedServers)

return &SchedulingContextState{
PrefixHashes: prefixHashes,
PrefixCacheServers: prefixCacheServers,
SelectedServers: selectedServers,
StatsBlockSize: s.StatsBlockSize,
StatsMatchLen: s.StatsMatchLen,
}
}

// compile-time type assertion
var (
_ framework.Scorer = &Plugin{}
_ requestcontrol.PreRequest = &Plugin{}
)

// PrefixCachePluginFactory defines the factory function for Prefix plugin.
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
parameters := DefaultConfig
Expand Down Expand Up @@ -282,29 +297,40 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
}

state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
if err != nil {
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
return
}

// Update the state with selected servers and stats, then write it back.
state.SelectedServers = servers
state.StatsMatchLen = state.PrefixCacheServers[ServerID(targetEndpoint.GetMetadata().NamespacedName)]
state.StatsBlockSize = getBlockSize(primaryProfileResult.TargetEndpoints, p.config)
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state)
}

func (p *Plugin) ResponseReceived(ctx context.Context, request *types.LLMRequest, _ *requestcontrol.Response, targetPod *datalayer.EndpointMetadata) {
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
if err != nil {
// This is expected if the request was shedded before PreRequest or if PreRequest failed.
return
}

// This function is just adding data, it does not need to block other operations.
// TODO: look into making this entire function async, none of this needs to be done in-band
// The PR that introduces this change is meant as a cherrypick, so it was minimally invasive.
// WaitGroup is added to the Plugin struct to allow waiting in tests.
p.wg.Add(1)
go func() {
for _, s := range servers {
for _, s := range state.SelectedServers {
p.indexer.Add(state.PrefixHashes, s)
}
p.wg.Done()
}()

total := len(state.PrefixHashes)
matchLen := state.PrefixCacheServers[ServerID(targetEndpoint.GetMetadata().NamespacedName)]

blockSize := getBlockSize(primaryProfileResult.TargetEndpoints, p.config)
metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize)
metrics.RecordPrefixCacheMatch(state.StatsMatchLen*state.StatsBlockSize, total*state.StatsBlockSize)
}

func (p *Plugin) makeServer(targetEndpoint types.Endpoint) Server {
Expand Down
10 changes: 10 additions & 0 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
},
}
plugin.PreRequest(context.Background(), req1, schedulingResult)
plugin.ResponseReceived(context.Background(), req1, nil, endpoint1.GetMetadata())
plugin.wg.Wait()

// Second request doesn't share any prefix with first one. It should be added to the cache but
Expand Down Expand Up @@ -113,6 +114,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
},
}
plugin.PreRequest(context.Background(), req2, schedulingResult)
plugin.ResponseReceived(context.Background(), req2, nil, endpoint2.GetMetadata())
plugin.wg.Wait()

// Third request shares partial prefix with first one.
Expand Down Expand Up @@ -144,6 +146,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
},
}
plugin.PreRequest(context.Background(), req3, schedulingResult)
plugin.ResponseReceived(context.Background(), req3, nil, endpoint1.GetMetadata())
plugin.wg.Wait()

// 4th request is same as req3 except the model is different, still no match.
Expand Down Expand Up @@ -174,6 +177,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
},
}
plugin.PreRequest(context.Background(), req4, schedulingResult)
plugin.ResponseReceived(context.Background(), req4, nil, endpoint1.GetMetadata())
plugin.wg.Wait()

// 5th request shares partial prefix with 3rd one.
Expand Down Expand Up @@ -204,6 +208,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
},
}
plugin.PreRequest(context.Background(), req5, schedulingResult)
plugin.ResponseReceived(context.Background(), req5, nil, endpoint1.GetMetadata())
plugin.wg.Wait()
}

Expand Down Expand Up @@ -284,6 +289,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
},
}
plugin.PreRequest(context.Background(), req1, schedulingResult)
plugin.ResponseReceived(context.Background(), req1, nil, endpoint1.GetMetadata())
plugin.wg.Wait()

// Second request adds assistant response and new user message (conversation grows)
Expand Down Expand Up @@ -317,6 +323,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {

// Simulate pod1 was picked again
plugin.PreRequest(context.Background(), req2, schedulingResult)
plugin.ResponseReceived(context.Background(), req2, nil, endpoint1.GetMetadata())
plugin.wg.Wait()

// Third request continues the conversation even further
Expand Down Expand Up @@ -528,6 +535,7 @@ func TestPrefixPluginAutoTune(t *testing.T) {
},
}
plugin.PreRequest(context.Background(), req, schedulingResult)
plugin.ResponseReceived(context.Background(), req, nil, endpoint.GetMetadata())
plugin.wg.Wait()

// Check indexer state
Expand Down Expand Up @@ -563,6 +571,7 @@ func TestPrefixPluginAutoTune(t *testing.T) {
},
}
plugin.PreRequest(context.Background(), req, schedulingResult)
plugin.ResponseReceived(context.Background(), req, nil, endpoint.GetMetadata())
plugin.wg.Wait()

assert.Contains(t, plugin.indexer.Pods(), ServerID(endpoint.GetMetadata().NamespacedName))
Expand Down Expand Up @@ -609,6 +618,7 @@ func TestPrepareRequestData(t *testing.T) {
},
}
plugin.PreRequest(context.Background(), req1, schedulingResult)
plugin.ResponseReceived(context.Background(), req1, nil, endpoint1.GetMetadata())
plugin.wg.Wait()

// Second request that shares a prefix.
Expand Down