diff --git a/pkg/kv-cache/kv_cache.go b/pkg/kv-cache/kv_cache.go index bb676ed9..11d8b547 100644 --- a/pkg/kv-cache/kv_cache.go +++ b/pkg/kv-cache/kv_cache.go @@ -35,7 +35,8 @@ type KVCacheHelper struct { blockSize int } -func NewKVCacheHelper(config *common.Configuration, logger logr.Logger, usageChan chan float64) (*KVCacheHelper, error) { +func NewKVCacheHelper(config *common.Configuration, logger logr.Logger, usageChan chan float64, + tokenizer tokenization.Tokenizer) (*KVCacheHelper, error) { tokenProcConfig := kvblock.DefaultTokenProcessorConfig() tokenProcConfig.BlockSize = config.TokenBlockSize if config.HashSeed != "" { @@ -43,14 +44,6 @@ func NewKVCacheHelper(config *common.Configuration, logger logr.Logger, usageCha } tokensProcessor := kvblock.NewChunkedTokenDatabase(tokenProcConfig) - tokenizationConfig := tokenization.DefaultConfig() - if config.TokenizersCacheDir != "" { - tokenizationConfig.TokenizersCacheDir = config.TokenizersCacheDir - } - tokenizer, err := tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig) - if err != nil { - return nil, fmt.Errorf("failed to create tokenizer: %w", err) - } blockCache, err := newBlockCache(config, logger, usageChan) if err != nil { return nil, fmt.Errorf("failed to create block cache: %w", err) diff --git a/pkg/llm-d-inference-sim/metrics_test.go b/pkg/llm-d-inference-sim/metrics_test.go index bc94c460..2fa385ba 100644 --- a/pkg/llm-d-inference-sim/metrics_test.go +++ b/pkg/llm-d-inference-sim/metrics_test.go @@ -322,11 +322,10 @@ var _ = Describe("Simulator metrics", Ordered, func() { Expect(err).NotTo(HaveOccurred()) }) It("Should send correct kv cache usage metrics", func() { - modelName := "Qwen/Qwen2-0.5B" // Three requests, there are should be two blocks in the kv cache, because // the first and the second prompt share a block. ctx := context.TODO() - args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom, + args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom, "--enable-kvcache", "true", "--kv-cache-size", "16", "--block-size", "8", "--time-to-first-token", "5000", "--tokenizers-cache-dir", tmpDir} @@ -342,19 +341,19 @@ var _ = Describe("Simulator metrics", Ordered, func() { Prompt: openai.CompletionNewParamsPromptUnion{ OfString: openai.String("What is the weather like in Haifa today? Is it cold?"), }, - Model: openai.CompletionNewParamsModel(modelName), + Model: openai.CompletionNewParamsModel(qwenModelName), }, { Prompt: openai.CompletionNewParamsPromptUnion{ OfString: openai.String("What is the weather like in Haifa today?"), }, - Model: openai.CompletionNewParamsModel(modelName), + Model: openai.CompletionNewParamsModel(qwenModelName), }, { Prompt: openai.CompletionNewParamsPromptUnion{ OfString: openai.String("What is the weather like in New York today?"), }, - Model: openai.CompletionNewParamsModel(modelName), + Model: openai.CompletionNewParamsModel(qwenModelName), }, } @@ -385,7 +384,7 @@ var _ = Describe("Simulator metrics", Ordered, func() { Expect(metrics).To(ContainSubstring("vllm:num_requests_waiting{model_name=\"Qwen/Qwen2-0.5B\"} 0")) Expect(metrics).To(ContainSubstring("vllm:gpu_cache_usage_perc{model_name=\"Qwen/Qwen2-0.5B\"} 0.125")) - time.Sleep(3 * time.Second) + time.Sleep(4 * time.Second) metricsResp, err = client.Get(metricsUrl) Expect(err).NotTo(HaveOccurred()) Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) @@ -402,9 +401,8 @@ var _ = Describe("Simulator metrics", Ordered, func() { }) It("Should send correct kv cache usage metrics for sequentual requests", func() { - modelName := "Qwen/Qwen2-0.5B" ctx := context.TODO() - args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom, + args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom, "--enable-kvcache", "true", "--kv-cache-size", "16", "--block-size", "8", "--time-to-first-token", "5000", "--tokenizers-cache-dir", tmpDir, "--max-num-seqs", "2"} @@ -420,19 +418,19 @@ var _ = Describe("Simulator metrics", Ordered, func() { Prompt: openai.CompletionNewParamsPromptUnion{ OfString: openai.String("What is the weather like in Haifa today? Is it cold?"), }, - Model: openai.CompletionNewParamsModel(modelName), + Model: openai.CompletionNewParamsModel(qwenModelName), }, { Prompt: openai.CompletionNewParamsPromptUnion{ OfString: openai.String("What is the weather like in Haifa today?"), }, - Model: openai.CompletionNewParamsModel(modelName), + Model: openai.CompletionNewParamsModel(qwenModelName), }, { Prompt: openai.CompletionNewParamsPromptUnion{ OfString: openai.String("What is the weather like in New York today?"), }, - Model: openai.CompletionNewParamsModel(modelName), + Model: openai.CompletionNewParamsModel(qwenModelName), }, } diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index b91057db..32d76ee7 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -40,6 +40,7 @@ import ( kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api" + "github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization" ) const ( @@ -117,6 +118,8 @@ type VllmSimulator struct { namespace string // pod name of simulator pod string + // tokenizer is currently used in kv-cache and in /tokenize + tokenizer tokenization.Tokenizer } // New creates a new VllmSimulator instance with the given logger @@ -197,8 +200,17 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { return err } + tokenizationConfig := tokenization.DefaultConfig() + if s.config.TokenizersCacheDir != "" { + tokenizationConfig.TokenizersCacheDir = s.config.TokenizersCacheDir + } + s.tokenizer, err = tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig) + if err != nil { + return fmt.Errorf("failed to create tokenizer: %w", err) + } + if s.config.EnableKVCache { - s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan) + s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan, s.tokenizer) if err != nil { return err } @@ -248,6 +260,7 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) // supports standard Kubernetes health and readiness checks r.GET("/health", s.HandleHealth) r.GET("/ready", s.HandleReady) + r.POST("/tokenize", s.HandleTokenize) server := fasthttp.Server{ ErrorHandler: s.HandleError, @@ -339,6 +352,59 @@ func (s *VllmSimulator) HandleTextCompletions(ctx *fasthttp.RequestCtx) { s.handleCompletions(ctx, false) } +// readTokenizeRequest reads and parses data from the body of the given request +func (s *VllmSimulator) readTokenizeRequest(ctx *fasthttp.RequestCtx) (*vllmapi.TokenizeRequest, error) { + var tokenizeReq vllmapi.TokenizeRequest + if err := json.Unmarshal(ctx.Request.Body(), &tokenizeReq); err != nil { + s.logger.Error(err, "failed to unmarshal tokenize request body") + return nil, err + } + return &tokenizeReq, nil +} + +// HandleTokenize http handler for /tokenize +func (s *VllmSimulator) HandleTokenize(ctx *fasthttp.RequestCtx) { + s.logger.Info("tokenize request received") + req, err := s.readTokenizeRequest(ctx) + if err != nil { + s.logger.Error(err, "failed to read and parse tokenize request body") + ctx.Error("Failed to read and parse tokenize request body, "+err.Error(), fasthttp.StatusBadRequest) + return + } + + // Check that the request has only one input to tokenize + if req.Prompt != "" && req.Messages != nil { + s.sendCompletionError(ctx, openaiserverapi.NewCompletionError("both prompt and messages fields in tokenize request", + fasthttp.StatusBadRequest, nil), false) + return + } + // Model is optional, if not set, the model from the configuration will be used + model := req.Model + if model == "" { + model = s.config.Model + } + + tokens, _, err := s.tokenizer.Encode(req.GetPrompt(), model) + if err != nil { + s.logger.Error(err, "failed to tokenize") + ctx.Error("Failed to tokenize, "+err.Error(), fasthttp.StatusInternalServerError) + return + } + resp := vllmapi.TokenizeResponse{ + Count: len(tokens), + Tokens: tokens, + MaxModelLen: s.config.MaxModelLen, + } + data, err := json.Marshal(resp) + if err != nil { + ctx.Error("Response body creation failed, "+err.Error(), fasthttp.StatusInternalServerError) + return + } + ctx.Response.Header.SetContentType("application/json") + ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) + ctx.Response.SetBody(data) +} + func (s *VllmSimulator) HandleLoadLora(ctx *fasthttp.RequestCtx) { s.logger.Info("load lora request received") s.loadLora(ctx) diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index b70beb12..8b78cc58 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -18,6 +18,7 @@ package llmdinferencesim import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -29,6 +30,8 @@ import ( "github.com/llm-d/llm-d-inference-sim/pkg/common" kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache" + vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api" + "github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/openai/openai-go" @@ -39,6 +42,7 @@ import ( ) const model = "my_model" +const qwenModelName = "Qwen/Qwen2-0.5B" const baseURL = "http://localhost/v1" const userMessage = "This is a test." const invalidMaxTokensErrMsg = "Max completion tokens and max tokens should be positive" @@ -97,8 +101,17 @@ func startServerWithArgs(ctx context.Context, mode string, args []string, envs m return nil, err } + tokenizationConfig := tokenization.DefaultConfig() + if s.config.TokenizersCacheDir != "" { + tokenizationConfig.TokenizersCacheDir = s.config.TokenizersCacheDir + } + s.tokenizer, err = tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig) + if err != nil { + return nil, fmt.Errorf("failed to create tokenizer: %w", err) + } + if s.config.EnableKVCache { - s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan) + s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan, s.tokenizer) if err != nil { return nil, err } @@ -1065,7 +1078,71 @@ var _ = Describe("Simulator", func() { Expect(factor).To(BeNumerically(">", 1.0)) Expect(factor).To(BeNumerically("<", simulator.config.TimeFactorUnderLoad)) }) - }) + Context("tokenize", Ordered, func() { + tmpDir := "./tests-tmp/" + AfterAll(func() { + err := os.RemoveAll(tmpDir) + Expect(err).NotTo(HaveOccurred()) + }) + + It("Should return correct response to /tokenize chat", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom, + "--tokenizers-cache-dir", tmpDir, "--max-model-len", "2048"} + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) + Expect(err).NotTo(HaveOccurred()) + + reqBody := `{ + "messages": [{"role": "user", "content": "This is a test"}], + "model": "Qwen/Qwen2-0.5B" + }` + resp, err := client.Post("http://localhost/tokenize", "application/json", strings.NewReader(reqBody)) + Expect(err).NotTo(HaveOccurred()) + defer func() { + err := resp.Body.Close() + Expect(err).NotTo(HaveOccurred()) + }() + + body, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + + var tokenizeResp vllmapi.TokenizeResponse + err = json.Unmarshal(body, &tokenizeResp) + Expect(err).NotTo(HaveOccurred()) + Expect(tokenizeResp.Count).To(Equal(4)) + Expect(tokenizeResp.Tokens).To(HaveLen(4)) + Expect(tokenizeResp.MaxModelLen).To(Equal(2048)) + }) + + It("Should return correct response to /tokenize text", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom, + "--tokenizers-cache-dir", tmpDir, "--max-model-len", "2048"} + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) + Expect(err).NotTo(HaveOccurred()) + + reqBody := `{ + "prompt": "This is a test", + "model": "Qwen/Qwen2-0.5B" + }` + resp, err := client.Post("http://localhost/tokenize", "application/json", strings.NewReader(reqBody)) + Expect(err).NotTo(HaveOccurred()) + defer func() { + err := resp.Body.Close() + Expect(err).NotTo(HaveOccurred()) + }() + + body, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + + var tokenizeResp vllmapi.TokenizeResponse + err = json.Unmarshal(body, &tokenizeResp) + Expect(err).NotTo(HaveOccurred()) + Expect(tokenizeResp.Count).To(Equal(4)) + Expect(tokenizeResp.Tokens).To(HaveLen(4)) + Expect(tokenizeResp.MaxModelLen).To(Equal(2048)) + }) + }) }) diff --git a/pkg/vllm-api/tokenize.go b/pkg/vllm-api/tokenize.go new file mode 100644 index 00000000..55760e9e --- /dev/null +++ b/pkg/vllm-api/tokenize.go @@ -0,0 +1,61 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vllmapi + +import ( + "strings" + + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" +) + +// TokenizeRequest is a request to /tokenize endpoint. +// Should contain either a prompt or messages, not both. +type TokenizeRequest struct { + // Model is the model for tokenization + Model string `json:"model"` + // Prompt is the text to tokenize + Prompt string `json:"prompt"` + // Messages is an array of messages to tokenize + Messages []openaiserverapi.Message `json:"messages"` +} + +// GetPrompt returns the text to tokenize, either the text prompt +// or the concatenation of the messages (we reject requests with both +// prompt and messages set). +func (t *TokenizeRequest) GetPrompt() string { + if t.Prompt != "" { + return t.Prompt + } + + messages := make([]string, 0) + for _, message := range t.Messages { + messages = append(messages, message.Content.PlainText()) + } + return strings.Join(messages, " ") +} + +// TokenizeResponse is a response for tokenize request +type TokenizeResponse struct { + // MaxModelLen is max model length as dfined in the configuration + MaxModelLen int `json:"max_model_len"` + // Count is the number of returned tokens + Count int `json:"count"` + // Tokens are an array of tokens - the result of request tokenization + Tokens []uint32 `json:"tokens"` + // TokenStrs is currently unsupported, will allways be null + TokenStrs []int `json:"token_strs"` +} diff --git a/pkg/vllm-api/vllm-models.go b/pkg/vllm-api/vllm-models.go index 6a9eff06..76c564af 100644 --- a/pkg/vllm-api/vllm-models.go +++ b/pkg/vllm-api/vllm-models.go @@ -14,8 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -// Definitions of all structures used by vLLM simulator -// Contains the main simulator class and all definitions related to request/response for all supported APIs package vllmapi const (