Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ require (
golang.org/x/sync v0.12.0
gopkg.in/yaml.v3 v3.0.1
k8s.io/klog/v2 v2.130.1
)

require (
github.com/dgraph-io/ristretto/v2 v2.3.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
go.uber.org/multierr v1.11.0 // indirect
sigs.k8s.io/controller-runtime v0.21.0
)

require (
Expand All @@ -35,7 +30,9 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/daulet/tokenizers v1.22.1 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgraph-io/ristretto/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
Expand Down Expand Up @@ -68,6 +65,7 @@ require (
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/oauth2 v0.27.0 // indirect
golang.org/x/sys v0.35.0 // indirect
Expand All @@ -83,7 +81,6 @@ require (
k8s.io/client-go v0.33.0 // indirect
k8s.io/kube-openapi v0.0.0-20250318190949-c8a335a9a2ff // indirect
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect
sigs.k8s.io/controller-runtime v0.21.0
sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect
sigs.k8s.io/randfill v1.0.0 // indirect
sigs.k8s.io/structured-merge-diff/v4 v4.6.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions pkg/common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ type Configuration struct {

// EnableSleepMode enables sleep mode
EnableSleepMode bool `yaml:"enable-sleep-mode" json:"enable-sleep-mode"`

// EnableRequestIDHeaders enables including X-Request-Id header in responses
EnableRequestIDHeaders bool `yaml:"enable-request-id-headers" json:"enable-request-id-headers"`
}

type Metrics struct {
Expand Down Expand Up @@ -749,6 +752,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
f.BoolVar(&config.DatasetInMemory, "dataset-in-memory", config.DatasetInMemory, "Load the entire dataset into memory for faster access")

f.BoolVar(&config.EnableSleepMode, "enable-sleep-mode", config.EnableSleepMode, "Enable sleep mode")
f.BoolVar(&config.EnableRequestIDHeaders, "enable-request-id-headers", config.EnableRequestIDHeaders, "Enable including X-Request-Id header in responses")

f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures")
failureTypes := getParamValueFromArgs("failure-types")
Expand Down
18 changes: 17 additions & 1 deletion pkg/llm-d-inference-sim/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,20 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener)
}
}

// getRequestID retrieves the request ID from the X-Request-Id header or generates a new one if not present
func (s *VllmSimulator) getRequestID(ctx *fasthttp.RequestCtx) string {
if s.config.EnableRequestIDHeaders {
requestID := string(ctx.Request.Header.Peek(requestIDHeader))
if requestID != "" {
return requestID
}
}
return s.random.GenerateUUIDString()
}

// readRequest reads and parses data from the body of the given request according the type defined by isChatCompletion
func (s *VllmSimulator) readRequest(ctx *fasthttp.RequestCtx, isChatCompletion bool) (openaiserverapi.CompletionRequest, error) {
requestID := s.random.GenerateUUIDString()
requestID := s.getRequestID(ctx)

if isChatCompletion {
var req openaiserverapi.ChatCompletionRequest
Expand Down Expand Up @@ -266,6 +277,11 @@ func (s *VllmSimulator) sendCompletionResponse(ctx *fasthttp.RequestCtx, resp op
if s.namespace != "" {
ctx.Response.Header.Add(namespaceHeader, s.namespace)
}
if s.config.EnableRequestIDHeaders {
if requestID := resp.GetRequestID(); requestID != "" {
ctx.Response.Header.Add(requestIDHeader, requestID)
}
}
ctx.Response.SetBody(data)
}

Expand Down
111 changes: 109 additions & 2 deletions pkg/llm-d-inference-sim/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ import (
"strings"
"time"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/valyala/fasthttp"

"github.com/llm-d/llm-d-inference-sim/pkg/common"
kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache"
vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

const tmpDir = "./tests-tmp/"
Expand Down Expand Up @@ -212,6 +214,111 @@ var _ = Describe("Server", func() {

})

Context("request ID headers", func() {
testRequestIDHeader := func(enableRequestID bool, endpoint, reqBody, inputRequestID string, expectRequestID *string, validateBody func([]byte)) {
ctx := context.TODO()
args := []string{"cmd", "--model", testModel, "--mode", common.ModeEcho}
if enableRequestID {
args = append(args, "--enable-request-id-headers")
}
client, err := startServerWithArgs(ctx, args)
Expect(err).NotTo(HaveOccurred())

req, err := http.NewRequest("POST", "http://localhost"+endpoint, strings.NewReader(reqBody))
Expect(err).NotTo(HaveOccurred())
req.Header.Set(fasthttp.HeaderContentType, "application/json")
if inputRequestID != "" {
req.Header.Set(requestIDHeader, inputRequestID)
}

resp, err := client.Do(req)
Expect(err).NotTo(HaveOccurred())
defer func() {
err := resp.Body.Close()
Expect(err).NotTo(HaveOccurred())
}()

Expect(resp.StatusCode).To(Equal(http.StatusOK))

if expectRequestID != nil {
actualRequestID := resp.Header.Get(requestIDHeader)
if *expectRequestID != "" {
// When a request ID is provided, it should be echoed back
Expect(actualRequestID).To(Equal(*expectRequestID))
} else {
// When no request ID is provided, a UUID should be generated
Expect(actualRequestID).NotTo(BeEmpty())
Expect(len(actualRequestID)).To(BeNumerically(">", 30))
}
} else {
// When request ID headers are disabled, the header should be empty
Expect(resp.Header.Get(requestIDHeader)).To(BeEmpty())
}

if validateBody != nil {
body, err := io.ReadAll(resp.Body)
Expect(err).NotTo(HaveOccurred())
validateBody(body)
}
}

DescribeTable("request ID behavior",
testRequestIDHeader,
Entry("includes X-Request-Id when enabled",
true,
"/v1/chat/completions",
`{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5}`,
"test-request-id-123",
ptr("test-request-id-123"),
nil,
),
Entry("excludes X-Request-Id when disabled",
false,
"/v1/chat/completions",
`{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5}`,
"test-request-id-456",
nil,
nil,
),
Entry("includes X-Request-Id in streaming response",
true,
"/v1/chat/completions",
`{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5, "stream": true}`,
"test-streaming-789",
ptr("test-streaming-789"),
nil,
),
Entry("works with text completions endpoint",
true,
"/v1/completions",
`{"prompt": "Hello world", "model": "`+testModel+`", "max_tokens": 5}`,
"text-request-111",
ptr("text-request-111"),
nil,
),
Entry("generates UUID when no request ID provided",
true,
"/v1/chat/completions",
`{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5}`,
"",
ptr(""),
nil,
),
Entry("uses request ID in response body ID field",
true,
"/v1/chat/completions",
`{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5}`,
"body-test-999",
ptr("body-test-999"),
func(body []byte) {
var resp map[string]any
Expect(json.Unmarshal(body, &resp)).To(Succeed())
Expect(resp["id"]).To(Equal("chatcmpl-body-test-999"))
},
),
)
})

Context("sleep mode", Ordered, func() {
AfterAll(func() {
err := os.RemoveAll(tmpDir)
Expand Down
10 changes: 6 additions & 4 deletions pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const (
podHeader = "x-inference-pod"
portHeader = "x-inference-port"
namespaceHeader = "x-inference-namespace"
requestIDHeader = "X-Request-Id"
podNameEnv = "POD_NAME"
podNsEnv = "POD_NAMESPACE"
)
Expand Down Expand Up @@ -581,9 +582,9 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool
// modelName - display name returned to the client and used in metrics. It is either the first alias
// from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request).
func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall,
finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse {
baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
time.Now().Unix(), modelName, usageData)
finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool, requestID string) openaiserverapi.CompletionResponse {
baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+requestID,
time.Now().Unix(), modelName, usageData, requestID)

if doRemoteDecode {
baseResp.KVParams = &openaiserverapi.KVTransferParams{}
Expand Down Expand Up @@ -663,9 +664,10 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r
if toolCalls == nil {
logprobs = reqCtx.CompletionReq.GetLogprobs()
}
requestID := reqCtx.CompletionReq.GetRequestID()

resp := s.createCompletionResponse(logprobs, reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName,
reqCtx.CompletionReq.IsDoRemoteDecode())
reqCtx.CompletionReq.IsDoRemoteDecode(), requestID)

// calculate how long to wait before returning the response, time is based on number of tokens
nCachedPromptTokens := reqCtx.CompletionReq.GetNumberOfCachedPromptTokens()
Expand Down
15 changes: 9 additions & 6 deletions pkg/llm-d-inference-sim/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons
if s.namespace != "" {
context.ctx.Response.Header.Add(namespaceHeader, s.namespace)
}
if s.config.EnableRequestIDHeaders {
context.ctx.Response.Header.Add(requestIDHeader, context.requestID)
}

context.ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
context.creationTime = time.Now().Unix()
Expand Down Expand Up @@ -176,8 +179,8 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ
// createUsageChunk creates and returns a CompletionRespChunk with usage data, a single chunk of streamed completion API response,
// supports both modes (text and chat)
func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *openaiserverapi.Usage) openaiserverapi.CompletionRespChunk {
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
context.creationTime, context.model, usageData)
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+context.requestID,
context.creationTime, context.model, usageData, context.requestID)

if context.isChatCompletion {
baseChunk.Object = chatCompletionChunkObject
Expand All @@ -191,8 +194,8 @@ func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *o
// createTextCompletionChunk creates and returns a CompletionRespChunk, a single chunk of streamed completion API response,
// for text completion.
func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, token string, finishReason *string) openaiserverapi.CompletionRespChunk {
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
context.creationTime, context.model, nil)
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+context.requestID,
context.creationTime, context.model, nil, context.requestID)
baseChunk.Object = textCompletionObject

choice := openaiserverapi.CreateTextRespChoice(openaiserverapi.CreateBaseResponseChoice(0, finishReason), token)
Expand All @@ -214,8 +217,8 @@ func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, tok
// API response, for chat completion. It sets either role, or token, or tool call info in the message.
func (s *VllmSimulator) createChatCompletionChunk(context *streamingContext, token string, tool *openaiserverapi.ToolCall,
role string, finishReason *string) openaiserverapi.CompletionRespChunk {
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
context.creationTime, context.model, nil)
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+context.requestID,
context.creationTime, context.model, nil, context.requestID)
baseChunk.Object = chatCompletionChunkObject
chunk := openaiserverapi.CreateChatCompletionRespChunk(baseChunk,
[]openaiserverapi.ChatRespChunkChoice{
Expand Down
4 changes: 4 additions & 0 deletions pkg/llm-d-inference-sim/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -541,3 +541,7 @@ func checkSimSleeping(client *http.Client, expectedToSleep bool) {
expect := fmt.Sprintf("{\"is_sleeping\":%t}", expectedToSleep)
gomega.Expect(string(body)).To(gomega.Equal(expect))
}

func ptr[T any](v T) *T {
return &v
}
15 changes: 12 additions & 3 deletions pkg/openai-server-api/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import (
)

// CompletionResponse interface representing both completion response types (text and chat)
type CompletionResponse interface{}
type CompletionResponse interface {
GetRequestID() string
}

// baseCompletionResponse contains base completion response related information
type baseCompletionResponse struct {
Expand All @@ -42,6 +44,8 @@ type baseCompletionResponse struct {
Object string `json:"object"`
// KVParams kv transfer related fields
KVParams *KVTransferParams `json:"kv_transfer_params"`
// RequestID is the unique request ID for tracking
RequestID string `json:"-"`
}

// Usage contains token Usage statistics
Expand Down Expand Up @@ -303,8 +307,13 @@ func CreateTextRespChoice(base baseResponseChoice, text string) TextRespChoice {
return TextRespChoice{baseResponseChoice: base, Text: text, Logprobs: nil}
}

func CreateBaseCompletionResponse(id string, created int64, model string, usage *Usage) baseCompletionResponse {
return baseCompletionResponse{ID: id, Created: created, Model: model, Usage: usage}
func CreateBaseCompletionResponse(id string, created int64, model string, usage *Usage, requestID string) baseCompletionResponse {
return baseCompletionResponse{ID: id, Created: created, Model: model, Usage: usage, RequestID: requestID}
}

// GetRequestID returns the request ID from the response
func (b baseCompletionResponse) GetRequestID() string {
return b.RequestID
}

func CreateChatCompletionResponse(base baseCompletionResponse, choices []ChatRespChoice) *ChatCompletionResponse {
Expand Down
Loading