Skip to content

Commit 68f5cf0

Browse files
committed
Add AutoTune config to prefix scorer to make it explicit when auto vs. manual config is used
1 parent 3836d3b commit 68f5cf0

File tree

2 files changed

+44
-44
lines changed

2 files changed

+44
-44
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,19 @@ const (
6767
)
6868

6969
var DefaultConfig = Config{
70-
DefaultBlockSize: DefaultBlockSize,
70+
AutoTune: true,
71+
BlockSize: DefaultBlockSize,
7172
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
7273
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
7374
}
7475

7576
type Config struct {
77+
// If set to true, the plugin will automatically adjust the configuration based on various
78+
// metrics from the model servers.
79+
AutoTune bool `json:"autoTune"`
7680
// The input prompt is broken into sizes of BlockSize to calculate block hashes . Requests
7781
// with length shorter than the block size will be ignored.
78-
DefaultBlockSize int `json:"blockSize"`
82+
BlockSize int `json:"blockSize"`
7983
// MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will
8084
// be ignored.
8185
MaxPrefixBlocksToMatch int `json:"maxPrefixBlocksToMatch"`
@@ -148,11 +152,7 @@ var (
148152

149153
// PrefixCachePluginFactory defines the factory function for Prefix plugin.
150154
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
151-
parameters := Config{
152-
DefaultBlockSize: DefaultBlockSize,
153-
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
154-
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
155-
}
155+
parameters := DefaultConfig
156156

157157
if rawParameters != nil {
158158
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
@@ -167,40 +167,31 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle
167167

168168
// New initializes a new prefix Plugin and returns its pointer.
169169
func New(ctx context.Context, config Config) *Plugin {
170-
capacity := config.LRUCapacityPerServer
171-
if capacity <= 0 {
172-
capacity = DefaultLRUCapacityPerServer
170+
if config.LRUCapacityPerServer <= 0 {
171+
config.LRUCapacityPerServer = DefaultLRUCapacityPerServer
173172
log.FromContext(ctx).V(logutil.DEFAULT).Info(
174173
"LRUCapacityPerServer is not positive, using default value",
175174
"defaultCapacity", DefaultLRUCapacityPerServer,
176175
)
177176
}
178177

179-
blockSize := config.DefaultBlockSize
180-
if blockSize <= 0 {
181-
blockSize = DefaultBlockSize
182-
log.FromContext(ctx).V(logutil.DEFAULT).Info("DefaultBlockSize is not positive, using default value",
178+
if config.BlockSize <= 0 {
179+
config.BlockSize = DefaultBlockSize
180+
log.FromContext(ctx).V(logutil.DEFAULT).Info("BlockSize is not positive, using default value",
183181
"default", DefaultBlockSize)
184182
}
185183

186-
maxPrefixBlocks := config.MaxPrefixBlocksToMatch
187-
if maxPrefixBlocks <= 0 {
188-
maxPrefixBlocks = DefaultMaxPrefixBlocks
184+
if config.MaxPrefixBlocksToMatch <= 0 {
185+
config.MaxPrefixBlocksToMatch = DefaultMaxPrefixBlocks
189186
log.FromContext(ctx).V(logutil.DEFAULT).Info("MaxPrefixBlocksToMatch is not positive, using default value",
190187
"default", DefaultMaxPrefixBlocks)
191188
}
192189

193-
validConfig := Config{
194-
DefaultBlockSize: blockSize,
195-
MaxPrefixBlocksToMatch: maxPrefixBlocks,
196-
LRUCapacityPerServer: capacity,
197-
}
198-
199190
return &Plugin{
200191
typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType},
201-
config: validConfig,
192+
config: config,
202193
pluginState: plugins.NewPluginState(ctx),
203-
indexer: newIndexer(ctx, capacity),
194+
indexer: newIndexer(ctx, config.LRUCapacityPerServer),
204195
}
205196
}
206197

@@ -218,7 +209,7 @@ func (p *Plugin) WithName(name string) *Plugin {
218209
// Score returns the scoring result for the given list of pods based on context.
219210
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
220211
// pre score step, hashing prompt and find longest prefix match.
221-
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch)
212+
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config), p.config.MaxPrefixBlocksToMatch)
222213
state := &SchedulingContextState{
223214
PrefixHashes: hashes,
224215
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
@@ -248,8 +239,12 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
248239
// PreRequest records in the plugin cache the result of the scheduling selection.
249240
func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) {
250241
primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
251-
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile
252-
gpuBlocks := primaryProfileResult.TargetPods[0].GetMetrics().CacheNumGPUBlocks
242+
targetPod := primaryProfileResult.TargetPods[0] // get the first pod of the primary profile
243+
244+
gpuBlocks := p.config.LRUCapacityPerServer
245+
if p.config.AutoTune && targetPod.GetMetrics().CacheNumGPUBlocks > 0 {
246+
gpuBlocks = targetPod.GetMetrics().CacheNumGPUBlocks
247+
}
253248

254249
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
255250
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
@@ -265,16 +260,16 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
265260
p.wg.Add(1)
266261
go func() {
267262
p.indexer.Add(state.PrefixHashes, Server{
268-
ServerID(targetPod.NamespacedName),
263+
ServerID(targetPod.GetPod().NamespacedName),
269264
gpuBlocks,
270265
})
271266
p.wg.Done()
272267
}()
273268

274269
total := len(state.PrefixHashes)
275-
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]
270+
matchLen := state.PrefixCacheServers[ServerID(targetPod.GetPod().NamespacedName)]
276271

277-
blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config.DefaultBlockSize)
272+
blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config)
278273
metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize)
279274
}
280275

@@ -388,9 +383,14 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
388383
return json.Marshal(request.Body.ChatCompletions.Messages)
389384
}
390385

391-
func getBlockSize(pods []types.Pod, defaultBlockSize int) int {
386+
func getBlockSize(pods []types.Pod, config Config) int {
387+
if !config.AutoTune {
388+
return config.BlockSize
389+
}
390+
391+
// Fallback to BlockSize if no metrics are available.
392392
if len(pods) == 0 {
393-
return defaultBlockSize
393+
return config.BlockSize
394394
}
395395

396396
// Since all PODs originate from the same inference pool, they are considered to have identical configurations.
@@ -401,5 +401,5 @@ func getBlockSize(pods []types.Pod, defaultBlockSize int) int {
401401
return cacheBlockSize * averageCharactersPerToken
402402
}
403403
}
404-
return defaultBlockSize
404+
return config.BlockSize
405405
}

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import (
3535

3636
func TestPrefixPluginCompletion(t *testing.T) {
3737
config := Config{
38-
DefaultBlockSize: 4,
38+
BlockSize: 4,
3939
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
4040
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
4141
}
@@ -201,7 +201,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
201201

202202
func TestPrefixPluginChatCompletions(t *testing.T) {
203203
config := Config{
204-
DefaultBlockSize: 4,
204+
BlockSize: 4,
205205
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
206206
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
207207
}
@@ -235,7 +235,7 @@ func TestPrefixPluginChatCompletions(t *testing.T) {
235235

236236
func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
237237
config := Config{
238-
DefaultBlockSize: 8, // Use larger block size for more predictable JSON marshaling
238+
BlockSize: 8, // Use larger block size for more predictable JSON marshaling
239239
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
240240
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
241241
}
@@ -349,7 +349,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
349349
blockSize := 4
350350
maxPrefixBlocks := 50000
351351
config := Config{
352-
DefaultBlockSize: blockSize,
352+
BlockSize: blockSize,
353353
MaxPrefixBlocksToMatch: maxPrefixBlocks,
354354
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
355355
}
@@ -409,7 +409,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
409409
{
410410
name: "all zero",
411411
config: Config{
412-
DefaultBlockSize: 0,
412+
BlockSize: 0,
413413
MaxPrefixBlocksToMatch: 0,
414414
LRUCapacityPerServer: 0,
415415
},
@@ -420,7 +420,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
420420
{
421421
name: "negative values",
422422
config: Config{
423-
DefaultBlockSize: -5,
423+
BlockSize: -5,
424424
MaxPrefixBlocksToMatch: -10,
425425
LRUCapacityPerServer: -100,
426426
},
@@ -431,7 +431,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
431431
{
432432
name: "mixed valid and invalid",
433433
config: Config{
434-
DefaultBlockSize: 32, // valid
434+
BlockSize: 32, // valid
435435
MaxPrefixBlocksToMatch: -1, // invalid
436436
LRUCapacityPerServer: 50000, // valid
437437
},
@@ -442,7 +442,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
442442
{
443443
name: "all valid",
444444
config: Config{
445-
DefaultBlockSize: 64,
445+
BlockSize: 64,
446446
MaxPrefixBlocksToMatch: 200,
447447
LRUCapacityPerServer: 30000,
448448
},
@@ -459,7 +459,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
459459

460460
assert.NotEmpty(t, plugin)
461461
assert.NotEmpty(t, plugin.indexer)
462-
assert.Equal(t, tt.expectBlock, plugin.config.DefaultBlockSize)
462+
assert.Equal(t, tt.expectBlock, plugin.config.BlockSize)
463463
assert.Equal(t, tt.expectMaxMatch, plugin.config.MaxPrefixBlocksToMatch)
464464
assert.Equal(t, tt.expectCapacity, plugin.config.LRUCapacityPerServer)
465465
})
@@ -481,7 +481,7 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) {
481481
blockSize := 8
482482
maxPrefixBlocks := 50000
483483
config := Config{
484-
DefaultBlockSize: blockSize,
484+
BlockSize: blockSize,
485485
MaxPrefixBlocksToMatch: maxPrefixBlocks,
486486
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
487487
}

0 commit comments

Comments
 (0)