Skip to content

Commit 699452c

Browse files
authored
Support /tokenize endpoint (#198)
* Support /tokenize endpoint Signed-off-by: Ira <[email protected]> * Fix lint errors Signed-off-by: Ira <[email protected]> * Fix test Signed-off-by: Ira <[email protected]> --------- Signed-off-by: Ira <[email protected]>
1 parent 50bbec0 commit 699452c

File tree

6 files changed

+218
-25
lines changed

6 files changed

+218
-25
lines changed

pkg/kv-cache/kv_cache.go

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,15 @@ type KVCacheHelper struct {
3535
blockSize int
3636
}
3737

38-
func NewKVCacheHelper(config *common.Configuration, logger logr.Logger, usageChan chan float64) (*KVCacheHelper, error) {
38+
func NewKVCacheHelper(config *common.Configuration, logger logr.Logger, usageChan chan float64,
39+
tokenizer tokenization.Tokenizer) (*KVCacheHelper, error) {
3940
tokenProcConfig := kvblock.DefaultTokenProcessorConfig()
4041
tokenProcConfig.BlockSize = config.TokenBlockSize
4142
if config.HashSeed != "" {
4243
tokenProcConfig.HashSeed = config.HashSeed
4344
}
4445
tokensProcessor := kvblock.NewChunkedTokenDatabase(tokenProcConfig)
4546

46-
tokenizationConfig := tokenization.DefaultConfig()
47-
if config.TokenizersCacheDir != "" {
48-
tokenizationConfig.TokenizersCacheDir = config.TokenizersCacheDir
49-
}
50-
tokenizer, err := tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig)
51-
if err != nil {
52-
return nil, fmt.Errorf("failed to create tokenizer: %w", err)
53-
}
5447
blockCache, err := newBlockCache(config, logger, usageChan)
5548
if err != nil {
5649
return nil, fmt.Errorf("failed to create block cache: %w", err)

pkg/llm-d-inference-sim/metrics_test.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -322,11 +322,10 @@ var _ = Describe("Simulator metrics", Ordered, func() {
322322
Expect(err).NotTo(HaveOccurred())
323323
})
324324
It("Should send correct kv cache usage metrics", func() {
325-
modelName := "Qwen/Qwen2-0.5B"
326325
// Three requests, there are should be two blocks in the kv cache, because
327326
// the first and the second prompt share a block.
328327
ctx := context.TODO()
329-
args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom,
328+
args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom,
330329
"--enable-kvcache", "true", "--kv-cache-size", "16", "--block-size", "8",
331330
"--time-to-first-token", "5000", "--tokenizers-cache-dir", tmpDir}
332331

@@ -342,19 +341,19 @@ var _ = Describe("Simulator metrics", Ordered, func() {
342341
Prompt: openai.CompletionNewParamsPromptUnion{
343342
OfString: openai.String("What is the weather like in Haifa today? Is it cold?"),
344343
},
345-
Model: openai.CompletionNewParamsModel(modelName),
344+
Model: openai.CompletionNewParamsModel(qwenModelName),
346345
},
347346
{
348347
Prompt: openai.CompletionNewParamsPromptUnion{
349348
OfString: openai.String("What is the weather like in Haifa today?"),
350349
},
351-
Model: openai.CompletionNewParamsModel(modelName),
350+
Model: openai.CompletionNewParamsModel(qwenModelName),
352351
},
353352
{
354353
Prompt: openai.CompletionNewParamsPromptUnion{
355354
OfString: openai.String("What is the weather like in New York today?"),
356355
},
357-
Model: openai.CompletionNewParamsModel(modelName),
356+
Model: openai.CompletionNewParamsModel(qwenModelName),
358357
},
359358
}
360359

@@ -385,7 +384,7 @@ var _ = Describe("Simulator metrics", Ordered, func() {
385384
Expect(metrics).To(ContainSubstring("vllm:num_requests_waiting{model_name=\"Qwen/Qwen2-0.5B\"} 0"))
386385
Expect(metrics).To(ContainSubstring("vllm:gpu_cache_usage_perc{model_name=\"Qwen/Qwen2-0.5B\"} 0.125"))
387386

388-
time.Sleep(3 * time.Second)
387+
time.Sleep(4 * time.Second)
389388
metricsResp, err = client.Get(metricsUrl)
390389
Expect(err).NotTo(HaveOccurred())
391390
Expect(metricsResp.StatusCode).To(Equal(http.StatusOK))
@@ -402,9 +401,8 @@ var _ = Describe("Simulator metrics", Ordered, func() {
402401
})
403402

404403
It("Should send correct kv cache usage metrics for sequentual requests", func() {
405-
modelName := "Qwen/Qwen2-0.5B"
406404
ctx := context.TODO()
407-
args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom,
405+
args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom,
408406
"--enable-kvcache", "true", "--kv-cache-size", "16", "--block-size", "8",
409407
"--time-to-first-token", "5000", "--tokenizers-cache-dir", tmpDir, "--max-num-seqs", "2"}
410408

@@ -420,19 +418,19 @@ var _ = Describe("Simulator metrics", Ordered, func() {
420418
Prompt: openai.CompletionNewParamsPromptUnion{
421419
OfString: openai.String("What is the weather like in Haifa today? Is it cold?"),
422420
},
423-
Model: openai.CompletionNewParamsModel(modelName),
421+
Model: openai.CompletionNewParamsModel(qwenModelName),
424422
},
425423
{
426424
Prompt: openai.CompletionNewParamsPromptUnion{
427425
OfString: openai.String("What is the weather like in Haifa today?"),
428426
},
429-
Model: openai.CompletionNewParamsModel(modelName),
427+
Model: openai.CompletionNewParamsModel(qwenModelName),
430428
},
431429
{
432430
Prompt: openai.CompletionNewParamsPromptUnion{
433431
OfString: openai.String("What is the weather like in New York today?"),
434432
},
435-
Model: openai.CompletionNewParamsModel(modelName),
433+
Model: openai.CompletionNewParamsModel(qwenModelName),
436434
},
437435
}
438436

pkg/llm-d-inference-sim/simulator.go

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import (
4040
kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache"
4141
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
4242
vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api"
43+
"github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization"
4344
)
4445

4546
const (
@@ -117,6 +118,8 @@ type VllmSimulator struct {
117118
namespace string
118119
// pod name of simulator
119120
pod string
121+
// tokenizer is currently used in kv-cache and in /tokenize
122+
tokenizer tokenization.Tokenizer
120123
}
121124

122125
// New creates a new VllmSimulator instance with the given logger
@@ -197,8 +200,17 @@ func (s *VllmSimulator) startSim(ctx context.Context) error {
197200
return err
198201
}
199202

203+
tokenizationConfig := tokenization.DefaultConfig()
204+
if s.config.TokenizersCacheDir != "" {
205+
tokenizationConfig.TokenizersCacheDir = s.config.TokenizersCacheDir
206+
}
207+
s.tokenizer, err = tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig)
208+
if err != nil {
209+
return fmt.Errorf("failed to create tokenizer: %w", err)
210+
}
211+
200212
if s.config.EnableKVCache {
201-
s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan)
213+
s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan, s.tokenizer)
202214
if err != nil {
203215
return err
204216
}
@@ -248,6 +260,7 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener)
248260
// supports standard Kubernetes health and readiness checks
249261
r.GET("/health", s.HandleHealth)
250262
r.GET("/ready", s.HandleReady)
263+
r.POST("/tokenize", s.HandleTokenize)
251264

252265
server := fasthttp.Server{
253266
ErrorHandler: s.HandleError,
@@ -339,6 +352,59 @@ func (s *VllmSimulator) HandleTextCompletions(ctx *fasthttp.RequestCtx) {
339352
s.handleCompletions(ctx, false)
340353
}
341354

355+
// readTokenizeRequest reads and parses data from the body of the given request
356+
func (s *VllmSimulator) readTokenizeRequest(ctx *fasthttp.RequestCtx) (*vllmapi.TokenizeRequest, error) {
357+
var tokenizeReq vllmapi.TokenizeRequest
358+
if err := json.Unmarshal(ctx.Request.Body(), &tokenizeReq); err != nil {
359+
s.logger.Error(err, "failed to unmarshal tokenize request body")
360+
return nil, err
361+
}
362+
return &tokenizeReq, nil
363+
}
364+
365+
// HandleTokenize http handler for /tokenize
366+
func (s *VllmSimulator) HandleTokenize(ctx *fasthttp.RequestCtx) {
367+
s.logger.Info("tokenize request received")
368+
req, err := s.readTokenizeRequest(ctx)
369+
if err != nil {
370+
s.logger.Error(err, "failed to read and parse tokenize request body")
371+
ctx.Error("Failed to read and parse tokenize request body, "+err.Error(), fasthttp.StatusBadRequest)
372+
return
373+
}
374+
375+
// Check that the request has only one input to tokenize
376+
if req.Prompt != "" && req.Messages != nil {
377+
s.sendCompletionError(ctx, openaiserverapi.NewCompletionError("both prompt and messages fields in tokenize request",
378+
fasthttp.StatusBadRequest, nil), false)
379+
return
380+
}
381+
// Model is optional, if not set, the model from the configuration will be used
382+
model := req.Model
383+
if model == "" {
384+
model = s.config.Model
385+
}
386+
387+
tokens, _, err := s.tokenizer.Encode(req.GetPrompt(), model)
388+
if err != nil {
389+
s.logger.Error(err, "failed to tokenize")
390+
ctx.Error("Failed to tokenize, "+err.Error(), fasthttp.StatusInternalServerError)
391+
return
392+
}
393+
resp := vllmapi.TokenizeResponse{
394+
Count: len(tokens),
395+
Tokens: tokens,
396+
MaxModelLen: s.config.MaxModelLen,
397+
}
398+
data, err := json.Marshal(resp)
399+
if err != nil {
400+
ctx.Error("Response body creation failed, "+err.Error(), fasthttp.StatusInternalServerError)
401+
return
402+
}
403+
ctx.Response.Header.SetContentType("application/json")
404+
ctx.Response.Header.SetStatusCode(fasthttp.StatusOK)
405+
ctx.Response.SetBody(data)
406+
}
407+
342408
func (s *VllmSimulator) HandleLoadLora(ctx *fasthttp.RequestCtx) {
343409
s.logger.Info("load lora request received")
344410
s.loadLora(ctx)

pkg/llm-d-inference-sim/simulator_test.go

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package llmdinferencesim
1818

1919
import (
2020
"context"
21+
"encoding/json"
2122
"errors"
2223
"fmt"
2324
"io"
@@ -29,6 +30,8 @@ import (
2930

3031
"github.com/llm-d/llm-d-inference-sim/pkg/common"
3132
kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache"
33+
vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api"
34+
"github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization"
3235
. "github.com/onsi/ginkgo/v2"
3336
. "github.com/onsi/gomega"
3437
"github.com/openai/openai-go"
@@ -39,6 +42,7 @@ import (
3942
)
4043

4144
const model = "my_model"
45+
const qwenModelName = "Qwen/Qwen2-0.5B"
4246
const baseURL = "http://localhost/v1"
4347
const userMessage = "This is a test."
4448
const invalidMaxTokensErrMsg = "Max completion tokens and max tokens should be positive"
@@ -97,8 +101,17 @@ func startServerWithArgs(ctx context.Context, mode string, args []string, envs m
97101
return nil, err
98102
}
99103

104+
tokenizationConfig := tokenization.DefaultConfig()
105+
if s.config.TokenizersCacheDir != "" {
106+
tokenizationConfig.TokenizersCacheDir = s.config.TokenizersCacheDir
107+
}
108+
s.tokenizer, err = tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig)
109+
if err != nil {
110+
return nil, fmt.Errorf("failed to create tokenizer: %w", err)
111+
}
112+
100113
if s.config.EnableKVCache {
101-
s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan)
114+
s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan, s.tokenizer)
102115
if err != nil {
103116
return nil, err
104117
}
@@ -1065,7 +1078,71 @@ var _ = Describe("Simulator", func() {
10651078
Expect(factor).To(BeNumerically(">", 1.0))
10661079
Expect(factor).To(BeNumerically("<", simulator.config.TimeFactorUnderLoad))
10671080
})
1068-
10691081
})
10701082

1083+
Context("tokenize", Ordered, func() {
1084+
tmpDir := "./tests-tmp/"
1085+
AfterAll(func() {
1086+
err := os.RemoveAll(tmpDir)
1087+
Expect(err).NotTo(HaveOccurred())
1088+
})
1089+
1090+
It("Should return correct response to /tokenize chat", func() {
1091+
ctx := context.TODO()
1092+
args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom,
1093+
"--tokenizers-cache-dir", tmpDir, "--max-model-len", "2048"}
1094+
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
1095+
Expect(err).NotTo(HaveOccurred())
1096+
1097+
reqBody := `{
1098+
"messages": [{"role": "user", "content": "This is a test"}],
1099+
"model": "Qwen/Qwen2-0.5B"
1100+
}`
1101+
resp, err := client.Post("http://localhost/tokenize", "application/json", strings.NewReader(reqBody))
1102+
Expect(err).NotTo(HaveOccurred())
1103+
defer func() {
1104+
err := resp.Body.Close()
1105+
Expect(err).NotTo(HaveOccurred())
1106+
}()
1107+
1108+
body, err := io.ReadAll(resp.Body)
1109+
Expect(err).NotTo(HaveOccurred())
1110+
1111+
var tokenizeResp vllmapi.TokenizeResponse
1112+
err = json.Unmarshal(body, &tokenizeResp)
1113+
Expect(err).NotTo(HaveOccurred())
1114+
Expect(tokenizeResp.Count).To(Equal(4))
1115+
Expect(tokenizeResp.Tokens).To(HaveLen(4))
1116+
Expect(tokenizeResp.MaxModelLen).To(Equal(2048))
1117+
})
1118+
1119+
It("Should return correct response to /tokenize text", func() {
1120+
ctx := context.TODO()
1121+
args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom,
1122+
"--tokenizers-cache-dir", tmpDir, "--max-model-len", "2048"}
1123+
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
1124+
Expect(err).NotTo(HaveOccurred())
1125+
1126+
reqBody := `{
1127+
"prompt": "This is a test",
1128+
"model": "Qwen/Qwen2-0.5B"
1129+
}`
1130+
resp, err := client.Post("http://localhost/tokenize", "application/json", strings.NewReader(reqBody))
1131+
Expect(err).NotTo(HaveOccurred())
1132+
defer func() {
1133+
err := resp.Body.Close()
1134+
Expect(err).NotTo(HaveOccurred())
1135+
}()
1136+
1137+
body, err := io.ReadAll(resp.Body)
1138+
Expect(err).NotTo(HaveOccurred())
1139+
1140+
var tokenizeResp vllmapi.TokenizeResponse
1141+
err = json.Unmarshal(body, &tokenizeResp)
1142+
Expect(err).NotTo(HaveOccurred())
1143+
Expect(tokenizeResp.Count).To(Equal(4))
1144+
Expect(tokenizeResp.Tokens).To(HaveLen(4))
1145+
Expect(tokenizeResp.MaxModelLen).To(Equal(2048))
1146+
})
1147+
})
10711148
})

pkg/vllm-api/tokenize.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
Copyright 2025 The llm-d-inference-sim Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package vllmapi
18+
19+
import (
20+
"strings"
21+
22+
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
23+
)
24+
25+
// TokenizeRequest is a request to /tokenize endpoint.
26+
// Should contain either a prompt or messages, not both.
27+
type TokenizeRequest struct {
28+
// Model is the model for tokenization
29+
Model string `json:"model"`
30+
// Prompt is the text to tokenize
31+
Prompt string `json:"prompt"`
32+
// Messages is an array of messages to tokenize
33+
Messages []openaiserverapi.Message `json:"messages"`
34+
}
35+
36+
// GetPrompt returns the text to tokenize, either the text prompt
37+
// or the concatenation of the messages (we reject requests with both
38+
// prompt and messages set).
39+
func (t *TokenizeRequest) GetPrompt() string {
40+
if t.Prompt != "" {
41+
return t.Prompt
42+
}
43+
44+
messages := make([]string, 0)
45+
for _, message := range t.Messages {
46+
messages = append(messages, message.Content.PlainText())
47+
}
48+
return strings.Join(messages, " ")
49+
}
50+
51+
// TokenizeResponse is a response for tokenize request
52+
type TokenizeResponse struct {
53+
// MaxModelLen is max model length as dfined in the configuration
54+
MaxModelLen int `json:"max_model_len"`
55+
// Count is the number of returned tokens
56+
Count int `json:"count"`
57+
// Tokens are an array of tokens - the result of request tokenization
58+
Tokens []uint32 `json:"tokens"`
59+
// TokenStrs is currently unsupported, will allways be null
60+
TokenStrs []int `json:"token_strs"`
61+
}

pkg/vllm-api/vllm-models.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616

17-
// Definitions of all structures used by vLLM simulator
18-
// Contains the main simulator class and all definitions related to request/response for all supported APIs
1917
package vllmapi
2018

2119
const (

0 commit comments

Comments
 (0)