Skip to content

Commit 31ed61a

Browse files
committed
Support /tokenize endpoint
Signed-off-by: Ira <[email protected]>
1 parent 50bbec0 commit 31ed61a

File tree

6 files changed

+206
-14
lines changed

6 files changed

+206
-14
lines changed

pkg/kv-cache/kv_cache.go

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,15 @@ type KVCacheHelper struct {
3535
blockSize int
3636
}
3737

38-
func NewKVCacheHelper(config *common.Configuration, logger logr.Logger, usageChan chan float64) (*KVCacheHelper, error) {
38+
func NewKVCacheHelper(config *common.Configuration, logger logr.Logger, usageChan chan float64,
39+
tokenizer tokenization.Tokenizer) (*KVCacheHelper, error) {
3940
tokenProcConfig := kvblock.DefaultTokenProcessorConfig()
4041
tokenProcConfig.BlockSize = config.TokenBlockSize
4142
if config.HashSeed != "" {
4243
tokenProcConfig.HashSeed = config.HashSeed
4344
}
4445
tokensProcessor := kvblock.NewChunkedTokenDatabase(tokenProcConfig)
4546

46-
tokenizationConfig := tokenization.DefaultConfig()
47-
if config.TokenizersCacheDir != "" {
48-
tokenizationConfig.TokenizersCacheDir = config.TokenizersCacheDir
49-
}
50-
tokenizer, err := tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig)
51-
if err != nil {
52-
return nil, fmt.Errorf("failed to create tokenizer: %w", err)
53-
}
5447
blockCache, err := newBlockCache(config, logger, usageChan)
5548
if err != nil {
5649
return nil, fmt.Errorf("failed to create block cache: %w", err)

pkg/llm-d-inference-sim/metrics_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ var _ = Describe("Simulator metrics", Ordered, func() {
385385
Expect(metrics).To(ContainSubstring("vllm:num_requests_waiting{model_name=\"Qwen/Qwen2-0.5B\"} 0"))
386386
Expect(metrics).To(ContainSubstring("vllm:gpu_cache_usage_perc{model_name=\"Qwen/Qwen2-0.5B\"} 0.125"))
387387

388-
time.Sleep(3 * time.Second)
388+
time.Sleep(4 * time.Second)
389389
metricsResp, err = client.Get(metricsUrl)
390390
Expect(err).NotTo(HaveOccurred())
391391
Expect(metricsResp.StatusCode).To(Equal(http.StatusOK))

pkg/llm-d-inference-sim/simulator.go

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import (
4040
kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache"
4141
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
4242
vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api"
43+
"github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization"
4344
)
4445

4546
const (
@@ -117,6 +118,8 @@ type VllmSimulator struct {
117118
namespace string
118119
// pod name of simulator
119120
pod string
121+
// tokenizer is currently used in kv-cache and in /tokenize
122+
tokenizer tokenization.Tokenizer
120123
}
121124

122125
// New creates a new VllmSimulator instance with the given logger
@@ -197,8 +200,17 @@ func (s *VllmSimulator) startSim(ctx context.Context) error {
197200
return err
198201
}
199202

203+
tokenizationConfig := tokenization.DefaultConfig()
204+
if s.config.TokenizersCacheDir != "" {
205+
tokenizationConfig.TokenizersCacheDir = s.config.TokenizersCacheDir
206+
}
207+
s.tokenizer, err = tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig)
208+
if err != nil {
209+
return fmt.Errorf("failed to create tokenizer: %w", err)
210+
}
211+
200212
if s.config.EnableKVCache {
201-
s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan)
213+
s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan, s.tokenizer)
202214
if err != nil {
203215
return err
204216
}
@@ -248,6 +260,7 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener)
248260
// supports standard Kubernetes health and readiness checks
249261
r.GET("/health", s.HandleHealth)
250262
r.GET("/ready", s.HandleReady)
263+
r.POST("/tokenize", s.HandleTokenize)
251264

252265
server := fasthttp.Server{
253266
ErrorHandler: s.HandleError,
@@ -339,6 +352,59 @@ func (s *VllmSimulator) HandleTextCompletions(ctx *fasthttp.RequestCtx) {
339352
s.handleCompletions(ctx, false)
340353
}
341354

355+
// readTokenizeRequest reads and parses data from the body of the given request
356+
func (s *VllmSimulator) readTokenizeRequest(ctx *fasthttp.RequestCtx) (*vllmapi.TokenizeRequest, error) {
357+
var tokenizeReq vllmapi.TokenizeRequest
358+
if err := json.Unmarshal(ctx.Request.Body(), &tokenizeReq); err != nil {
359+
s.logger.Error(err, "failed to unmarshal tokenize request body")
360+
return nil, err
361+
}
362+
return &tokenizeReq, nil
363+
}
364+
365+
// HandleTokenize http handler for /tokenize
366+
func (s *VllmSimulator) HandleTokenize(ctx *fasthttp.RequestCtx) {
367+
s.logger.Info("tokenize request received")
368+
req, err := s.readTokenizeRequest(ctx)
369+
if err != nil {
370+
s.logger.Error(err, "failed to read and parse tokenize request body")
371+
ctx.Error("Failed to read and parse tokenize request body, "+err.Error(), fasthttp.StatusBadRequest)
372+
return
373+
}
374+
375+
// Check that the request has only one input to tokenize
376+
if req.Prompt != "" && req.Messages != nil {
377+
s.sendCompletionError(ctx, openaiserverapi.NewCompletionError("both prompt and messages fields in tokenize request",
378+
fasthttp.StatusBadRequest, nil), false)
379+
return
380+
}
381+
// Model is optional, if not set, the model from the configuration will be used
382+
model := req.Model
383+
if model == "" {
384+
model = s.config.Model
385+
}
386+
387+
tokens, _, err := s.tokenizer.Encode(req.GetPrompt(), model)
388+
if err != nil {
389+
s.logger.Error(err, "failed to tokenize")
390+
ctx.Error("Failed to tokenize, "+err.Error(), fasthttp.StatusInternalServerError)
391+
return
392+
}
393+
resp := vllmapi.TokenizeResponse{
394+
Count: len(tokens),
395+
Tokens: tokens,
396+
MaxModelLen: s.config.MaxModelLen,
397+
}
398+
data, err := json.Marshal(resp)
399+
if err != nil {
400+
ctx.Error("Response body creation failed, "+err.Error(), fasthttp.StatusInternalServerError)
401+
return
402+
}
403+
ctx.Response.Header.SetContentType("application/json")
404+
ctx.Response.Header.SetStatusCode(fasthttp.StatusOK)
405+
ctx.Response.SetBody(data)
406+
}
407+
342408
func (s *VllmSimulator) HandleLoadLora(ctx *fasthttp.RequestCtx) {
343409
s.logger.Info("load lora request received")
344410
s.loadLora(ctx)

pkg/llm-d-inference-sim/simulator_test.go

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package llmdinferencesim
1818

1919
import (
2020
"context"
21+
"encoding/json"
2122
"errors"
2223
"fmt"
2324
"io"
@@ -29,6 +30,8 @@ import (
2930

3031
"github.com/llm-d/llm-d-inference-sim/pkg/common"
3132
kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache"
33+
vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api"
34+
"github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization"
3235
. "github.com/onsi/ginkgo/v2"
3336
. "github.com/onsi/gomega"
3437
"github.com/openai/openai-go"
@@ -97,8 +100,17 @@ func startServerWithArgs(ctx context.Context, mode string, args []string, envs m
97100
return nil, err
98101
}
99102

103+
tokenizationConfig := tokenization.DefaultConfig()
104+
if s.config.TokenizersCacheDir != "" {
105+
tokenizationConfig.TokenizersCacheDir = s.config.TokenizersCacheDir
106+
}
107+
s.tokenizer, err = tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig)
108+
if err != nil {
109+
return nil, fmt.Errorf("failed to create tokenizer: %w", err)
110+
}
111+
100112
if s.config.EnableKVCache {
101-
s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan)
113+
s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan, s.tokenizer)
102114
if err != nil {
103115
return nil, err
104116
}
@@ -1065,7 +1077,69 @@ var _ = Describe("Simulator", func() {
10651077
Expect(factor).To(BeNumerically(">", 1.0))
10661078
Expect(factor).To(BeNumerically("<", simulator.config.TimeFactorUnderLoad))
10671079
})
1080+
})
1081+
1082+
It("Should return correct response to /tokenize chat", func() {
1083+
ctx := context.TODO()
1084+
modelName := "Qwen/Qwen2-0.5B"
1085+
tmpDir := "./tests-tmp/"
1086+
defer os.RemoveAll(tmpDir)
1087+
args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom,
1088+
"--tokenizers-cache-dir", tmpDir}
1089+
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
1090+
Expect(err).NotTo(HaveOccurred())
1091+
1092+
reqBody := `{
1093+
"messages": [{"role": "user", "content": "This is a test"}],
1094+
"model": "Qwen/Qwen2-0.5B"
1095+
}`
1096+
resp, err := client.Post("http://localhost/tokenize", "application/json", strings.NewReader(reqBody))
1097+
Expect(err).NotTo(HaveOccurred())
1098+
defer func() {
1099+
err := resp.Body.Close()
1100+
Expect(err).NotTo(HaveOccurred())
1101+
}()
1102+
1103+
body, err := io.ReadAll(resp.Body)
1104+
Expect(err).NotTo(HaveOccurred())
10681105

1106+
var tokenizeResp vllmapi.TokenizeResponse
1107+
err = json.Unmarshal(body, &tokenizeResp)
1108+
Expect(err).NotTo(HaveOccurred())
1109+
Expect(tokenizeResp.Count).To(Equal(4))
1110+
Expect(tokenizeResp.Tokens).To(HaveLen(4))
1111+
Expect(tokenizeResp.MaxModelLen).To(Equal(1024))
10691112
})
10701113

1114+
It("Should return correct response to /tokenize text", func() {
1115+
ctx := context.TODO()
1116+
modelName := "Qwen/Qwen2-0.5B"
1117+
tmpDir := "./tests-tmp/"
1118+
defer os.RemoveAll(tmpDir)
1119+
args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom,
1120+
"--tokenizers-cache-dir", tmpDir}
1121+
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
1122+
Expect(err).NotTo(HaveOccurred())
1123+
1124+
reqBody := `{
1125+
"prompt": "This is a test",
1126+
"model": "Qwen/Qwen2-0.5B"
1127+
}`
1128+
resp, err := client.Post("http://localhost/tokenize", "application/json", strings.NewReader(reqBody))
1129+
Expect(err).NotTo(HaveOccurred())
1130+
defer func() {
1131+
err := resp.Body.Close()
1132+
Expect(err).NotTo(HaveOccurred())
1133+
}()
1134+
1135+
body, err := io.ReadAll(resp.Body)
1136+
Expect(err).NotTo(HaveOccurred())
1137+
1138+
var tokenizeResp vllmapi.TokenizeResponse
1139+
err = json.Unmarshal(body, &tokenizeResp)
1140+
Expect(err).NotTo(HaveOccurred())
1141+
Expect(tokenizeResp.Count).To(Equal(4))
1142+
Expect(tokenizeResp.Tokens).To(HaveLen(4))
1143+
Expect(tokenizeResp.MaxModelLen).To(Equal(1024))
1144+
})
10711145
})

pkg/vllm-api/tokenize.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
Copyright 2025 The llm-d-inference-sim Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package vllmapi
18+
19+
import (
20+
"strings"
21+
22+
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
23+
)
24+
25+
// TokenizeRequest is a request to /tokenize endpoint.
26+
// Should contain either a prompt or messages, not both.
27+
type TokenizeRequest struct {
28+
// Model is the model for tokenization
29+
Model string `json:"model"`
30+
// Prompt is the text to tokenize
31+
Prompt string `json:"prompt"`
32+
// Messages is an array of messages to tokenize
33+
Messages []openaiserverapi.Message `json:"messages"`
34+
}
35+
36+
// GetPrompt returns the text to tokenize, either the text prompt
37+
// or the concatenation of the messages (we reject requests with both
38+
// prompt and messages set).
39+
func (t *TokenizeRequest) GetPrompt() string {
40+
if t.Prompt != "" {
41+
return t.Prompt
42+
}
43+
44+
var messages []string
45+
for _, message := range t.Messages {
46+
messages = append(messages, message.Content.PlainText())
47+
}
48+
return strings.Join(messages, " ")
49+
}
50+
51+
// TokenizeResponse is a response for tokenize request
52+
type TokenizeResponse struct {
53+
// MaxModelLen is max model length as dfined in the configuration
54+
MaxModelLen int `json:"max_model_len"`
55+
// Count is the number of returned tokens
56+
Count int `json:"count"`
57+
// Tokens are an array of tokens - the result of request tokenization
58+
Tokens []uint32 `json:"tokens"`
59+
// TokenStrs is currently unsupported, will allways be null
60+
TokenStrs []int `json:"token_strs"`
61+
}

pkg/vllm-api/vllm-models.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616

17-
// Definitions of all structures used by vLLM simulator
18-
// Contains the main simulator class and all definitions related to request/response for all supported APIs
1917
package vllmapi
2018

2119
const (

0 commit comments

Comments
 (0)