Skip to content

Commit 3823ef9

Browse files
committed
feat: add support for X-Request-Id header in responses and logs
Signed-off-by: rudeigerc <[email protected]>
1 parent cecbf25 commit 3823ef9

File tree

8 files changed

+159
-17
lines changed

8 files changed

+159
-17
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ require (
2121
golang.org/x/sync v0.12.0
2222
gopkg.in/yaml.v3 v3.0.1
2323
k8s.io/klog/v2 v2.130.1
24+
sigs.k8s.io/controller-runtime v0.21.0
2425
)
2526

2627
require (
@@ -77,7 +78,6 @@ require (
7778
k8s.io/client-go v0.33.0 // indirect
7879
k8s.io/kube-openapi v0.0.0-20250318190949-c8a335a9a2ff // indirect
7980
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect
80-
sigs.k8s.io/controller-runtime v0.21.0 // indirect
8181
sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect
8282
sigs.k8s.io/randfill v1.0.0 // indirect
8383
sigs.k8s.io/structured-merge-diff/v4 v4.6.0 // indirect

pkg/common/config.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ type Configuration struct {
223223

224224
// EnableSleepMode enables sleep mode
225225
EnableSleepMode bool `yaml:"enable-sleep-mode" json:"enable-sleep-mode"`
226+
227+
// EnableRequestIDHeaders enables including X-Request-Id header in responses
228+
EnableRequestIDHeaders bool `yaml:"enable-request-id-headers" json:"enable-request-id-headers"`
226229
}
227230

228231
type Metrics struct {
@@ -749,6 +752,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
749752
f.BoolVar(&config.DatasetInMemory, "dataset-in-memory", config.DatasetInMemory, "Load the entire dataset into memory for faster access")
750753

751754
f.BoolVar(&config.EnableSleepMode, "enable-sleep-mode", config.EnableSleepMode, "Enable sleep mode")
755+
f.BoolVar(&config.EnableRequestIDHeaders, "enable-request-id-headers", config.EnableRequestIDHeaders, "Enable including X-Request-Id header in responses")
752756

753757
f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures")
754758
failureTypes := getParamValueFromArgs("failure-types")

pkg/llm-d-inference-sim/server.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,20 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener)
109109
}
110110
}
111111

112+
// getRequestID retrieves the request ID from the X-Request-Id header or generates a new one if not present
113+
func (s *VllmSimulator) getRequestID(ctx *fasthttp.RequestCtx) string {
114+
if s.config.EnableRequestIDHeaders {
115+
requestID := string(ctx.Request.Header.Peek(requestIDHeader))
116+
if requestID != "" {
117+
return requestID
118+
}
119+
}
120+
return s.random.GenerateUUIDString()
121+
}
122+
112123
// readRequest reads and parses data from the body of the given request according the type defined by isChatCompletion
113124
func (s *VllmSimulator) readRequest(ctx *fasthttp.RequestCtx, isChatCompletion bool) (openaiserverapi.CompletionRequest, error) {
114-
requestID := s.random.GenerateUUIDString()
125+
requestID := s.getRequestID(ctx)
115126

116127
if isChatCompletion {
117128
var req openaiserverapi.ChatCompletionRequest
@@ -266,6 +277,11 @@ func (s *VllmSimulator) sendCompletionResponse(ctx *fasthttp.RequestCtx, resp op
266277
if s.namespace != "" {
267278
ctx.Response.Header.Add(namespaceHeader, s.namespace)
268279
}
280+
if s.config.EnableRequestIDHeaders {
281+
if requestID := resp.GetRequestID(); requestID != "" {
282+
ctx.Response.Header.Add(requestIDHeader, requestID)
283+
}
284+
}
269285
ctx.Response.SetBody(data)
270286
}
271287

pkg/llm-d-inference-sim/server_test.go

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@ import (
2525
"strings"
2626
"time"
2727

28+
. "github.com/onsi/ginkgo/v2"
29+
. "github.com/onsi/gomega"
30+
"github.com/valyala/fasthttp"
31+
2832
"github.com/llm-d/llm-d-inference-sim/pkg/common"
2933
kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache"
3034
vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api"
31-
. "github.com/onsi/ginkgo/v2"
32-
. "github.com/onsi/gomega"
3335
)
3436

3537
const tmpDir = "./tests-tmp/"
@@ -212,6 +214,108 @@ var _ = Describe("Server", func() {
212214

213215
})
214216

217+
Context("request ID headers", func() {
218+
testRequestIDHeader := func(enableRequestID bool, endpoint, reqBody, inputRequestID string, expectRequestID *string, validateBody func([]byte)) {
219+
ctx := context.TODO()
220+
args := []string{"cmd", "--model", testModel, "--mode", common.ModeEcho}
221+
if enableRequestID {
222+
args = append(args, "--enable-request-id-headers")
223+
}
224+
client, err := startServerWithArgs(ctx, args)
225+
Expect(err).NotTo(HaveOccurred())
226+
227+
req, err := http.NewRequest("POST", "http://localhost"+endpoint, strings.NewReader(reqBody))
228+
Expect(err).NotTo(HaveOccurred())
229+
req.Header.Set(fasthttp.HeaderContentType, "application/json")
230+
if inputRequestID != "" {
231+
req.Header.Set(requestIDHeader, inputRequestID)
232+
}
233+
234+
resp, err := client.Do(req)
235+
Expect(err).NotTo(HaveOccurred())
236+
defer resp.Body.Close()
237+
238+
Expect(resp.StatusCode).To(Equal(http.StatusOK))
239+
240+
if expectRequestID != nil {
241+
actualRequestID := resp.Header.Get(requestIDHeader)
242+
if *expectRequestID != "" {
243+
// When a request ID is provided, it should be echoed back
244+
Expect(actualRequestID).To(Equal(*expectRequestID))
245+
} else {
246+
// When no request ID is provided, a UUID should be generated
247+
Expect(actualRequestID).NotTo(BeEmpty())
248+
Expect(len(actualRequestID)).To(BeNumerically(">", 30))
249+
}
250+
} else {
251+
// When request ID headers are disabled, the header should be empty
252+
Expect(resp.Header.Get(requestIDHeader)).To(BeEmpty())
253+
}
254+
255+
if validateBody != nil {
256+
body, err := io.ReadAll(resp.Body)
257+
Expect(err).NotTo(HaveOccurred())
258+
validateBody(body)
259+
}
260+
}
261+
262+
DescribeTable("request ID behavior",
263+
testRequestIDHeader,
264+
Entry("includes X-Request-Id when enabled",
265+
true,
266+
"/v1/chat/completions",
267+
`{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5}`,
268+
"test-request-id-123",
269+
ptr("test-request-id-123"),
270+
nil,
271+
),
272+
Entry("excludes X-Request-Id when disabled",
273+
false,
274+
"/v1/chat/completions",
275+
`{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5}`,
276+
"test-request-id-456",
277+
nil,
278+
nil,
279+
),
280+
Entry("includes X-Request-Id in streaming response",
281+
true,
282+
"/v1/chat/completions",
283+
`{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5, "stream": true}`,
284+
"test-streaming-789",
285+
ptr("test-streaming-789"),
286+
nil,
287+
),
288+
Entry("works with text completions endpoint",
289+
true,
290+
"/v1/completions",
291+
`{"prompt": "Hello world", "model": "`+testModel+`", "max_tokens": 5}`,
292+
"text-request-111",
293+
ptr("text-request-111"),
294+
nil,
295+
),
296+
Entry("generates UUID when no request ID provided",
297+
true,
298+
"/v1/chat/completions",
299+
`{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5}`,
300+
"",
301+
ptr(""),
302+
nil,
303+
),
304+
Entry("uses request ID in response body ID field",
305+
true,
306+
"/v1/chat/completions",
307+
`{"messages": [{"role": "user", "content": "Hello"}], "model": "`+testModel+`", "max_tokens": 5}`,
308+
"body-test-999",
309+
ptr("body-test-999"),
310+
func(body []byte) {
311+
var resp map[string]any
312+
Expect(json.Unmarshal(body, &resp)).To(Succeed())
313+
Expect(resp["id"]).To(Equal("chatcmpl-body-test-999"))
314+
},
315+
),
316+
)
317+
})
318+
215319
Context("sleep mode", Ordered, func() {
216320
AfterAll(func() {
217321
err := os.RemoveAll(tmpDir)

pkg/llm-d-inference-sim/simulator.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ const (
5151
podHeader = "x-inference-pod"
5252
portHeader = "x-inference-port"
5353
namespaceHeader = "x-inference-namespace"
54+
requestIDHeader = "X-Request-Id"
5455
podNameEnv = "POD_NAME"
5556
podNsEnv = "POD_NAMESPACE"
5657
)
@@ -573,9 +574,9 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool
573574
// modelName - display name returned to the client and used in metrics. It is either the first alias
574575
// from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request).
575576
func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall,
576-
finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse {
577-
baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
578-
time.Now().Unix(), modelName, usageData)
577+
finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool, requestID string) openaiserverapi.CompletionResponse {
578+
baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+requestID,
579+
time.Now().Unix(), modelName, usageData, requestID)
579580

580581
if doRemoteDecode {
581582
baseResp.KVParams = &openaiserverapi.KVTransferParams{}
@@ -655,9 +656,10 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r
655656
if toolCalls == nil {
656657
logprobs = reqCtx.CompletionReq.GetLogprobs()
657658
}
659+
requestID := reqCtx.CompletionReq.GetRequestID()
658660

659661
resp := s.createCompletionResponse(logprobs, reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName,
660-
reqCtx.CompletionReq.IsDoRemoteDecode())
662+
reqCtx.CompletionReq.IsDoRemoteDecode(), requestID)
661663

662664
// calculate how long to wait before returning the response, time is based on number of tokens
663665
nCachedPromptTokens := reqCtx.CompletionReq.GetNumberOfCachedPromptTokens()

pkg/llm-d-inference-sim/streaming.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons
6060
if s.namespace != "" {
6161
context.ctx.Response.Header.Add(namespaceHeader, s.namespace)
6262
}
63+
if s.config.EnableRequestIDHeaders {
64+
context.ctx.Response.Header.Add(requestIDHeader, context.requestID)
65+
}
6366

6467
context.ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
6568
context.creationTime = time.Now().Unix()
@@ -176,8 +179,8 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ
176179
// createUsageChunk creates and returns a CompletionRespChunk with usage data, a single chunk of streamed completion API response,
177180
// supports both modes (text and chat)
178181
func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *openaiserverapi.Usage) openaiserverapi.CompletionRespChunk {
179-
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
180-
context.creationTime, context.model, usageData)
182+
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+context.requestID,
183+
context.creationTime, context.model, usageData, context.requestID)
181184

182185
if context.isChatCompletion {
183186
baseChunk.Object = chatCompletionChunkObject
@@ -191,8 +194,8 @@ func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *o
191194
// createTextCompletionChunk creates and returns a CompletionRespChunk, a single chunk of streamed completion API response,
192195
// for text completion.
193196
func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, token string, finishReason *string) openaiserverapi.CompletionRespChunk {
194-
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
195-
context.creationTime, context.model, nil)
197+
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+context.requestID,
198+
context.creationTime, context.model, nil, context.requestID)
196199
baseChunk.Object = textCompletionObject
197200

198201
choice := openaiserverapi.CreateTextRespChoice(openaiserverapi.CreateBaseResponseChoice(0, finishReason), token)
@@ -214,8 +217,8 @@ func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, tok
214217
// API response, for chat completion. It sets either role, or token, or tool call info in the message.
215218
func (s *VllmSimulator) createChatCompletionChunk(context *streamingContext, token string, tool *openaiserverapi.ToolCall,
216219
role string, finishReason *string) openaiserverapi.CompletionRespChunk {
217-
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
218-
context.creationTime, context.model, nil)
220+
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+context.requestID,
221+
context.creationTime, context.model, nil, context.requestID)
219222
baseChunk.Object = chatCompletionChunkObject
220223
chunk := openaiserverapi.CreateChatCompletionRespChunk(baseChunk,
221224
[]openaiserverapi.ChatRespChunkChoice{

pkg/llm-d-inference-sim/test_utils.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,3 +541,7 @@ func checkSimSleeping(client *http.Client, expectedToSleep bool) {
541541
expect := fmt.Sprintf("{\"is_sleeping\":%t}", expectedToSleep)
542542
gomega.Expect(string(body)).To(gomega.Equal(expect))
543543
}
544+
545+
func ptr[T any](v T) *T {
546+
return &v
547+
}

pkg/openai-server-api/response.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ import (
2626
)
2727

2828
// CompletionResponse interface representing both completion response types (text and chat)
29-
type CompletionResponse interface{}
29+
type CompletionResponse interface {
30+
GetRequestID() string
31+
}
3032

3133
// baseCompletionResponse contains base completion response related information
3234
type baseCompletionResponse struct {
@@ -42,6 +44,8 @@ type baseCompletionResponse struct {
4244
Object string `json:"object"`
4345
// KVParams kv transfer related fields
4446
KVParams *KVTransferParams `json:"kv_transfer_params"`
47+
// RequestID is the unique request ID for tracking
48+
RequestID string `json:"-"`
4549
}
4650

4751
// Usage contains token Usage statistics
@@ -303,8 +307,13 @@ func CreateTextRespChoice(base baseResponseChoice, text string) TextRespChoice {
303307
return TextRespChoice{baseResponseChoice: base, Text: text, Logprobs: nil}
304308
}
305309

306-
func CreateBaseCompletionResponse(id string, created int64, model string, usage *Usage) baseCompletionResponse {
307-
return baseCompletionResponse{ID: id, Created: created, Model: model, Usage: usage}
310+
func CreateBaseCompletionResponse(id string, created int64, model string, usage *Usage, requestID string) baseCompletionResponse {
311+
return baseCompletionResponse{ID: id, Created: created, Model: model, Usage: usage, RequestID: requestID}
312+
}
313+
314+
// GetRequestID returns the request ID from the response
315+
func (b baseCompletionResponse) GetRequestID() string {
316+
return b.RequestID
308317
}
309318

310319
func CreateChatCompletionResponse(base baseCompletionResponse, choices []ChatRespChoice) *ChatCompletionResponse {

0 commit comments

Comments
 (0)