diff --git a/go.mod b/go.mod index 45f4866..1bd79a2 100644 --- a/go.mod +++ b/go.mod @@ -21,12 +21,7 @@ require ( golang.org/x/sync v0.12.0 gopkg.in/yaml.v3 v3.0.1 k8s.io/klog/v2 v2.130.1 -) - -require ( - github.com/dgraph-io/ristretto/v2 v2.3.0 // indirect - github.com/dustin/go-humanize v1.0.1 // indirect - go.uber.org/multierr v1.11.0 // indirect + sigs.k8s.io/controller-runtime v0.21.0 ) require ( @@ -35,7 +30,9 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/daulet/tokenizers v1.22.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgraph-io/ristretto/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect @@ -68,6 +65,7 @@ require ( github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect go.uber.org/automaxprocs v1.6.0 // indirect + go.uber.org/multierr v1.11.0 // indirect golang.org/x/net v0.38.0 // indirect golang.org/x/oauth2 v0.27.0 // indirect golang.org/x/sys v0.35.0 // indirect @@ -83,7 +81,6 @@ require ( k8s.io/client-go v0.33.0 // indirect k8s.io/kube-openapi v0.0.0-20250318190949-c8a335a9a2ff // indirect k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect - sigs.k8s.io/controller-runtime v0.21.0 sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect sigs.k8s.io/randfill v1.0.0 // indirect sigs.k8s.io/structured-merge-diff/v4 v4.6.0 // indirect diff --git a/pkg/common/config.go b/pkg/common/config.go index bc51be9..35ebfcb 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -223,6 +223,9 @@ type Configuration struct { // EnableSleepMode enables sleep mode EnableSleepMode bool `yaml:"enable-sleep-mode" json:"enable-sleep-mode"` + + // EnableRequestIDHeaders enables including X-Request-Id header in responses + EnableRequestIDHeaders bool `yaml:"enable-request-id-headers" json:"enable-request-id-headers"` } type Metrics struct { @@ -749,6 +752,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.BoolVar(&config.DatasetInMemory, "dataset-in-memory", config.DatasetInMemory, "Load the entire dataset into memory for faster access") f.BoolVar(&config.EnableSleepMode, "enable-sleep-mode", config.EnableSleepMode, "Enable sleep mode") + f.BoolVar(&config.EnableRequestIDHeaders, "enable-request-id-headers", config.EnableRequestIDHeaders, "Enable including X-Request-Id header in responses") f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures") failureTypes := getParamValueFromArgs("failure-types") diff --git a/pkg/llm-d-inference-sim/server.go b/pkg/llm-d-inference-sim/server.go index 0faf6a3..e353a8a 100644 --- a/pkg/llm-d-inference-sim/server.go +++ b/pkg/llm-d-inference-sim/server.go @@ -109,9 +109,20 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) } } +// getRequestID retrieves the request ID from the X-Request-Id header or generates a new one if not present +func (s *VllmSimulator) getRequestID(ctx *fasthttp.RequestCtx) string { + if s.config.EnableRequestIDHeaders { + requestID := string(ctx.Request.Header.Peek(requestIDHeader)) + if requestID != "" { + return requestID + } + } + return s.random.GenerateUUIDString() +} + // readRequest reads and parses data from the body of the given request according the type defined by isChatCompletion func (s *VllmSimulator) readRequest(ctx *fasthttp.RequestCtx, isChatCompletion bool) (openaiserverapi.CompletionRequest, error) { - requestID := s.random.GenerateUUIDString() + requestID := s.getRequestID(ctx) if isChatCompletion { var req openaiserverapi.ChatCompletionRequest @@ -266,6 +277,11 @@ func (s *VllmSimulator) sendCompletionResponse(ctx *fasthttp.RequestCtx, resp op if s.namespace != "" { ctx.Response.Header.Add(namespaceHeader, s.namespace) } + if s.config.EnableRequestIDHeaders { + if requestID := resp.GetRequestID(); requestID != "" { + ctx.Response.Header.Add(requestIDHeader, requestID) + } + } ctx.Response.SetBody(data) } diff --git a/pkg/llm-d-inference-sim/server_test.go b/pkg/llm-d-inference-sim/server_test.go index 631b1ce..2e2cab9 100644 --- a/pkg/llm-d-inference-sim/server_test.go +++ b/pkg/llm-d-inference-sim/server_test.go @@ -25,11 +25,13 @@ import ( "strings" "time" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/valyala/fasthttp" + "github.com/llm-d/llm-d-inference-sim/pkg/common" kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache" vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" ) const tmpDir = "./tests-tmp/" @@ -212,6 +214,111 @@ var _ = Describe("Server", func() { }) + Context("request ID headers", func() { + testRequestIDHeader := func(enableRequestID bool, endpoint, reqBody, inputRequestID string, expectRequestID *string, validateBody func([]byte)) { + ctx := context.TODO() + args := []string{"cmd", "--model", testModel, "--mode", common.ModeEcho} + if enableRequestID { + args = append(args, "--enable-request-id-headers") + } + client, err := startServerWithArgs(ctx, args) + Expect(err).NotTo(HaveOccurred()) + + req, err := http.NewRequest("POST", "http://localhost"+endpoint, strings.NewReader(reqBody)) + Expect(err).NotTo(HaveOccurred()) + req.Header.Set(fasthttp.HeaderContentType, "application/json") + if inputRequestID != "" { + req.Header.Set(requestIDHeader, inputRequestID) + } + + resp, err := client.Do(req) + Expect(err).NotTo(HaveOccurred()) + defer func() { + err := resp.Body.Close() + Expect(err).NotTo(HaveOccurred()) + }() + + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + if expectRequestID != nil { + actualRequestID := resp.Header.Get(requestIDHeader) + if *expectRequestID != "" { + // When a request ID is provided, it should be echoed back + Expect(actualRequestID).To(Equal(*expectRequestID)) + } else { + // When no request ID is provided, a UUID should be generated + Expect(actualRequestID).NotTo(BeEmpty()) + Expect(len(actualRequestID)).To(BeNumerically(">", 30)) + } + } else { + // When request ID headers are disabled, the header should be empty + Expect(resp.Header.Get(requestIDHeader)).To(BeEmpty()) + } + + if validateBody != nil { + body, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + validateBody(body) + } + } + + DescribeTable("request ID behavior", + testRequestIDHeader, + Entry("includes X-Request-Id when enabled", + true, + "/v1/chat/completions", + `{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5}`, + "test-request-id-123", + ptr("test-request-id-123"), + nil, + ), + Entry("excludes X-Request-Id when disabled", + false, + "/v1/chat/completions", + `{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5}`, + "test-request-id-456", + nil, + nil, + ), + Entry("includes X-Request-Id in streaming response", + true, + "/v1/chat/completions", + `{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5, "stream": true}`, + "test-streaming-789", + ptr("test-streaming-789"), + nil, + ), + Entry("works with text completions endpoint", + true, + "/v1/completions", + `{"prompt": "Hello world", "model": "`+testModel+`", "max_tokens": 5}`, + "text-request-111", + ptr("text-request-111"), + nil, + ), + Entry("generates UUID when no request ID provided", + true, + "/v1/chat/completions", + `{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5}`, + "", + ptr(""), + nil, + ), + Entry("uses request ID in response body ID field", + true, + "/v1/chat/completions", + `{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5}`, + "body-test-999", + ptr("body-test-999"), + func(body []byte) { + var resp map[string]any + Expect(json.Unmarshal(body, &resp)).To(Succeed()) + Expect(resp["id"]).To(Equal("chatcmpl-body-test-999")) + }, + ), + ) + }) + Context("sleep mode", Ordered, func() { AfterAll(func() { err := os.RemoveAll(tmpDir) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 2c08e5d..1d55356 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -51,6 +51,7 @@ const ( podHeader = "x-inference-pod" portHeader = "x-inference-port" namespaceHeader = "x-inference-namespace" + requestIDHeader = "X-Request-Id" podNameEnv = "POD_NAME" podNsEnv = "POD_NAMESPACE" ) @@ -581,9 +582,9 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool // modelName - display name returned to the client and used in metrics. It is either the first alias // from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request). func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, - finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse { - baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(), - time.Now().Unix(), modelName, usageData) + finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool, requestID string) openaiserverapi.CompletionResponse { + baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+requestID, + time.Now().Unix(), modelName, usageData, requestID) if doRemoteDecode { baseResp.KVParams = &openaiserverapi.KVTransferParams{} @@ -663,9 +664,10 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r if toolCalls == nil { logprobs = reqCtx.CompletionReq.GetLogprobs() } + requestID := reqCtx.CompletionReq.GetRequestID() resp := s.createCompletionResponse(logprobs, reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, - reqCtx.CompletionReq.IsDoRemoteDecode()) + reqCtx.CompletionReq.IsDoRemoteDecode(), requestID) // calculate how long to wait before returning the response, time is based on number of tokens nCachedPromptTokens := reqCtx.CompletionReq.GetNumberOfCachedPromptTokens() diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index 7c96fa9..dce8b74 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -60,6 +60,9 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons if s.namespace != "" { context.ctx.Response.Header.Add(namespaceHeader, s.namespace) } + if s.config.EnableRequestIDHeaders { + context.ctx.Response.Header.Add(requestIDHeader, context.requestID) + } context.ctx.SetBodyStreamWriter(func(w *bufio.Writer) { context.creationTime = time.Now().Unix() @@ -176,8 +179,8 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ // createUsageChunk creates and returns a CompletionRespChunk with usage data, a single chunk of streamed completion API response, // supports both modes (text and chat) func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *openaiserverapi.Usage) openaiserverapi.CompletionRespChunk { - baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(), - context.creationTime, context.model, usageData) + baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+context.requestID, + context.creationTime, context.model, usageData, context.requestID) if context.isChatCompletion { baseChunk.Object = chatCompletionChunkObject @@ -191,8 +194,8 @@ func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *o // createTextCompletionChunk creates and returns a CompletionRespChunk, a single chunk of streamed completion API response, // for text completion. func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, token string, finishReason *string) openaiserverapi.CompletionRespChunk { - baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(), - context.creationTime, context.model, nil) + baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+context.requestID, + context.creationTime, context.model, nil, context.requestID) baseChunk.Object = textCompletionObject choice := openaiserverapi.CreateTextRespChoice(openaiserverapi.CreateBaseResponseChoice(0, finishReason), token) @@ -214,8 +217,8 @@ func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, tok // API response, for chat completion. It sets either role, or token, or tool call info in the message. func (s *VllmSimulator) createChatCompletionChunk(context *streamingContext, token string, tool *openaiserverapi.ToolCall, role string, finishReason *string) openaiserverapi.CompletionRespChunk { - baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(), - context.creationTime, context.model, nil) + baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+context.requestID, + context.creationTime, context.model, nil, context.requestID) baseChunk.Object = chatCompletionChunkObject chunk := openaiserverapi.CreateChatCompletionRespChunk(baseChunk, []openaiserverapi.ChatRespChunkChoice{ diff --git a/pkg/llm-d-inference-sim/test_utils.go b/pkg/llm-d-inference-sim/test_utils.go index 0810c04..025f8ca 100644 --- a/pkg/llm-d-inference-sim/test_utils.go +++ b/pkg/llm-d-inference-sim/test_utils.go @@ -541,3 +541,7 @@ func checkSimSleeping(client *http.Client, expectedToSleep bool) { expect := fmt.Sprintf("{\"is_sleeping\":%t}", expectedToSleep) gomega.Expect(string(body)).To(gomega.Equal(expect)) } + +func ptr[T any](v T) *T { + return &v +} diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index b05ba20..9b466fa 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -26,7 +26,9 @@ import ( ) // CompletionResponse interface representing both completion response types (text and chat) -type CompletionResponse interface{} +type CompletionResponse interface { + GetRequestID() string +} // baseCompletionResponse contains base completion response related information type baseCompletionResponse struct { @@ -42,6 +44,8 @@ type baseCompletionResponse struct { Object string `json:"object"` // KVParams kv transfer related fields KVParams *KVTransferParams `json:"kv_transfer_params"` + // RequestID is the unique request ID for tracking + RequestID string `json:"-"` } // Usage contains token Usage statistics @@ -303,8 +307,13 @@ func CreateTextRespChoice(base baseResponseChoice, text string) TextRespChoice { return TextRespChoice{baseResponseChoice: base, Text: text, Logprobs: nil} } -func CreateBaseCompletionResponse(id string, created int64, model string, usage *Usage) baseCompletionResponse { - return baseCompletionResponse{ID: id, Created: created, Model: model, Usage: usage} +func CreateBaseCompletionResponse(id string, created int64, model string, usage *Usage, requestID string) baseCompletionResponse { + return baseCompletionResponse{ID: id, Created: created, Model: model, Usage: usage, RequestID: requestID} +} + +// GetRequestID returns the request ID from the response +func (b baseCompletionResponse) GetRequestID() string { + return b.RequestID } func CreateChatCompletionResponse(base baseCompletionResponse, choices []ChatRespChoice) *ChatCompletionResponse {