Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions pkg/kv-cache/kv_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,15 @@ type KVCacheHelper struct {
blockSize int
}

func NewKVCacheHelper(config *common.Configuration, logger logr.Logger, usageChan chan float64) (*KVCacheHelper, error) {
func NewKVCacheHelper(config *common.Configuration, logger logr.Logger, usageChan chan float64,
tokenizer tokenization.Tokenizer) (*KVCacheHelper, error) {
tokenProcConfig := kvblock.DefaultTokenProcessorConfig()
tokenProcConfig.BlockSize = config.TokenBlockSize
if config.HashSeed != "" {
tokenProcConfig.HashSeed = config.HashSeed
}
tokensProcessor := kvblock.NewChunkedTokenDatabase(tokenProcConfig)

tokenizationConfig := tokenization.DefaultConfig()
if config.TokenizersCacheDir != "" {
tokenizationConfig.TokenizersCacheDir = config.TokenizersCacheDir
}
tokenizer, err := tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig)
if err != nil {
return nil, fmt.Errorf("failed to create tokenizer: %w", err)
}
blockCache, err := newBlockCache(config, logger, usageChan)
if err != nil {
return nil, fmt.Errorf("failed to create block cache: %w", err)
Expand Down
20 changes: 9 additions & 11 deletions pkg/llm-d-inference-sim/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,10 @@ var _ = Describe("Simulator metrics", Ordered, func() {
Expect(err).NotTo(HaveOccurred())
})
It("Should send correct kv cache usage metrics", func() {
modelName := "Qwen/Qwen2-0.5B"
// Three requests, there are should be two blocks in the kv cache, because
// the first and the second prompt share a block.
ctx := context.TODO()
args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom,
args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom,
"--enable-kvcache", "true", "--kv-cache-size", "16", "--block-size", "8",
"--time-to-first-token", "5000", "--tokenizers-cache-dir", tmpDir}

Expand All @@ -342,19 +341,19 @@ var _ = Describe("Simulator metrics", Ordered, func() {
Prompt: openai.CompletionNewParamsPromptUnion{
OfString: openai.String("What is the weather like in Haifa today? Is it cold?"),
},
Model: openai.CompletionNewParamsModel(modelName),
Model: openai.CompletionNewParamsModel(qwenModelName),
},
{
Prompt: openai.CompletionNewParamsPromptUnion{
OfString: openai.String("What is the weather like in Haifa today?"),
},
Model: openai.CompletionNewParamsModel(modelName),
Model: openai.CompletionNewParamsModel(qwenModelName),
},
{
Prompt: openai.CompletionNewParamsPromptUnion{
OfString: openai.String("What is the weather like in New York today?"),
},
Model: openai.CompletionNewParamsModel(modelName),
Model: openai.CompletionNewParamsModel(qwenModelName),
},
}

Expand Down Expand Up @@ -385,7 +384,7 @@ var _ = Describe("Simulator metrics", Ordered, func() {
Expect(metrics).To(ContainSubstring("vllm:num_requests_waiting{model_name=\"Qwen/Qwen2-0.5B\"} 0"))
Expect(metrics).To(ContainSubstring("vllm:gpu_cache_usage_perc{model_name=\"Qwen/Qwen2-0.5B\"} 0.125"))

time.Sleep(3 * time.Second)
time.Sleep(4 * time.Second)
metricsResp, err = client.Get(metricsUrl)
Expect(err).NotTo(HaveOccurred())
Expect(metricsResp.StatusCode).To(Equal(http.StatusOK))
Expand All @@ -402,9 +401,8 @@ var _ = Describe("Simulator metrics", Ordered, func() {
})

It("Should send correct kv cache usage metrics for sequentual requests", func() {
modelName := "Qwen/Qwen2-0.5B"
ctx := context.TODO()
args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom,
args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom,
"--enable-kvcache", "true", "--kv-cache-size", "16", "--block-size", "8",
"--time-to-first-token", "5000", "--tokenizers-cache-dir", tmpDir, "--max-num-seqs", "2"}

Expand All @@ -420,19 +418,19 @@ var _ = Describe("Simulator metrics", Ordered, func() {
Prompt: openai.CompletionNewParamsPromptUnion{
OfString: openai.String("What is the weather like in Haifa today? Is it cold?"),
},
Model: openai.CompletionNewParamsModel(modelName),
Model: openai.CompletionNewParamsModel(qwenModelName),
},
{
Prompt: openai.CompletionNewParamsPromptUnion{
OfString: openai.String("What is the weather like in Haifa today?"),
},
Model: openai.CompletionNewParamsModel(modelName),
Model: openai.CompletionNewParamsModel(qwenModelName),
},
{
Prompt: openai.CompletionNewParamsPromptUnion{
OfString: openai.String("What is the weather like in New York today?"),
},
Model: openai.CompletionNewParamsModel(modelName),
Model: openai.CompletionNewParamsModel(qwenModelName),
},
}

Expand Down
68 changes: 67 additions & 1 deletion pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache"
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api"
"github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization"
)

const (
Expand Down Expand Up @@ -117,6 +118,8 @@ type VllmSimulator struct {
namespace string
// pod name of simulator
pod string
// tokenizer is currently used in kv-cache and in /tokenize
tokenizer tokenization.Tokenizer
}

// New creates a new VllmSimulator instance with the given logger
Expand Down Expand Up @@ -197,8 +200,17 @@ func (s *VllmSimulator) startSim(ctx context.Context) error {
return err
}

tokenizationConfig := tokenization.DefaultConfig()
if s.config.TokenizersCacheDir != "" {
tokenizationConfig.TokenizersCacheDir = s.config.TokenizersCacheDir
}
s.tokenizer, err = tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig)
if err != nil {
return fmt.Errorf("failed to create tokenizer: %w", err)
}

if s.config.EnableKVCache {
s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan)
s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan, s.tokenizer)
if err != nil {
return err
}
Expand Down Expand Up @@ -248,6 +260,7 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener)
// supports standard Kubernetes health and readiness checks
r.GET("/health", s.HandleHealth)
r.GET("/ready", s.HandleReady)
r.POST("/tokenize", s.HandleTokenize)

server := fasthttp.Server{
ErrorHandler: s.HandleError,
Expand Down Expand Up @@ -339,6 +352,59 @@ func (s *VllmSimulator) HandleTextCompletions(ctx *fasthttp.RequestCtx) {
s.handleCompletions(ctx, false)
}

// readTokenizeRequest reads and parses data from the body of the given request
func (s *VllmSimulator) readTokenizeRequest(ctx *fasthttp.RequestCtx) (*vllmapi.TokenizeRequest, error) {
var tokenizeReq vllmapi.TokenizeRequest
if err := json.Unmarshal(ctx.Request.Body(), &tokenizeReq); err != nil {
s.logger.Error(err, "failed to unmarshal tokenize request body")
return nil, err
}
return &tokenizeReq, nil
}

// HandleTokenize http handler for /tokenize
func (s *VllmSimulator) HandleTokenize(ctx *fasthttp.RequestCtx) {
s.logger.Info("tokenize request received")
req, err := s.readTokenizeRequest(ctx)
if err != nil {
s.logger.Error(err, "failed to read and parse tokenize request body")
ctx.Error("Failed to read and parse tokenize request body, "+err.Error(), fasthttp.StatusBadRequest)
return
}

// Check that the request has only one input to tokenize
if req.Prompt != "" && req.Messages != nil {
s.sendCompletionError(ctx, openaiserverapi.NewCompletionError("both prompt and messages fields in tokenize request",
fasthttp.StatusBadRequest, nil), false)
return
}
// Model is optional, if not set, the model from the configuration will be used
model := req.Model
if model == "" {
model = s.config.Model
}

tokens, _, err := s.tokenizer.Encode(req.GetPrompt(), model)
if err != nil {
s.logger.Error(err, "failed to tokenize")
ctx.Error("Failed to tokenize, "+err.Error(), fasthttp.StatusInternalServerError)
return
}
resp := vllmapi.TokenizeResponse{
Count: len(tokens),
Tokens: tokens,
MaxModelLen: s.config.MaxModelLen,
}
data, err := json.Marshal(resp)
if err != nil {
ctx.Error("Response body creation failed, "+err.Error(), fasthttp.StatusInternalServerError)
return
}
ctx.Response.Header.SetContentType("application/json")
ctx.Response.Header.SetStatusCode(fasthttp.StatusOK)
ctx.Response.SetBody(data)
}

func (s *VllmSimulator) HandleLoadLora(ctx *fasthttp.RequestCtx) {
s.logger.Info("load lora request received")
s.loadLora(ctx)
Expand Down
81 changes: 79 additions & 2 deletions pkg/llm-d-inference-sim/simulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package llmdinferencesim

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
Expand All @@ -29,6 +30,8 @@ import (

"github.com/llm-d/llm-d-inference-sim/pkg/common"
kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache"
vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api"
"github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/openai/openai-go"
Expand All @@ -39,6 +42,7 @@ import (
)

const model = "my_model"
const qwenModelName = "Qwen/Qwen2-0.5B"
const baseURL = "http://localhost/v1"
const userMessage = "This is a test."
const invalidMaxTokensErrMsg = "Max completion tokens and max tokens should be positive"
Expand Down Expand Up @@ -97,8 +101,17 @@ func startServerWithArgs(ctx context.Context, mode string, args []string, envs m
return nil, err
}

tokenizationConfig := tokenization.DefaultConfig()
if s.config.TokenizersCacheDir != "" {
tokenizationConfig.TokenizersCacheDir = s.config.TokenizersCacheDir
}
s.tokenizer, err = tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig)
if err != nil {
return nil, fmt.Errorf("failed to create tokenizer: %w", err)
}

if s.config.EnableKVCache {
s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan)
s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan, s.tokenizer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1065,7 +1078,71 @@ var _ = Describe("Simulator", func() {
Expect(factor).To(BeNumerically(">", 1.0))
Expect(factor).To(BeNumerically("<", simulator.config.TimeFactorUnderLoad))
})

})

Context("tokenize", Ordered, func() {
tmpDir := "./tests-tmp/"
AfterAll(func() {
err := os.RemoveAll(tmpDir)
Expect(err).NotTo(HaveOccurred())
})

It("Should return correct response to /tokenize chat", func() {
ctx := context.TODO()
args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom,
"--tokenizers-cache-dir", tmpDir, "--max-model-len", "2048"}
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
Expect(err).NotTo(HaveOccurred())

reqBody := `{
"messages": [{"role": "user", "content": "This is a test"}],
"model": "Qwen/Qwen2-0.5B"
}`
resp, err := client.Post("http://localhost/tokenize", "application/json", strings.NewReader(reqBody))
Expect(err).NotTo(HaveOccurred())
defer func() {
err := resp.Body.Close()
Expect(err).NotTo(HaveOccurred())
}()

body, err := io.ReadAll(resp.Body)
Expect(err).NotTo(HaveOccurred())

var tokenizeResp vllmapi.TokenizeResponse
err = json.Unmarshal(body, &tokenizeResp)
Expect(err).NotTo(HaveOccurred())
Expect(tokenizeResp.Count).To(Equal(4))
Expect(tokenizeResp.Tokens).To(HaveLen(4))
Expect(tokenizeResp.MaxModelLen).To(Equal(2048))
})

It("Should return correct response to /tokenize text", func() {
ctx := context.TODO()
args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom,
"--tokenizers-cache-dir", tmpDir, "--max-model-len", "2048"}
client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil)
Expect(err).NotTo(HaveOccurred())

reqBody := `{
"prompt": "This is a test",
"model": "Qwen/Qwen2-0.5B"
}`
resp, err := client.Post("http://localhost/tokenize", "application/json", strings.NewReader(reqBody))
Expect(err).NotTo(HaveOccurred())
defer func() {
err := resp.Body.Close()
Expect(err).NotTo(HaveOccurred())
}()

body, err := io.ReadAll(resp.Body)
Expect(err).NotTo(HaveOccurred())

var tokenizeResp vllmapi.TokenizeResponse
err = json.Unmarshal(body, &tokenizeResp)
Expect(err).NotTo(HaveOccurred())
Expect(tokenizeResp.Count).To(Equal(4))
Expect(tokenizeResp.Tokens).To(HaveLen(4))
Expect(tokenizeResp.MaxModelLen).To(Equal(2048))
})
})
})
61 changes: 61 additions & 0 deletions pkg/vllm-api/tokenize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
Copyright 2025 The llm-d-inference-sim Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package vllmapi

import (
"strings"

openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
)

// TokenizeRequest is a request to /tokenize endpoint.
// Should contain either a prompt or messages, not both.
type TokenizeRequest struct {
// Model is the model for tokenization
Model string `json:"model"`
// Prompt is the text to tokenize
Prompt string `json:"prompt"`
// Messages is an array of messages to tokenize
Messages []openaiserverapi.Message `json:"messages"`
}

// GetPrompt returns the text to tokenize, either the text prompt
// or the concatenation of the messages (we reject requests with both
// prompt and messages set).
func (t *TokenizeRequest) GetPrompt() string {
if t.Prompt != "" {
return t.Prompt
}

messages := make([]string, 0)
for _, message := range t.Messages {
messages = append(messages, message.Content.PlainText())
}
return strings.Join(messages, " ")
}

// TokenizeResponse is a response for tokenize request
type TokenizeResponse struct {
// MaxModelLen is max model length as dfined in the configuration
MaxModelLen int `json:"max_model_len"`
// Count is the number of returned tokens
Count int `json:"count"`
// Tokens are an array of tokens - the result of request tokenization
Tokens []uint32 `json:"tokens"`
// TokenStrs is currently unsupported, will allways be null
TokenStrs []int `json:"token_strs"`
}
2 changes: 0 additions & 2 deletions pkg/vllm-api/vllm-models.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

// Definitions of all structures used by vLLM simulator
// Contains the main simulator class and all definitions related to request/response for all supported APIs
package vllmapi

const (
Expand Down
Loading