Skip to content

Commit 743c103

Browse files
committed
Update precise-prefix-cache-scorer to latest llm-d-kv-cache
The new API separates tokenization from scoring, requiring explicit token processor initialization and a two-step flow: tokenize first, then get pod scores. Signed-off-by: Antonio Cardace <acardace@redhat.com>
1 parent fa48006 commit 743c103

File tree

1 file changed

+89
-23
lines changed

1 file changed

+89
-23
lines changed

pkg/plugins/scorer/precise_prefix_cache.go

Lines changed: 89 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package scorer
33
import (
44
"context"
55
"encoding/json"
6-
"errors"
76
"fmt"
87
"os"
98
"time"
@@ -39,6 +38,9 @@ type PrecisePrefixCachePluginConfig struct {
3938
// used to subscribe to KV-cache events and update the internal KV-cache
4039
// index state.
4140
KVEventsConfig *kvevents.Config `json:"kvEventsConfig"`
41+
// TokenProcessorConfig holds the configuration for the token processor
42+
// used to convert tokens to KV block keys.
43+
TokenProcessorConfig *kvblock.TokenProcessorConfig `json:"tokenProcessorConfig"`
4244
}
4345

4446
// compile-time type assertion
@@ -54,8 +56,9 @@ func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage,
5456
}
5557

5658
parameters := PrecisePrefixCachePluginConfig{
57-
IndexerConfig: indexerConfig,
58-
KVEventsConfig: kvevents.DefaultConfig(),
59+
IndexerConfig: indexerConfig,
60+
KVEventsConfig: kvevents.DefaultConfig(),
61+
TokenProcessorConfig: kvblock.DefaultTokenProcessorConfig(),
5962
}
6063

6164
if rawParameters != nil {
@@ -96,10 +99,7 @@ func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage,
9699
// If the configuration is invalid or if the indexer fails to initialize,
97100
// an error is returned.
98101
func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePrefixCacheScorer, error) {
99-
if config.TokenProcessorConfig == nil {
100-
config.TokenProcessorConfig = kvblock.DefaultTokenProcessorConfig()
101-
}
102-
102+
// initialize the token processor
103103
tokenProcessor := kvblock.NewChunkedTokenDatabase(config.TokenProcessorConfig)
104104

105105
// initialize the indexer
@@ -110,9 +110,8 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr
110110

111111
go kvCacheIndexer.Run(ctx)
112112

113-
// initialize the KV-events pool
114-
pool := kvevents.NewPool(config.KVEventsConfig, kvCacheIndexer.KVBlockIndex(), tokenProcessor)
115-
pool.Start(ctx)
113+
// initialize and start the KV-events pool
114+
kvevents.NewPool(config.KVEventsConfig, kvCacheIndexer.KVBlockIndex(), tokenProcessor).Start(ctx)
116115

117116
subscribersManager := kvevents.NewSubscriberManager(pool)
118117
var subscribersCache *ttlcache.Cache[string, struct{}]
@@ -180,6 +179,59 @@ func (s *PrecisePrefixCacheScorer) WithName(name string) *PrecisePrefixCacheScor
180179
return s
181180
}
182181

182+
func (s *PrecisePrefixCacheScorer) buildPrompt(ctx context.Context, request *types.LLMRequest) (string, *preprocessing.ApplyChatTemplateRequest) {
183+
logger := log.FromContext(ctx).WithName(s.typedName.String())
184+
traceLogger := logger.V(logutil.TRACE)
185+
186+
traceLogger.Info("Getting scores",
187+
"isChatCompletions", request.Body != nil && request.Body.ChatCompletions != nil,
188+
"isCompletions", request.Body != nil && request.Body.Completions != nil)
189+
190+
// The upstream parser guarantees exactly one body is populated, but we defensively prioritize chat completions.
191+
// If an unexpected dual payload slips through (parser regression/new client), log it and use chat semantics.
192+
if request.Body != nil && request.Body.ChatCompletions != nil {
193+
if request.Body.Completions != nil {
194+
traceLogger.Info("Both chat/completions and completions present; defaulting to chat/completions")
195+
}
196+
197+
// Convert messages to the format expected by the renderer
198+
conversation := make([]preprocessing.Conversation, len(request.Body.ChatCompletions.Messages))
199+
for i, msg := range request.Body.ChatCompletions.Messages {
200+
conversation[i] = preprocessing.Conversation{
201+
Role: msg.Role,
202+
Content: msg.Content.Raw,
203+
}
204+
}
205+
206+
renderReq := &preprocessing.ApplyChatTemplateRequest{
207+
Conversation: [][]preprocessing.Conversation{conversation},
208+
Tools: request.Body.ChatCompletions.Tools,
209+
Documents: request.Body.ChatCompletions.Documents,
210+
ChatTemplate: request.Body.ChatCompletions.ChatTemplate,
211+
ReturnAssistantTokensMask: request.Body.ChatCompletions.ReturnAssistantTokensMask,
212+
ContinueFinalMessage: request.Body.ChatCompletions.ContinueFinalMessage,
213+
AddGenerationPrompt: request.Body.ChatCompletions.AddGenerationPrompt,
214+
ChatTemplateKWArgs: request.Body.ChatCompletions.ChatTemplateKWArgs,
215+
}
216+
217+
traceLogger.Info("Processing chat completion request",
218+
"messagesCount", len(conversation),
219+
"toolsCount", len(renderReq.Tools),
220+
"documentsCount", len(renderReq.Documents))
221+
222+
return "", renderReq
223+
}
224+
225+
// For regular completions, use the prompt directly
226+
if request.Body != nil && request.Body.Completions != nil {
227+
traceLogger.Info("Using completion prompt directly", "promptLength", len(request.Body.Completions.Prompt))
228+
return request.Body.Completions.Prompt, nil
229+
}
230+
231+
traceLogger.Error(fmt.Errorf("Both chat and completions are empty"), "error building prompt")
232+
return "", nil
233+
}
234+
183235
// Score scores the provided pod based on the KVCache index state.
184236
// The returned scores are normalized to a range of 0-1.
185237
func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
@@ -211,11 +263,24 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types.
211263
return nil
212264
}
213265

214-
scores, err := s.getScores(ctx, request)
266+
prompt, renderReq := s.buildPrompt(ctx, request)
267+
if prompt == "" && renderReq == nil {
268+
logger.V(logutil.DEFAULT).Info("No valid prompt, skipping scoring")
269+
return nil
270+
}
271+
272+
tokens, err := s.kvCacheIndexer.Tokenize(renderReq, prompt)
273+
if err != nil {
274+
logger.Error(err, "Failed to tokenize prompt")
275+
return nil
276+
}
277+
278+
scores, err := s.kvCacheIndexer.GetPodScores(ctx, tokens, request.TargetModel, nil)
215279
if err != nil {
216280
logger.Error(err, "Failed to get pod scores")
217281
return nil
218282
}
283+
219284
debugLogger.Info("Got pod scores", "scores", scores)
220285

221286
podToKey := func(pod types.Pod) (string, bool) {
@@ -242,6 +307,7 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types.
242307

243308
return indexedScoresToNormalizedScoredPods(pods, podToKey, scores)
244309
}
310+
==== BASE ====
245311

246312
// getScores retrieves the pod scores from the KV-cache indexer
247313
// based on the provided LLM request.
@@ -262,17 +328,8 @@ func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types
262328
traceLogger.Info("Both chat/completions and completions present; defaulting to chat/completions")
263329
}
264330

265-
// Convert messages to conversation format
266-
conversations := make([]preprocessing.Conversation, len(request.Body.ChatCompletions.Messages))
267-
for i, msg := range request.Body.ChatCompletions.Messages {
268-
conversations[i] = preprocessing.Conversation{
269-
Role: msg.Role,
270-
Content: msg.Content.Raw,
271-
}
272-
}
273-
274-
renderReq := &preprocessing.ApplyChatTemplateRequest{
275-
Conversation: [][]preprocessing.Conversation{conversations},
331+
renderReq := &preprocessing.RenderJinjaTemplateRequest{
332+
Conversations: make([]preprocessing.ChatMessage, 0),
276333
Tools: request.Body.ChatCompletions.Tools,
277334
Documents: request.Body.ChatCompletions.Documents,
278335
ChatTemplate: request.Body.ChatCompletions.ChatTemplate,
@@ -282,8 +339,16 @@ func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types
282339
ChatTemplateKWArgs: request.Body.ChatCompletions.ChatTemplateKWArgs,
283340
}
284341

342+
// Convert messages to the format expected by the renderer
343+
for _, msg := range request.Body.ChatCompletions.Messages {
344+
renderReq.Conversations = append(renderReq.Conversations, preprocessing.ChatMessage{
345+
Role: msg.Role,
346+
Content: msg.Content.Raw,
347+
})
348+
}
349+
285350
traceLogger.Info("Processing chat completion request",
286-
"messagesCount", len(conversations),
351+
"messagesCount", len(renderReq.Conversations),
287352
"toolsCount", len(renderReq.Tools),
288353
"documentsCount", len(renderReq.Documents))
289354

@@ -308,3 +373,4 @@ func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types
308373

309374
return nil, errors.New("no valid input found in request")
310375
}
376+
==== BASE ====

0 commit comments

Comments
 (0)