Skip to content

Commit f357ece

Browse files
authored
Add AutoTune config to prefix scorer to make it explicit when auto vs. manual config is used (#1888)
1 parent 3c8aba1 commit f357ece

File tree

2 files changed

+140
-44
lines changed

2 files changed

+140
-44
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,19 @@ const (
6767
)
6868

6969
var DefaultConfig = Config{
70-
DefaultBlockSize: DefaultBlockSize,
70+
AutoTune: true,
71+
BlockSize: DefaultBlockSize,
7172
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
7273
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
7374
}
7475

7576
type Config struct {
77+
// If set to true, the plugin will automatically adjust the configuration based on various
78+
// metrics from the model servers.
79+
AutoTune bool `json:"autoTune"`
7680
// The input prompt is broken into sizes of BlockSize to calculate block hashes . Requests
7781
// with length shorter than the block size will be ignored.
78-
DefaultBlockSize int `json:"blockSize"`
82+
BlockSize int `json:"blockSize"`
7983
// MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will
8084
// be ignored.
8185
MaxPrefixBlocksToMatch int `json:"maxPrefixBlocksToMatch"`
@@ -148,11 +152,7 @@ var (
148152

149153
// PrefixCachePluginFactory defines the factory function for Prefix plugin.
150154
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
151-
parameters := Config{
152-
DefaultBlockSize: DefaultBlockSize,
153-
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
154-
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
155-
}
155+
parameters := DefaultConfig
156156

157157
if rawParameters != nil {
158158
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
@@ -167,40 +167,32 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle
167167

168168
// New initializes a new prefix Plugin and returns its pointer.
169169
func New(ctx context.Context, config Config) *Plugin {
170-
capacity := config.LRUCapacityPerServer
171-
if capacity <= 0 {
172-
capacity = DefaultLRUCapacityPerServer
170+
if config.LRUCapacityPerServer <= 0 {
171+
config.LRUCapacityPerServer = DefaultLRUCapacityPerServer
173172
log.FromContext(ctx).V(logutil.DEFAULT).Info(
174173
"LRUCapacityPerServer is not positive, using default value",
175174
"defaultCapacity", DefaultLRUCapacityPerServer,
176175
)
177176
}
178177

179-
blockSize := config.DefaultBlockSize
180-
if blockSize <= 0 {
181-
blockSize = DefaultBlockSize
182-
log.FromContext(ctx).V(logutil.DEFAULT).Info("DefaultBlockSize is not positive, using default value",
178+
if config.BlockSize <= 0 {
179+
config.BlockSize = DefaultBlockSize
180+
log.FromContext(ctx).V(logutil.DEFAULT).Info("BlockSize is not positive, using default value",
183181
"default", DefaultBlockSize)
184182
}
185183

186-
maxPrefixBlocks := config.MaxPrefixBlocksToMatch
187-
if maxPrefixBlocks <= 0 {
188-
maxPrefixBlocks = DefaultMaxPrefixBlocks
184+
if config.MaxPrefixBlocksToMatch <= 0 {
185+
config.MaxPrefixBlocksToMatch = DefaultMaxPrefixBlocks
189186
log.FromContext(ctx).V(logutil.DEFAULT).Info("MaxPrefixBlocksToMatch is not positive, using default value",
190187
"default", DefaultMaxPrefixBlocks)
191188
}
192189

193-
validConfig := Config{
194-
DefaultBlockSize: blockSize,
195-
MaxPrefixBlocksToMatch: maxPrefixBlocks,
196-
LRUCapacityPerServer: capacity,
197-
}
198-
190+
log.FromContext(ctx).V(logutil.DEFAULT).Info("PrefixCachePlugin initialized", "config", config)
199191
return &Plugin{
200192
typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType},
201-
config: validConfig,
193+
config: config,
202194
pluginState: plugins.NewPluginState(ctx),
203-
indexer: newIndexer(ctx, capacity),
195+
indexer: newIndexer(ctx, config.LRUCapacityPerServer),
204196
}
205197
}
206198

@@ -218,7 +210,7 @@ func (p *Plugin) WithName(name string) *Plugin {
218210
// Score returns the scoring result for the given list of pods based on context.
219211
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
220212
// pre score step, hashing prompt and find longest prefix match.
221-
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch)
213+
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config), p.config.MaxPrefixBlocksToMatch)
222214
state := &SchedulingContextState{
223215
PrefixHashes: hashes,
224216
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
@@ -248,8 +240,12 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
248240
// PreRequest records in the plugin cache the result of the scheduling selection.
249241
func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) {
250242
primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
251-
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile
252-
gpuBlocks := primaryProfileResult.TargetPods[0].GetMetrics().CacheNumGPUBlocks
243+
targetPod := primaryProfileResult.TargetPods[0] // get the first pod of the primary profile
244+
245+
gpuBlocks := p.config.LRUCapacityPerServer
246+
if p.config.AutoTune && targetPod.GetMetrics().CacheNumGPUBlocks > 0 {
247+
gpuBlocks = targetPod.GetMetrics().CacheNumGPUBlocks
248+
}
253249

254250
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
255251
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
@@ -265,16 +261,16 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
265261
p.wg.Add(1)
266262
go func() {
267263
p.indexer.Add(state.PrefixHashes, Server{
268-
ServerID(targetPod.NamespacedName),
264+
ServerID(targetPod.GetPod().NamespacedName),
269265
gpuBlocks,
270266
})
271267
p.wg.Done()
272268
}()
273269

274270
total := len(state.PrefixHashes)
275-
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]
271+
matchLen := state.PrefixCacheServers[ServerID(targetPod.GetPod().NamespacedName)]
276272

277-
blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config.DefaultBlockSize)
273+
blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config)
278274
metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize)
279275
}
280276

@@ -388,9 +384,14 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
388384
return json.Marshal(request.Body.ChatCompletions.Messages)
389385
}
390386

391-
func getBlockSize(pods []types.Pod, defaultBlockSize int) int {
387+
func getBlockSize(pods []types.Pod, config Config) int {
388+
if !config.AutoTune {
389+
return config.BlockSize
390+
}
391+
392+
// Fallback to BlockSize if no metrics are available.
392393
if len(pods) == 0 {
393-
return defaultBlockSize
394+
return config.BlockSize
394395
}
395396

396397
// Since all PODs originate from the same inference pool, they are considered to have identical configurations.
@@ -401,5 +402,5 @@ func getBlockSize(pods []types.Pod, defaultBlockSize int) int {
401402
return cacheBlockSize * averageCharactersPerToken
402403
}
403404
}
404-
return defaultBlockSize
405+
return config.BlockSize
405406
}

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 105 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import (
3535

3636
func TestPrefixPluginCompletion(t *testing.T) {
3737
config := Config{
38-
DefaultBlockSize: 4,
38+
BlockSize: 4,
3939
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
4040
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
4141
}
@@ -201,7 +201,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
201201

202202
func TestPrefixPluginChatCompletions(t *testing.T) {
203203
config := Config{
204-
DefaultBlockSize: 4,
204+
BlockSize: 4,
205205
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
206206
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
207207
}
@@ -235,7 +235,7 @@ func TestPrefixPluginChatCompletions(t *testing.T) {
235235

236236
func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
237237
config := Config{
238-
DefaultBlockSize: 8, // Use larger block size for more predictable JSON marshaling
238+
BlockSize: 8, // Use larger block size for more predictable JSON marshaling
239239
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
240240
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
241241
}
@@ -349,7 +349,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
349349
blockSize := 4
350350
maxPrefixBlocks := 50000
351351
config := Config{
352-
DefaultBlockSize: blockSize,
352+
BlockSize: blockSize,
353353
MaxPrefixBlocksToMatch: maxPrefixBlocks,
354354
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
355355
}
@@ -409,7 +409,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
409409
{
410410
name: "all zero",
411411
config: Config{
412-
DefaultBlockSize: 0,
412+
BlockSize: 0,
413413
MaxPrefixBlocksToMatch: 0,
414414
LRUCapacityPerServer: 0,
415415
},
@@ -420,7 +420,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
420420
{
421421
name: "negative values",
422422
config: Config{
423-
DefaultBlockSize: -5,
423+
BlockSize: -5,
424424
MaxPrefixBlocksToMatch: -10,
425425
LRUCapacityPerServer: -100,
426426
},
@@ -431,7 +431,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
431431
{
432432
name: "mixed valid and invalid",
433433
config: Config{
434-
DefaultBlockSize: 32, // valid
434+
BlockSize: 32, // valid
435435
MaxPrefixBlocksToMatch: -1, // invalid
436436
LRUCapacityPerServer: 50000, // valid
437437
},
@@ -442,7 +442,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
442442
{
443443
name: "all valid",
444444
config: Config{
445-
DefaultBlockSize: 64,
445+
BlockSize: 64,
446446
MaxPrefixBlocksToMatch: 200,
447447
LRUCapacityPerServer: 30000,
448448
},
@@ -459,13 +459,108 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
459459

460460
assert.NotEmpty(t, plugin)
461461
assert.NotEmpty(t, plugin.indexer)
462-
assert.Equal(t, tt.expectBlock, plugin.config.DefaultBlockSize)
462+
assert.Equal(t, tt.expectBlock, plugin.config.BlockSize)
463463
assert.Equal(t, tt.expectMaxMatch, plugin.config.MaxPrefixBlocksToMatch)
464464
assert.Equal(t, tt.expectCapacity, plugin.config.LRUCapacityPerServer)
465465
})
466466
}
467467
}
468468

469+
func TestPrefixPluginAutoTune(t *testing.T) {
470+
// Setup common test data
471+
podName := "pod-autotune"
472+
pod := &types.PodMetrics{
473+
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: podName}},
474+
MetricsState: &backendmetrics.MetricsState{
475+
CacheBlockSize: 16, // 16 tokens * 4 chars/token = 64 chars per block
476+
CacheNumGPUBlocks: 1000, // 1000 blocks capacity
477+
},
478+
}
479+
pods := []types.Pod{pod}
480+
481+
req := &types.LLMRequest{
482+
RequestId: uuid.NewString(),
483+
TargetModel: "test-model",
484+
Body: &types.LLMRequestBody{
485+
Completions: &types.CompletionsRequest{
486+
// Length 128 chars.
487+
// If AutoTune=true (block size 64): 2 blocks
488+
// If AutoTune=false (block size 32): 4 blocks
489+
Prompt: strings.Repeat("a", 128),
490+
},
491+
},
492+
}
493+
494+
t.Run("AutoTune Enabled", func(t *testing.T) {
495+
config := Config{
496+
AutoTune: true,
497+
BlockSize: 32, // Should be ignored in favor of pod metrics (64)
498+
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
499+
// Should be ignored in favor of pod metrics (1000)
500+
LRUCapacityPerServer: 1,
501+
}
502+
plugin := New(context.Background(), config)
503+
504+
// 1. Verify Score uses pod metrics for block size
505+
scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods)
506+
_ = scores
507+
508+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String()))
509+
assert.NoError(t, err)
510+
// Block size from pod is 16 tokens * 4 = 64 chars.
511+
// Prompt is 128 chars.
512+
// Expected blocks: 128/64 = 2 hashes (model hash is used as seed but not returned as a block)
513+
assert.Equal(t, 2, len(state.PrefixHashes), "Should use pod block size (64 chars) -> 2 body blocks")
514+
515+
// 2. Verify PreRequest uses pod metrics for capacity
516+
schedulingResult := &types.SchedulingResult{
517+
PrimaryProfileName: "default",
518+
ProfileResults: map[string]*types.ProfileRunResult{
519+
"default": {TargetPods: []types.Pod{pod}},
520+
},
521+
}
522+
plugin.PreRequest(context.Background(), req, schedulingResult)
523+
plugin.wg.Wait()
524+
525+
// Check indexer state
526+
assert.Contains(t, plugin.indexer.Pods(), ServerID(pod.GetPod().NamespacedName))
527+
})
528+
529+
t.Run("AutoTune Disabled", func(t *testing.T) {
530+
config := Config{
531+
AutoTune: false,
532+
BlockSize: 32, // Should be used (32 chars)
533+
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
534+
LRUCapacityPerServer: 1, // Should be used, and the first hash should be evicted due to the small
535+
}
536+
plugin := New(context.Background(), config)
537+
538+
// 1. Verify Score uses config BlockSize
539+
req.RequestId = uuid.NewString() // New request ID
540+
scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods)
541+
_ = scores
542+
543+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String()))
544+
assert.NoError(t, err)
545+
// Block size from config is 32 chars.
546+
// Prompt is 128 chars.
547+
// 128 / 32 = 4 chunks.
548+
assert.Equal(t, 4, len(state.PrefixHashes), "Should use config block size (32 chars) -> 4 body blocks")
549+
550+
// 2. Verify PreRequest uses config LRUCapacityPerServer
551+
schedulingResult := &types.SchedulingResult{
552+
PrimaryProfileName: "default",
553+
ProfileResults: map[string]*types.ProfileRunResult{
554+
"default": {TargetPods: []types.Pod{pod}},
555+
},
556+
}
557+
plugin.PreRequest(context.Background(), req, schedulingResult)
558+
plugin.wg.Wait()
559+
560+
assert.Contains(t, plugin.indexer.Pods(), ServerID(pod.GetPod().NamespacedName))
561+
})
562+
}
563+
469564
// randomPrompt generates a pseudo-random string of length n using lowercase letters.
470565
func randomPrompt(n int) string {
471566
runes := []rune("abcdefghijklmnopqrstuvwxyz")
@@ -481,7 +576,7 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) {
481576
blockSize := 8
482577
maxPrefixBlocks := 50000
483578
config := Config{
484-
DefaultBlockSize: blockSize,
579+
BlockSize: blockSize,
485580
MaxPrefixBlocksToMatch: maxPrefixBlocks,
486581
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
487582
}

0 commit comments

Comments
 (0)