Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pkg/epp/config/loader/configloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,6 @@ func (f *test1) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMReque

// compile-time type validation
var _ framework.Scorer = &test2{}
var _ framework.PostCycle = &test2{}

type test2 struct {
typedName plugins.TypedName
Expand All @@ -683,8 +682,6 @@ func (m *test2) Score(_ context.Context, _ *types.CycleState, _ *types.LLMReques
return map[types.Pod]float64{}
}

func (m *test2) PostCycle(_ context.Context, _ *types.CycleState, _ *types.ProfileRunResult) {}

// compile-time type validation
var _ framework.Picker = &testPicker{}

Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/config/loader/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func setDefaultsPhaseTwo(cfg *configapi.EndpointPickerConfig, handle plugins.Han
thePlugins := []configapi.SchedulingPlugin{}
for pluginName, plugin := range allPlugins {
switch plugin.(type) {
case framework.Filter, framework.Picker, framework.PostCycle, framework.Scorer:
case framework.Filter, framework.Scorer, framework.Picker:
thePlugins = append(thePlugins, configapi.SchedulingPlugin{PluginRef: pluginName})
}
}
Expand Down
8 changes: 7 additions & 1 deletion pkg/epp/plugins/plugin_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,14 @@ func TestPluginState_ReadWrite(t *testing.T) {
assert.True(t, ok, "should be able to cast to pluginTestData")
assert.Equal(t, data1, td.value)

// Delete the req2 data and verify it's removed
// Delete the req2 data and verify content that was read before is still valid
readData, err = state.Read(req2, key)
assert.NoError(t, err)
state.Delete(req2)
td, ok = readData.(*pluginTestData)
assert.True(t, ok, "should be able to cast to pluginTestData")
assert.Equal(t, data2, td.value)
// try to read again aftet deletion, verify error
readData, err = state.Read(req2, key)
assert.Equal(t, ErrNotFound, err)
assert.Nil(t, readData, "expected no data after delete")
Expand Down
8 changes: 0 additions & 8 deletions pkg/epp/scheduling/framework/plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ const (
FilterExtensionPoint = "Filter"
ScorerExtensionPoint = "Scorer"
PickerExtensionPoint = "Picker"
PostCycleExtensionPoint = "PostCycle"
ProcessProfilesResultsExtensionPoint = "ProcessProfilesResults"
)

Expand Down Expand Up @@ -69,10 +68,3 @@ type Picker interface {
plugins.Plugin
Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult
}

// PostCycle is called by the scheduler after it selects a targetPod for the request in the SchedulerProfile cycle.
// DEPRECATED - do not use PostCycle. this is in the process of deprecation.
type PostCycle interface {
plugins.Plugin
PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult)
}
62 changes: 34 additions & 28 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
Expand Down Expand Up @@ -73,9 +74,10 @@ type Config struct {
}

type Plugin struct {
Config
typedName plugins.TypedName
indexer Indexer
typedName plugins.TypedName
config Config
pluginState *plugins.PluginState
indexer Indexer
}

// podSet holds an pods servers that may have a specific prefix hash.
Expand Down Expand Up @@ -122,10 +124,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {

// compile-time type assertion
var _ framework.Scorer = &Plugin{}
var _ framework.PostCycle = &Plugin{}
var _ requestcontrol.PreRequest = &Plugin{}

// PrefixCachePluginFactory defines the factory function for Prefix plugin.
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
parameters := Config{
HashBlockSize: DefaultHashBlockSize,
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
Expand All @@ -138,11 +140,11 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plug
}
}

return New(parameters).WithName(name), nil
return New(handle.Context(), parameters).WithName(name), nil
}

// New initializes a new prefix Plugin and returns its pointer.
func New(config Config) *Plugin {
func New(ctx context.Context, config Config) *Plugin {
capacity := config.LRUCapacityPerServer
if capacity <= 0 {
capacity = DefaultLRUCapacityPerServer
Expand All @@ -153,34 +155,35 @@ func New(config Config) *Plugin {
}

return &Plugin{
typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType},
Config: config,
indexer: newIndexer(capacity),
typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType},
config: config,
pluginState: plugins.NewPluginState(ctx),
indexer: newIndexer(capacity),
}
}

// TypedName returns the type and name tuple of this plugin instance.
func (m *Plugin) TypedName() plugins.TypedName {
return m.typedName
func (p *Plugin) TypedName() plugins.TypedName {
return p.typedName
}

// WithName sets the name of the plugin.
func (m *Plugin) WithName(name string) *Plugin {
m.typedName.Name = name
return m
func (p *Plugin) WithName(name string) *Plugin {
p.typedName.Name = name
return p
}

// Score returns the scoring result for the given list of pods based on context.
func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
func (p *Plugin) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
// pre score step, hashing prompt and find longest prefix match.
hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
hashes := hashPrompt(ctx, request, p.config.HashBlockSize, p.config.MaxPrefixBlocksToMatch)
state := &SchedulingContextState{
PrefixHashes: hashes,
PrefixCacheServers: m.matchLongestPrefix(ctx, hashes),
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
}

cycleState.Write(plugins.StateKey(m.TypedName().Type), state)
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().Type), state)
loggerTrace.Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes)
// calculate the scores of pods
scores := make(map[types.Pod]float64, len(pods))
Expand All @@ -200,31 +203,34 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
return scores
}

// PostCycle records in the plugin cache the result of the scheduling selection.
func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) {
targetPod := res.TargetPods[0].GetPod()
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
// PreRequest records in the plugin cache the result of the scheduling selection.
func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, _ int) {
primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile

state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, PrefixCachePluginType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also clean up the per-request state after reading, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, fixed it. good catch!

my original intention was to do explicitly delete on successful request handling and let the cleanup go routine clean entries only of requests that failed in the middle.

I think it would be useful to add a generic ReadAndDeletePluginStateKey function to avoid this confusion, but will leave that for a follow up to keep this PR tightly scoped on the PostCycle deprecation.

p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
if err != nil {
log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state")
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
return
}

m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName))
p.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName))

total := len(state.PrefixHashes)
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]
metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize)
metrics.RecordPrefixCacheMatch(matchLen*p.config.HashBlockSize, total*p.config.HashBlockSize)
}

// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int {
func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
res := make(map[ServerID]int)
// Use a greedy strategy to search from the longest prefix.
// NOTE: It's possible to further optimize this with a binary search.
for i := 0; i < len(hashes); i++ {
hash := hashes[i]
cachedServers := m.indexer.Get(hash)
cachedServers := p.indexer.Get(hash)
if len(cachedServers) == 0 {
break
} else {
Expand Down
Loading