Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pkg/epp/config/loader/configloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,6 @@ func (f *test1) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMReque

// compile-time type validation
var _ framework.Scorer = &test2{}
var _ framework.PostCycle = &test2{}

type test2 struct {
typedName plugins.TypedName
Expand All @@ -683,8 +682,6 @@ func (m *test2) Score(_ context.Context, _ *types.CycleState, _ *types.LLMReques
return map[types.Pod]float64{}
}

func (m *test2) PostCycle(_ context.Context, _ *types.CycleState, _ *types.ProfileRunResult) {}

// compile-time type validation
var _ framework.Picker = &testPicker{}

Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/config/loader/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func setDefaultsPhaseTwo(cfg *configapi.EndpointPickerConfig, handle plugins.Han
thePlugins := []configapi.SchedulingPlugin{}
for pluginName, plugin := range allPlugins {
switch plugin.(type) {
case framework.Filter, framework.Picker, framework.PostCycle, framework.Scorer:
case framework.Filter, framework.Scorer, framework.Picker:
thePlugins = append(thePlugins, configapi.SchedulingPlugin{PluginRef: pluginName})
}
}
Expand Down
8 changes: 0 additions & 8 deletions pkg/epp/scheduling/framework/plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ const (
FilterExtensionPoint = "Filter"
ScorerExtensionPoint = "Scorer"
PickerExtensionPoint = "Picker"
PostCycleExtensionPoint = "PostCycle"
ProcessProfilesResultsExtensionPoint = "ProcessProfilesResults"
)

Expand Down Expand Up @@ -69,10 +68,3 @@ type Picker interface {
plugins.Plugin
Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult
}

// PostCycle is called by the scheduler after it selects a targetPod for the request in the SchedulerProfile cycle.
// DEPRECATED - do not use PostCycle. this is in the process of deprecation.
type PostCycle interface {
plugins.Plugin
PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult)
}
61 changes: 33 additions & 28 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
Expand Down Expand Up @@ -73,9 +74,10 @@ type Config struct {
}

type Plugin struct {
Config
typedName plugins.TypedName
indexer Indexer
typedName plugins.TypedName
config Config
pluginState *plugins.PluginState
indexer Indexer
}

// podSet holds an pods servers that may have a specific prefix hash.
Expand Down Expand Up @@ -122,10 +124,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {

// compile-time type assertion
var _ framework.Scorer = &Plugin{}
var _ framework.PostCycle = &Plugin{}
var _ requestcontrol.PreRequest = &Plugin{}

// PrefixCachePluginFactory defines the factory function for Prefix plugin.
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
parameters := Config{
HashBlockSize: DefaultHashBlockSize,
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
Expand All @@ -138,11 +140,11 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plug
}
}

return New(parameters).WithName(name), nil
return New(handle.Context(), parameters).WithName(name), nil
}

// New initializes a new prefix Plugin and returns its pointer.
func New(config Config) *Plugin {
func New(ctx context.Context, config Config) *Plugin {
capacity := config.LRUCapacityPerServer
if capacity <= 0 {
capacity = DefaultLRUCapacityPerServer
Expand All @@ -153,34 +155,35 @@ func New(config Config) *Plugin {
}

return &Plugin{
typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType},
Config: config,
indexer: newIndexer(capacity),
typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType},
config: config,
pluginState: plugins.NewPluginState(ctx),
indexer: newIndexer(capacity),
}
}

// TypedName returns the type and name tuple of this plugin instance.
func (m *Plugin) TypedName() plugins.TypedName {
return m.typedName
func (p *Plugin) TypedName() plugins.TypedName {
return p.typedName
}

// WithName sets the name of the plugin.
func (m *Plugin) WithName(name string) *Plugin {
m.typedName.Name = name
return m
func (p *Plugin) WithName(name string) *Plugin {
p.typedName.Name = name
return p
}

// Score returns the scoring result for the given list of pods based on context.
func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
func (p *Plugin) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
// pre score step, hashing prompt and find longest prefix match.
hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
hashes := hashPrompt(ctx, request, p.config.HashBlockSize, p.config.MaxPrefixBlocksToMatch)
state := &SchedulingContextState{
PrefixHashes: hashes,
PrefixCacheServers: m.matchLongestPrefix(ctx, hashes),
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
}

cycleState.Write(plugins.StateKey(m.TypedName().Type), state)
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().Type), state)
loggerTrace.Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes)
// calculate the scores of pods
scores := make(map[types.Pod]float64, len(pods))
Expand All @@ -200,31 +203,33 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
return scores
}

// PostCycle records in the plugin cache the result of the scheduling selection.
func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) {
targetPod := res.TargetPods[0].GetPod()
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
// PreRequest records in the plugin cache the result of the scheduling selection.
func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, _ int) {
primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile

state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, PrefixCachePluginType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also clean up the per-request state after reading, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, fixed it. good catch!

my original intention was to do explicitly delete on successful request handling and let the cleanup go routine clean entries only of requests that failed in the middle.

I think it would be useful to add a generic ReadAndDeletePluginStateKey function to avoid this confusion, but will leave that for a follow up to keep this PR tightly scoped on the PostCycle deprecation.

if err != nil {
log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state")
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
return
}

m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName))
p.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName))

total := len(state.PrefixHashes)
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]
metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize)
metrics.RecordPrefixCacheMatch(matchLen*p.config.HashBlockSize, total*p.config.HashBlockSize)
}

// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int {
func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
res := make(map[ServerID]int)
// Use a greedy strategy to search from the longest prefix.
// NOTE: It's possible to further optimize this with a binary search.
for i := 0; i < len(hashes); i++ {
hash := hashes[i]
cachedServers := m.indexer.Get(hash)
cachedServers := p.indexer.Get(hash)
if len(cachedServers) == 0 {
break
} else {
Expand Down
90 changes: 64 additions & 26 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ import (
"strings"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
k8stypes "k8s.io/apimachinery/pkg/types"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

Expand All @@ -38,20 +40,20 @@ func TestPrefixPlugin(t *testing.T) {
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}
plugin := New(config)
plugin := New(context.Background(), config)

pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}
pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}
pods := []types.Pod{pod1, pod2}

// First request.
req1 := &types.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "test-model1",
Prompt: "aaaaaa",
}
cycleState1 := types.NewCycleState()
scores := plugin.Score(context.Background(), cycleState1, req1, pods)
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState1, PrefixCachePluginType)
scores := plugin.Score(context.Background(), nil, req1, pods)
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, PrefixCachePluginType)
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
Expand All @@ -62,17 +64,23 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

// Simulate pod1 was picked.
plugin.PostCycle(context.Background(), cycleState1, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
schedulingResult := &types.SchedulingResult{
PrimaryProfileName: "default",
ProfileResults: map[string]*types.ProfileRunResult{
"default": {TargetPods: []types.Pod{pod1}},
},
}
plugin.PreRequest(context.Background(), req1, schedulingResult, 0)

// Second request doesn't share any prefix with first one. It should be added to the cache but
// the pod score should be 0.
req2 := &types.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "test-model2",
Prompt: "bbbbbb",
}
cycleState2 := types.NewCycleState()
scores = plugin.Score(context.Background(), cycleState2, req2, pods)
state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState2, PrefixCachePluginType)
scores = plugin.Score(context.Background(), nil, req2, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, PrefixCachePluginType)
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
Expand All @@ -83,16 +91,22 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

// Simulate pod2 was picked.
plugin.PostCycle(context.Background(), cycleState2, &types.ProfileRunResult{TargetPods: []types.Pod{pod2}})
schedulingResult = &types.SchedulingResult{
PrimaryProfileName: "default",
ProfileResults: map[string]*types.ProfileRunResult{
"default": {TargetPods: []types.Pod{pod2}},
},
}
plugin.PreRequest(context.Background(), req2, schedulingResult, 0)

// Third request shares partial prefix with first one.
req3 := &types.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "test-model1",
Prompt: "aaaabbbb",
}
cycleState3 := types.NewCycleState()
scores = plugin.Score(context.Background(), cycleState3, req3, pods)
state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState3, PrefixCachePluginType)
scores = plugin.Score(context.Background(), nil, req3, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, PrefixCachePluginType)
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
Expand All @@ -102,16 +116,22 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
schedulingResult = &types.SchedulingResult{
PrimaryProfileName: "default",
ProfileResults: map[string]*types.ProfileRunResult{
"default": {TargetPods: []types.Pod{pod1}},
},
}
plugin.PreRequest(context.Background(), req3, schedulingResult, 0)

// 4th request is same as req3 except the model is different, still no match.
req4 := &types.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "test-model-new",
Prompt: "aaaabbbb",
}
cycleState4 := types.NewCycleState()
scores = plugin.Score(context.Background(), cycleState4, req4, pods)
state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState4, PrefixCachePluginType)
scores = plugin.Score(context.Background(), nil, req4, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, PrefixCachePluginType)
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
Expand All @@ -121,16 +141,22 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

plugin.PostCycle(context.Background(), cycleState4, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
schedulingResult = &types.SchedulingResult{
PrimaryProfileName: "default",
ProfileResults: map[string]*types.ProfileRunResult{
"default": {TargetPods: []types.Pod{pod1}},
},
}
plugin.PreRequest(context.Background(), req4, schedulingResult, 0)

// 5th request shares partial prefix with 3rd one.
req5 := &types.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "test-model1",
Prompt: "aaaabbbbcccc",
}
cycleState5 := types.NewCycleState()
scores = plugin.Score(context.Background(), cycleState5, req5, pods)
state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState5, PrefixCachePluginType)
scores = plugin.Score(context.Background(), nil, req5, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, PrefixCachePluginType)
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 12, hash block size is 4, so 3 hashes will be calculated.
Expand All @@ -140,7 +166,13 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
schedulingResult = &types.SchedulingResult{
PrimaryProfileName: "default",
ProfileResults: map[string]*types.ProfileRunResult{
"default": {TargetPods: []types.Pod{pod1}},
},
}
plugin.PreRequest(context.Background(), req5, schedulingResult, 0)
}

// TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length.
Expand All @@ -153,7 +185,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}

plugin := New(config)
plugin := New(context.Background(), config)
types.NewCycleState()
var promptLen []int
for i := 1; i <= 1024; i++ {
Expand All @@ -174,17 +206,23 @@ func BenchmarkPrefixPluginStress(b *testing.B) {

pods := []types.Pod{pod}
req := &types.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "model-stress",
Prompt: prompt,
}

// First cycle: simulate scheduling and insert prefix info into the cache
cycleState := types.NewCycleState()
plugin.Score(context.Background(), cycleState, req, pods)
plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPods: []types.Pod{pod}})
plugin.Score(context.Background(), nil, req, pods)
schedulingResult := &types.SchedulingResult{
PrimaryProfileName: "default",
ProfileResults: map[string]*types.ProfileRunResult{
"default": {TargetPods: []types.Pod{pod}},
},
}
plugin.PreRequest(context.Background(), req, schedulingResult, 0)

// Second cycle: validate internal state
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, PrefixCachePluginType)
assert.NoError(b, err)
expectedHashes := int(math.Min(float64(maxPrefixBlocks+1), float64(len(req.Prompt)/blockSize+1))) // the extra one is for the model.
assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect")
Expand Down
Loading