diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index d30bc13f5..2947785d2 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -67,15 +67,19 @@ const ( ) var DefaultConfig = Config{ - DefaultBlockSize: DefaultBlockSize, + AutoTune: true, + BlockSize: DefaultBlockSize, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, LRUCapacityPerServer: DefaultLRUCapacityPerServer, } type Config struct { + // If set to true, the plugin will automatically adjust the configuration based on various + // metrics from the model servers. + AutoTune bool `json:"autoTune"` // The input prompt is broken into sizes of BlockSize to calculate block hashes . Requests // with length shorter than the block size will be ignored. - DefaultBlockSize int `json:"blockSize"` + BlockSize int `json:"blockSize"` // MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will // be ignored. MaxPrefixBlocksToMatch int `json:"maxPrefixBlocksToMatch"` @@ -148,11 +152,7 @@ var ( // PrefixCachePluginFactory defines the factory function for Prefix plugin. func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { - parameters := Config{ - DefaultBlockSize: DefaultBlockSize, - MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, - LRUCapacityPerServer: DefaultLRUCapacityPerServer, - } + parameters := DefaultConfig if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { @@ -167,40 +167,32 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle // New initializes a new prefix Plugin and returns its pointer. func New(ctx context.Context, config Config) *Plugin { - capacity := config.LRUCapacityPerServer - if capacity <= 0 { - capacity = DefaultLRUCapacityPerServer + if config.LRUCapacityPerServer <= 0 { + config.LRUCapacityPerServer = DefaultLRUCapacityPerServer log.FromContext(ctx).V(logutil.DEFAULT).Info( "LRUCapacityPerServer is not positive, using default value", "defaultCapacity", DefaultLRUCapacityPerServer, ) } - blockSize := config.DefaultBlockSize - if blockSize <= 0 { - blockSize = DefaultBlockSize - log.FromContext(ctx).V(logutil.DEFAULT).Info("DefaultBlockSize is not positive, using default value", + if config.BlockSize <= 0 { + config.BlockSize = DefaultBlockSize + log.FromContext(ctx).V(logutil.DEFAULT).Info("BlockSize is not positive, using default value", "default", DefaultBlockSize) } - maxPrefixBlocks := config.MaxPrefixBlocksToMatch - if maxPrefixBlocks <= 0 { - maxPrefixBlocks = DefaultMaxPrefixBlocks + if config.MaxPrefixBlocksToMatch <= 0 { + config.MaxPrefixBlocksToMatch = DefaultMaxPrefixBlocks log.FromContext(ctx).V(logutil.DEFAULT).Info("MaxPrefixBlocksToMatch is not positive, using default value", "default", DefaultMaxPrefixBlocks) } - validConfig := Config{ - DefaultBlockSize: blockSize, - MaxPrefixBlocksToMatch: maxPrefixBlocks, - LRUCapacityPerServer: capacity, - } - + log.FromContext(ctx).V(logutil.DEFAULT).Info("PrefixCachePlugin initialized", "config", config) return &Plugin{ typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType}, - config: validConfig, + config: config, pluginState: plugins.NewPluginState(ctx), - indexer: newIndexer(ctx, capacity), + indexer: newIndexer(ctx, config.LRUCapacityPerServer), } } @@ -218,7 +210,7 @@ func (p *Plugin) WithName(name string) *Plugin { // Score returns the scoring result for the given list of pods based on context. func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { // pre score step, hashing prompt and find longest prefix match. - hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch) + hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config), p.config.MaxPrefixBlocksToMatch) state := &SchedulingContextState{ PrefixHashes: hashes, PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), @@ -248,8 +240,12 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques // PreRequest records in the plugin cache the result of the scheduling selection. func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) { primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName] - targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile - gpuBlocks := primaryProfileResult.TargetPods[0].GetMetrics().CacheNumGPUBlocks + targetPod := primaryProfileResult.TargetPods[0] // get the first pod of the primary profile + + gpuBlocks := p.config.LRUCapacityPerServer + if p.config.AutoTune && targetPod.GetMetrics().CacheNumGPUBlocks > 0 { + gpuBlocks = targetPod.GetMetrics().CacheNumGPUBlocks + } state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it @@ -265,16 +261,16 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche p.wg.Add(1) go func() { p.indexer.Add(state.PrefixHashes, Server{ - ServerID(targetPod.NamespacedName), + ServerID(targetPod.GetPod().NamespacedName), gpuBlocks, }) p.wg.Done() }() total := len(state.PrefixHashes) - matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)] + matchLen := state.PrefixCacheServers[ServerID(targetPod.GetPod().NamespacedName)] - blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config.DefaultBlockSize) + blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config) metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize) } @@ -388,9 +384,14 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) { return json.Marshal(request.Body.ChatCompletions.Messages) } -func getBlockSize(pods []types.Pod, defaultBlockSize int) int { +func getBlockSize(pods []types.Pod, config Config) int { + if !config.AutoTune { + return config.BlockSize + } + + // Fallback to BlockSize if no metrics are available. if len(pods) == 0 { - return defaultBlockSize + return config.BlockSize } // Since all PODs originate from the same inference pool, they are considered to have identical configurations. @@ -401,5 +402,5 @@ func getBlockSize(pods []types.Pod, defaultBlockSize int) int { return cacheBlockSize * averageCharactersPerToken } } - return defaultBlockSize + return config.BlockSize } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 1bb0794de..f0feeef68 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -35,7 +35,7 @@ import ( func TestPrefixPluginCompletion(t *testing.T) { config := Config{ - DefaultBlockSize: 4, + BlockSize: 4, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, LRUCapacityPerServer: DefaultLRUCapacityPerServer, } @@ -201,7 +201,7 @@ func TestPrefixPluginCompletion(t *testing.T) { func TestPrefixPluginChatCompletions(t *testing.T) { config := Config{ - DefaultBlockSize: 4, + BlockSize: 4, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, LRUCapacityPerServer: DefaultLRUCapacityPerServer, } @@ -235,7 +235,7 @@ func TestPrefixPluginChatCompletions(t *testing.T) { func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { config := Config{ - DefaultBlockSize: 8, // Use larger block size for more predictable JSON marshaling + BlockSize: 8, // Use larger block size for more predictable JSON marshaling MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, LRUCapacityPerServer: DefaultLRUCapacityPerServer, } @@ -349,7 +349,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) { blockSize := 4 maxPrefixBlocks := 50000 config := Config{ - DefaultBlockSize: blockSize, + BlockSize: blockSize, MaxPrefixBlocksToMatch: maxPrefixBlocks, LRUCapacityPerServer: DefaultLRUCapacityPerServer, } @@ -409,7 +409,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) { { name: "all zero", config: Config{ - DefaultBlockSize: 0, + BlockSize: 0, MaxPrefixBlocksToMatch: 0, LRUCapacityPerServer: 0, }, @@ -420,7 +420,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) { { name: "negative values", config: Config{ - DefaultBlockSize: -5, + BlockSize: -5, MaxPrefixBlocksToMatch: -10, LRUCapacityPerServer: -100, }, @@ -431,7 +431,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) { { name: "mixed valid and invalid", config: Config{ - DefaultBlockSize: 32, // valid + BlockSize: 32, // valid MaxPrefixBlocksToMatch: -1, // invalid LRUCapacityPerServer: 50000, // valid }, @@ -442,7 +442,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) { { name: "all valid", config: Config{ - DefaultBlockSize: 64, + BlockSize: 64, MaxPrefixBlocksToMatch: 200, LRUCapacityPerServer: 30000, }, @@ -459,13 +459,108 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) { assert.NotEmpty(t, plugin) assert.NotEmpty(t, plugin.indexer) - assert.Equal(t, tt.expectBlock, plugin.config.DefaultBlockSize) + assert.Equal(t, tt.expectBlock, plugin.config.BlockSize) assert.Equal(t, tt.expectMaxMatch, plugin.config.MaxPrefixBlocksToMatch) assert.Equal(t, tt.expectCapacity, plugin.config.LRUCapacityPerServer) }) } } +func TestPrefixPluginAutoTune(t *testing.T) { + // Setup common test data + podName := "pod-autotune" + pod := &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: podName}}, + MetricsState: &backendmetrics.MetricsState{ + CacheBlockSize: 16, // 16 tokens * 4 chars/token = 64 chars per block + CacheNumGPUBlocks: 1000, // 1000 blocks capacity + }, + } + pods := []types.Pod{pod} + + req := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + // Length 128 chars. + // If AutoTune=true (block size 64): 2 blocks + // If AutoTune=false (block size 32): 4 blocks + Prompt: strings.Repeat("a", 128), + }, + }, + } + + t.Run("AutoTune Enabled", func(t *testing.T) { + config := Config{ + AutoTune: true, + BlockSize: 32, // Should be ignored in favor of pod metrics (64) + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + // Should be ignored in favor of pod metrics (1000) + LRUCapacityPerServer: 1, + } + plugin := New(context.Background(), config) + + // 1. Verify Score uses pod metrics for block size + scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods) + _ = scores + + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + // Block size from pod is 16 tokens * 4 = 64 chars. + // Prompt is 128 chars. + // Expected blocks: 128/64 = 2 hashes (model hash is used as seed but not returned as a block) + assert.Equal(t, 2, len(state.PrefixHashes), "Should use pod block size (64 chars) -> 2 body blocks") + + // 2. Verify PreRequest uses pod metrics for capacity + schedulingResult := &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod}}, + }, + } + plugin.PreRequest(context.Background(), req, schedulingResult) + plugin.wg.Wait() + + // Check indexer state + assert.Contains(t, plugin.indexer.Pods(), ServerID(pod.GetPod().NamespacedName)) + }) + + t.Run("AutoTune Disabled", func(t *testing.T) { + config := Config{ + AutoTune: false, + BlockSize: 32, // Should be used (32 chars) + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUCapacityPerServer: 1, // Should be used, and the first hash should be evicted due to the small + } + plugin := New(context.Background(), config) + + // 1. Verify Score uses config BlockSize + req.RequestId = uuid.NewString() // New request ID + scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods) + _ = scores + + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + // Block size from config is 32 chars. + // Prompt is 128 chars. + // 128 / 32 = 4 chunks. + assert.Equal(t, 4, len(state.PrefixHashes), "Should use config block size (32 chars) -> 4 body blocks") + + // 2. Verify PreRequest uses config LRUCapacityPerServer + schedulingResult := &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod}}, + }, + } + plugin.PreRequest(context.Background(), req, schedulingResult) + plugin.wg.Wait() + + assert.Contains(t, plugin.indexer.Pods(), ServerID(pod.GetPod().NamespacedName)) + }) +} + // randomPrompt generates a pseudo-random string of length n using lowercase letters. func randomPrompt(n int) string { runes := []rune("abcdefghijklmnopqrstuvwxyz") @@ -481,7 +576,7 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) { blockSize := 8 maxPrefixBlocks := 50000 config := Config{ - DefaultBlockSize: blockSize, + BlockSize: blockSize, MaxPrefixBlocksToMatch: maxPrefixBlocks, LRUCapacityPerServer: DefaultLRUCapacityPerServer, }