Skip to content

Commit 9365b78

Browse files
guygirvMaroon
authored andcommitted
Completions support
Signed-off-by: Guy Girmonsky <[email protected]>
1 parent 4942be3 commit 9365b78

File tree

14 files changed

+1309
-20
lines changed

14 files changed

+1309
-20
lines changed

Makefile

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@ help: ## Print help
2323
@awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m<target>\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST)
2424

2525
##@ Tokenizer & Linking
26+
##
2627

28+
29+
##
2730
LDFLAGS ?= -extldflags '-L$(shell pwd)/lib'
31+
#LDFLAGS ?= -extldflags '-L$(shell pwd)/lib'
2832
CGO_ENABLED=1
2933
TOKENIZER_LIB = lib/libtokenizers.a
3034
TOKENIZER_RELEASE = v1.20.2
@@ -83,6 +87,22 @@ e2e-test: download-tokenizer
8387
@printf "\033[33;1m==== Running unit tests ====\033[0m\n"
8488
go test -v -ldflags="$(LDFLAGS)" ./tests/...
8589

90+
.PHONY: validate-chat-templates
91+
validate-chat-templates: ## Run chat template validation tests against vLLM
92+
@printf "\033[33;1m==== Running Chat Template Validation Tests ====\033[0m\n"
93+
@echo "Running tests for all models..."
94+
python3 scripts/run_chat_template_validation.py --model "TroyDoesAI/Llama-3.1-8B-Instruct"
95+
96+
.PHONY: validate-chat-template
97+
validate-chat-template: ## Run chat template validation for a specific model (usage: make validate-chat-template MODEL=model-name)
98+
@printf "\033[33;1m==== Running Chat Template Validation for $(MODEL) ====\033[0m\n"
99+
python3 scripts/run_chat_template_validation.py --model "$(MODEL)" --save
100+
101+
.PHONY: validate-chat-template-default
102+
validate-chat-template-default: ## Run chat template validation for default model
103+
@printf "\033[33;1m==== Running Chat Template Validation for Default Model ====\033[0m\n"
104+
python3 scripts/run_chat_template_validation.py --model "OpenAI-ChatGPT/ChatGPT-4-Micro"
105+
86106
##@ Build
87107

88108
.PHONY: build

examples/kv_cache_index/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ func runPrompts(ctx context.Context, kvCacheIndexer *kvcache.Indexer) error {
115115
logger.Info("Started Indexer", "model", modelName)
116116

117117
// Get pods for the prompt
118-
pods, err := kvCacheIndexer.GetPodScores(ctx, testdata.Prompt, modelName, nil)
118+
pods, err := kvCacheIndexer.GetPodScores(ctx, testdata.Prompt, modelName, nil, false)
119119
if err != nil {
120120
return err
121121
}
@@ -136,7 +136,7 @@ func runPrompts(ctx context.Context, kvCacheIndexer *kvcache.Indexer) error {
136136
time.Sleep(3 * time.Second)
137137

138138
// Get pods for the prompt
139-
pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.Prompt, modelName, nil)
139+
pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.Prompt, modelName, nil, false)
140140
if err != nil {
141141
return err
142142
}

examples/kv_events/offline/main.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ func runEventsDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, publish
152152
logger.Info("@@@ Starting KV Events Demo", "model", testdata.ModelName)
153153

154154
// Initial query - should be empty since no events have been published
155-
pods, err := kvCacheIndexer.GetPodScores(ctx, testdata.Prompt, testdata.ModelName, nil)
155+
pods, err := kvCacheIndexer.GetPodScores(ctx, testdata.Prompt, testdata.ModelName, nil, false)
156156
if err != nil {
157157
return err
158158
}
@@ -185,7 +185,7 @@ func runEventsDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, publish
185185
time.Sleep(3 * time.Second)
186186

187187
// Query again to see the effect of the events
188-
pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.Prompt, testdata.ModelName, nil)
188+
pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.Prompt, testdata.ModelName, nil, false)
189189
if err != nil {
190190
return err
191191
}
@@ -214,7 +214,7 @@ func runEventsDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, publish
214214
time.Sleep(3 * time.Second)
215215

216216
// Final query
217-
pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.Prompt, testdata.ModelName, nil)
217+
pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.Prompt, testdata.ModelName, nil, false)
218218
if err != nil {
219219
return err
220220
}

examples/kv_events/online/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func main() {
147147
return
148148
}
149149

150-
pods, err := kvCacheIndexer.GetPodScores(ctx, req.Prompt, modelName, nil)
150+
pods, err := kvCacheIndexer.GetPodScores(ctx, req.Prompt, modelName, nil, false)
151151
if err != nil {
152152
http.Error(w, fmt.Sprintf("error: %v", err), http.StatusInternalServerError)
153153
return

pkg/kvcache/indexer.go

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@ package kvcache
1818

1919
import (
2020
"context"
21+
"encoding/json"
2122
"fmt"
2223

2324
"k8s.io/apimachinery/pkg/util/sets"
2425
"k8s.io/klog/v2"
2526

2627
"github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvblock"
2728
"github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization"
29+
chattemplatego "github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization/chat_template_go"
2830
"github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization/prefixstore"
2931
"github.com/llm-d/llm-d-kv-cache-manager/pkg/utils/logging"
3032
)
@@ -115,14 +117,50 @@ func (k *Indexer) KVBlockIndex() kvblock.Index {
115117
//
116118
// The function returns a map of pod identifiers to scores.
117119
func (k *Indexer) GetPodScores(ctx context.Context, prompt, modelName string,
118-
podIdentifiers []string,
120+
podIdentifiers []string, chatCompletion bool,
119121
) (map[string]int, error) {
120122
traceLogger := klog.FromContext(ctx).V(logging.TRACE).WithName("kvcache.GetPodScores")
123+
124+
// Handle chat completion requests
125+
if chatCompletion {
126+
// Parse the prompt as a ChatTemplateRequest JSON
127+
var req chattemplatego.ChatTemplateRequest
128+
if err := json.Unmarshal([]byte(prompt), &req); err != nil {
129+
return nil, fmt.Errorf("failed to parse chat template request: %w", err)
130+
}
131+
132+
// Create or reuse the CGo wrapper (could be a singleton in production)
133+
// TODO: cache, instance management
134+
wrapper := chattemplatego.NewChatTemplateCGoWrapper()
135+
136+
// Fetch the chat template for the model (if not already set)
137+
if req.ChatTemplate == "" {
138+
getReq := chattemplatego.GetChatTemplateRequest{ModelName: modelName}
139+
template, template_vars, err := wrapper.GetModelChatTemplate(getReq)
140+
if err != nil {
141+
return nil, fmt.Errorf("failed to fetch chat template: %w", err)
142+
}
143+
req.ChatTemplate = template
144+
req.TemplateVars = template_vars
145+
}
146+
147+
// Apply the template to the request
148+
resp, err := wrapper.RenderChatTemplate(req)
149+
if err != nil {
150+
return nil, fmt.Errorf("failed to render chat template: %w", err)
151+
}
152+
if len(resp.RenderedChats) == 0 {
153+
return nil, nil
154+
}
155+
prompt = resp.RenderedChats[0]
156+
}
157+
121158
// 0. add to tokenizers pool
122159
k.tokenizersPool.AddTask(prompt, modelName)
123160

124161
// 1. get available tokens of longest prefix
125162
tokens := k.tokensIndexer.FindLongestContainedTokens(prompt, modelName)
163+
126164
if len(tokens) == 0 {
127165
//nolint:nilnil // no need to return an error
128166
return nil, nil
@@ -150,6 +188,14 @@ func (k *Indexer) GetPodScores(ctx context.Context, prompt, modelName string,
150188
return podScores, nil
151189
}
152190

191+
// GetPodScoresDefault is a convenience function for backward compatibility
192+
// that calls GetPodScores with chatCompletion=false
193+
func (k *Indexer) GetPodScoresDefault(ctx context.Context, prompt, modelName string,
194+
podIdentifiers []string,
195+
) (map[string]int, error) {
196+
return k.GetPodScores(ctx, prompt, modelName, podIdentifiers, false)
197+
}
198+
153199
// podsPerKeyPrintHelper formats a map of keys to pod names for printing.
154200
func podsPerKeyPrintHelper(ks map[kvblock.Key][]string) string {
155201
flattened := ""
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# HOWTO: Using `GetCompletionsPodScores` for OpenAI-API ChatCompletions Requests with kv-cache-manager
2+
3+
## Overview
4+
5+
`GetCompletionsPodScores` in `indexer.go` enables the kv-cache-manager to support OpenAI-compatible ChatCompletions requests by rendering the full message structure (including tools and documents) into a prompt using a Python Jinja2 template, before tokenization and KV block key calculation.
6+
7+
---
8+
9+
## What struct do I need to receive from the router?
10+
11+
You must provide a `chattemplatego.ChatTemplateRequest` with the following fields:
12+
13+
```go
14+
// ChatTemplateRequest represents the request to render a chat template
15+
type ChatTemplateRequest struct {
16+
Conversations [][]ChatMessage `json:"conversations"`
17+
Tools []interface{} `json:"tools,omitempty"`
18+
Documents []interface{} `json:"documents,omitempty"`
19+
ChatTemplate string `json:"chat_template,omitempty"`
20+
ReturnAssistantTokensMask bool `json:"return_assistant_tokens_mask,omitempty"`
21+
ContinueFinalMessage bool `json:"continue_final_message,omitempty"`
22+
AddGenerationPrompt bool `json:"add_generation_prompt,omitempty"`
23+
TemplateVars map[string]interface{} `json:"template_vars,omitempty"`
24+
}
25+
```
26+
27+
- **Conversations**: List of message lists (role/content pairs)
28+
- **Tools**: (Optional) List of tool schemas
29+
- **Documents**: (Optional) List of document dicts
30+
- **ChatTemplate**: (Optional) Override for the chat template
31+
- **ReturnAssistantTokensMask**: (Optional) Whether to return assistant token indices
32+
- **ContinueFinalMessage**: (Optional) Whether to continue from the final message
33+
- **AddGenerationPrompt**: (Optional) Whether to add a generation prompt
34+
- **TemplateVars**: (Optional) Special tokens for template rendering
35+
36+
This struct mirrors the OpenAI ChatCompletions request, supporting messages, tools, documents, and advanced template options.
37+
38+
### ChatMessage Struct
39+
40+
The `ChatMessage` struct represents individual messages within conversations:
41+
42+
```go
43+
// ChatMessage represents a single message in a conversation
44+
type ChatMessage struct {
45+
Role string `json:"role"`
46+
Content string `json:"content"`
47+
}
48+
```
49+
50+
- **Role**: The role of the message sender (e.g., "user", "assistant", "system")
51+
- **Content**: The actual message content/text
52+
53+
**Example usage:**
54+
```go
55+
conversation := []chattemplatego.ChatMessage{
56+
{Role: "user", Content: "What is the weather in Paris?"},
57+
{Role: "assistant", Content: "Let me check that for you."},
58+
{Role: "user", Content: "Thank you!"},
59+
}
60+
```
61+
62+
This structure follows the OpenAI ChatCompletions API format, making it compatible with existing chat-based applications.
63+
64+
---
65+
66+
## How do the three scoring functions differ?
67+
68+
- **`GetPromptPodScores`**:
69+
Accepts a simple prompt string, tokenizes it, and calculates KV block keys directly.
70+
71+
- **`GetCompletionsPodScores`**:
72+
Accepts a full `ChatTemplateRequest` (with messages, tools, etc.), uses the Python Jinja2 template (via CGO) to flatten the structure into a prompt, then tokenizes and calculates KV block keys. This ensures the prompt matches what the model would actually see.
73+
74+
- **`GetPodScores`**:
75+
A unified interface that automatically dispatches to either `GetPromptPodScores` or `GetCompletionsPodScores` based on the input type:
76+
- If input is a `string` → calls `GetPromptPodScores`
77+
- If input is a `ChatTemplateRequest` → calls `GetCompletionsPodScores`
78+
- This provides a single entry point for both simple prompts and complex chat completions.
79+
80+
---
81+
82+
## Detailed Flow: `GetCompletionsPodScores` Pipeline
83+
84+
When `indexer.go:GetCompletionsPodScores()` is called, here's the complete flow through files and functions:
85+
86+
```
87+
1. indexer.go:GetCompletionsPodScores(ctx, req, modelName, podIdentifiers)
88+
89+
├── 1.1. **CGO Binding**: chattemplatego.NewChatTemplateCGoWrapper()
90+
│ └── cgo_functions.go:NewChatTemplateCGoWrapper()
91+
│ └── Creates ChatTemplateCGoWrapper struct with initialized=false
92+
93+
├── 1.2. **CGO Binding**: wrapper.GetModelChatTemplate(getReq)
94+
│ ├── cgo_functions.go:GetModelChatTemplate(req)
95+
│ │ ├── Initialize() Python interpreter via CGO
96+
│ │ ├── executePythonCode() - **CGO Binding** to Python
97+
│ │ └── **Python Wrapper**: chat_template_wrapper.py:get_model_chat_template()
98+
│ │ └── Uses Hugging Face AutoTokenizer to fetch model template
99+
│ └── Returns: (template, template_vars)
100+
101+
├── 1.3. **CGO Binding**: wrapper.RenderChatTemplate(req)
102+
│ ├── cgo_functions.go:RenderChatTemplate(req)
103+
│ │ ├── Initialize() Python interpreter via CGO (if not already done)
104+
│ │ ├── executePythonCode() - **CGO Binding** to Python
105+
│ │ └── **Python Wrapper**: chat_template_wrapper.py:render_jinja_template()
106+
│ │ ├── _compile_jinja_template() - Compiles Jinja2 template
107+
│ │ ├── AssistantTracker class - Tracks assistant token indices
108+
│ │ └── Returns: (rendered_chats, generation_indices)
109+
│ └── Returns: ChatTemplateResponse
110+
111+
├── 1.4. Extract prompt from response
112+
│ └── prompt := resp.RenderedChats[0]
113+
114+
├── 1.5. **Tokenization**: k.tokenizersPool.AddTask(prompt, modelName)
115+
│ └── tokenization/pool.go:AddTask() - Queues tokenization task
116+
117+
├── 1.6. **Prefix Store**: k.tokensIndexer.FindLongestContainedTokens(prompt, modelName)
118+
│ └── prefixstore/lru-store.go:FindLongestContainedTokens() - Finds cached tokens
119+
120+
├── 1.7. **Token Processing**: k.tokensProcessor.TokensToKVBlockKeys(tokens, modelName)
121+
│ └── kv-cache/token-processor.go:TokensToKVBlockKeys() - Converts tokens to block keys
122+
123+
├── 1.8. **KV Block Indexing**: k.kvBlockIndexer.GetPodsForKeys(ctx, blockKeys, podSet)
124+
│ └── kv-cache/kvblock-indexer.go:GetPodsForKeys() - Queries Redis for pod mappings
125+
126+
└── 1.9. **Scoring**: k.kvBlockScorer.Score(strBlockKeys, keyToPods)
127+
└── kv-cache/kvblock-scorer.go:Score() - Calculates pod scores
128+
```
129+
130+
### Key Components in the Pipeline:
131+
132+
**🔗 CGO Bindings** (Go → Python):
133+
- `cgo_functions.go` - Provides the bridge between Go and Python
134+
- Uses Python's C API via CGO to call Python functions directly
135+
- Manages Python interpreter lifecycle (Initialize/Finalize)
136+
137+
**📦 Python Wrapper** (Python → Hugging Face):
138+
- `chat_template_wrapper.py` - Wraps Hugging Face's complex template system
139+
- Provides clean API for template rendering and model template fetching
140+
- Handles Jinja2 compilation, assistant tracking, and error handling
141+
142+
**🔄 Data Flow**:
143+
1. **Input**: `ChatTemplateRequest` (messages, tools, documents)
144+
2. **Template Fetching**: Model-specific chat template from Hugging Face
145+
3. **Template Rendering**: Jinja2 template processing with tools/documents
146+
4. **Tokenization**: Convert rendered prompt to tokens
147+
5. **KV Cache Lookup**: Find cached token blocks and associated pods
148+
6. **Scoring**: Calculate pod scores based on cache hits
149+
150+
This pipeline ensures that chat completion requests are properly templated, tokenized, and scored against the KV cache, providing accurate pod recommendations for efficient request routing.
151+
152+
---
153+
154+
## Summary
155+
156+
- The router should send a `ChatTemplateRequest` (not just a prompt string) to the indexer.
157+
- `GetCompletionsPodScores` will handle template rendering and tokenization internally, ensuring correct KV block key calculation for all supported models.
158+
- The integration uses a CGO bridge (`cgo_functions.go`) to call Python (`chat_template_wrapper.py`) for template rendering, matching vLLM and OpenAI API behavior.

0 commit comments

Comments
 (0)