Skip to content

Commit e0e36e2

Browse files
committed
feat: add support for X-Request-Id header in responses and logs
Signed-off-by: rudeigerc <[email protected]>
1 parent cecbf25 commit e0e36e2

File tree

6 files changed

+221
-10
lines changed

6 files changed

+221
-10
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ require (
2121
golang.org/x/sync v0.12.0
2222
gopkg.in/yaml.v3 v3.0.1
2323
k8s.io/klog/v2 v2.130.1
24+
sigs.k8s.io/controller-runtime v0.21.0
2425
)
2526

2627
require (
@@ -77,7 +78,6 @@ require (
7778
k8s.io/client-go v0.33.0 // indirect
7879
k8s.io/kube-openapi v0.0.0-20250318190949-c8a335a9a2ff // indirect
7980
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect
80-
sigs.k8s.io/controller-runtime v0.21.0 // indirect
8181
sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect
8282
sigs.k8s.io/randfill v1.0.0 // indirect
8383
sigs.k8s.io/structured-merge-diff/v4 v4.6.0 // indirect

pkg/common/config.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ type Configuration struct {
223223

224224
// EnableSleepMode enables sleep mode
225225
EnableSleepMode bool `yaml:"enable-sleep-mode" json:"enable-sleep-mode"`
226+
227+
// EnableRequestIDHeaders enables including X-Request-Id header in responses
228+
EnableRequestIDHeaders bool `yaml:"enable-request-id-headers" json:"enable-request-id-headers"`
226229
}
227230

228231
type Metrics struct {
@@ -749,6 +752,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
749752
f.BoolVar(&config.DatasetInMemory, "dataset-in-memory", config.DatasetInMemory, "Load the entire dataset into memory for faster access")
750753

751754
f.BoolVar(&config.EnableSleepMode, "enable-sleep-mode", config.EnableSleepMode, "Enable sleep mode")
755+
f.BoolVar(&config.EnableRequestIDHeaders, "enable-request-id-headers", config.EnableRequestIDHeaders, "Enable including X-Request-Id header in responses")
752756

753757
f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures")
754758
failureTypes := getParamValueFromArgs("failure-types")

pkg/llm-d-inference-sim/server.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,22 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener)
109109
}
110110
}
111111

112+
// getRequestID retrieves the request ID from the X-Request-Id header or generates a new one if not present
113+
func (s *VllmSimulator) getRequestID(ctx *fasthttp.RequestCtx) string {
114+
requestID := s.random.GenerateUUIDString()
115+
116+
if s.config.EnableRequestIDHeaders {
117+
rid := string(ctx.Request.Header.Peek("X-Request-Id"))
118+
if rid != "" {
119+
requestID = rid
120+
}
121+
}
122+
return requestID
123+
}
124+
112125
// readRequest reads and parses data from the body of the given request according the type defined by isChatCompletion
113126
func (s *VllmSimulator) readRequest(ctx *fasthttp.RequestCtx, isChatCompletion bool) (openaiserverapi.CompletionRequest, error) {
114-
requestID := s.random.GenerateUUIDString()
127+
requestID := s.getRequestID(ctx)
115128

116129
if isChatCompletion {
117130
var req openaiserverapi.ChatCompletionRequest
@@ -250,7 +263,7 @@ func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (
250263
}
251264

252265
// sendCompletionResponse sends a completion response
253-
func (s *VllmSimulator) sendCompletionResponse(ctx *fasthttp.RequestCtx, resp openaiserverapi.CompletionResponse) {
266+
func (s *VllmSimulator) sendCompletionResponse(ctx *fasthttp.RequestCtx, resp openaiserverapi.CompletionResponse, requestID string) {
254267
data, err := json.Marshal(resp)
255268
if err != nil {
256269
ctx.Error("Response body creation failed, "+err.Error(), fasthttp.StatusInternalServerError)
@@ -266,6 +279,9 @@ func (s *VllmSimulator) sendCompletionResponse(ctx *fasthttp.RequestCtx, resp op
266279
if s.namespace != "" {
267280
ctx.Response.Header.Add(namespaceHeader, s.namespace)
268281
}
282+
if s.config.EnableRequestIDHeaders {
283+
ctx.Response.Header.Add(requestIDHeader, requestID)
284+
}
269285
ctx.Response.SetBody(data)
270286
}
271287

pkg/llm-d-inference-sim/server_test.go

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,192 @@ var _ = Describe("Server", func() {
212212

213213
})
214214

215+
Context("request ID headers", func() {
216+
It("Should include X-Request-Id in response when enabled", func() {
217+
ctx := context.TODO()
218+
args := []string{"cmd", "--model", testModel, "--mode", common.ModeEcho,
219+
"--enable-request-id-headers"}
220+
client, err := startServerWithArgs(ctx, args)
221+
Expect(err).NotTo(HaveOccurred())
222+
223+
reqBody := `{
224+
"messages": [{"role": "user", "content": "Hello"}],
225+
"model": "` + testModel + `",
226+
"max_tokens": 5
227+
}`
228+
229+
req, err := http.NewRequest("POST", "http://localhost/v1/chat/completions", strings.NewReader(reqBody))
230+
Expect(err).NotTo(HaveOccurred())
231+
req.Header.Set("Content-Type", "application/json")
232+
req.Header.Set("X-Request-Id", "test-request-id-123")
233+
234+
resp, err := client.Do(req)
235+
Expect(err).NotTo(HaveOccurred())
236+
defer func() {
237+
err := resp.Body.Close()
238+
Expect(err).NotTo(HaveOccurred())
239+
}()
240+
241+
Expect(resp.StatusCode).To(Equal(http.StatusOK))
242+
Expect(resp.Header.Get("X-Request-Id")).To(Equal("test-request-id-123"))
243+
})
244+
245+
It("Should not include X-Request-Id in response when disabled", func() {
246+
ctx := context.TODO()
247+
args := []string{"cmd", "--model", testModel, "--mode", common.ModeEcho}
248+
client, err := startServerWithArgs(ctx, args)
249+
Expect(err).NotTo(HaveOccurred())
250+
251+
reqBody := `{
252+
"messages": [{"role": "user", "content": "Hello"}],
253+
"model": "` + testModel + `",
254+
"max_tokens": 5
255+
}`
256+
257+
req, err := http.NewRequest("POST", "http://localhost/v1/chat/completions", strings.NewReader(reqBody))
258+
Expect(err).NotTo(HaveOccurred())
259+
req.Header.Set("Content-Type", "application/json")
260+
req.Header.Set("X-Request-Id", "test-request-id-456")
261+
262+
resp, err := client.Do(req)
263+
Expect(err).NotTo(HaveOccurred())
264+
defer func() {
265+
err := resp.Body.Close()
266+
Expect(err).NotTo(HaveOccurred())
267+
}()
268+
269+
Expect(resp.StatusCode).To(Equal(http.StatusOK))
270+
Expect(resp.Header.Get("X-Request-Id")).To(BeEmpty())
271+
})
272+
273+
It("Should include X-Request-Id in streaming response when enabled", func() {
274+
ctx := context.TODO()
275+
args := []string{"cmd", "--model", testModel, "--mode", common.ModeEcho,
276+
"--enable-request-id-headers"}
277+
client, err := startServerWithArgs(ctx, args)
278+
Expect(err).NotTo(HaveOccurred())
279+
280+
reqBody := `{
281+
"messages": [{"role": "user", "content": "Hello"}],
282+
"model": "` + testModel + `",
283+
"max_tokens": 5,
284+
"stream": true
285+
}`
286+
287+
req, err := http.NewRequest("POST", "http://localhost/v1/chat/completions", strings.NewReader(reqBody))
288+
Expect(err).NotTo(HaveOccurred())
289+
req.Header.Set("Content-Type", "application/json")
290+
req.Header.Set("X-Request-Id", "test-streaming-request-789")
291+
292+
resp, err := client.Do(req)
293+
Expect(err).NotTo(HaveOccurred())
294+
defer func() {
295+
err := resp.Body.Close()
296+
Expect(err).NotTo(HaveOccurred())
297+
}()
298+
299+
Expect(resp.StatusCode).To(Equal(http.StatusOK))
300+
Expect(resp.Header.Get("X-Request-Id")).To(Equal("test-streaming-request-789"))
301+
})
302+
303+
It("Should use request ID in response body ID field when enabled", func() {
304+
ctx := context.TODO()
305+
args := []string{"cmd", "--model", testModel, "--mode", common.ModeEcho,
306+
"--enable-request-id-headers"}
307+
client, err := startServerWithArgs(ctx, args)
308+
Expect(err).NotTo(HaveOccurred())
309+
310+
reqBody := `{
311+
"messages": [{"role": "user", "content": "Hello"}],
312+
"model": "` + testModel + `",
313+
"max_tokens": 5
314+
}`
315+
316+
req, err := http.NewRequest("POST", "http://localhost/v1/chat/completions", strings.NewReader(reqBody))
317+
Expect(err).NotTo(HaveOccurred())
318+
req.Header.Set("Content-Type", "application/json")
319+
req.Header.Set("X-Request-Id", "body-test-request-999")
320+
321+
resp, err := client.Do(req)
322+
Expect(err).NotTo(HaveOccurred())
323+
defer func() {
324+
err := resp.Body.Close()
325+
Expect(err).NotTo(HaveOccurred())
326+
}()
327+
328+
Expect(resp.StatusCode).To(Equal(http.StatusOK))
329+
330+
body, err := io.ReadAll(resp.Body)
331+
Expect(err).NotTo(HaveOccurred())
332+
333+
var completionResp map[string]interface{}
334+
err = json.Unmarshal(body, &completionResp)
335+
Expect(err).NotTo(HaveOccurred())
336+
337+
// The response ID should start with "chatcmpl-" followed by the request ID
338+
responseID, ok := completionResp["id"].(string)
339+
Expect(ok).To(BeTrue())
340+
Expect(responseID).To(Equal("chatcmpl-body-test-request-999"))
341+
})
342+
343+
It("Should work with text completions endpoint", func() {
344+
ctx := context.TODO()
345+
args := []string{"cmd", "--model", testModel, "--mode", common.ModeEcho,
346+
"--enable-request-id-headers"}
347+
client, err := startServerWithArgs(ctx, args)
348+
Expect(err).NotTo(HaveOccurred())
349+
350+
reqBody := `{
351+
"prompt": "Hello world",
352+
"model": "` + testModel + `",
353+
"max_tokens": 5
354+
}`
355+
356+
req, err := http.NewRequest("POST", "http://localhost/v1/completions", strings.NewReader(reqBody))
357+
Expect(err).NotTo(HaveOccurred())
358+
req.Header.Set("Content-Type", "application/json")
359+
req.Header.Set("X-Request-Id", "text-completion-request-111")
360+
361+
resp, err := client.Do(req)
362+
Expect(err).NotTo(HaveOccurred())
363+
defer func() {
364+
err := resp.Body.Close()
365+
Expect(err).NotTo(HaveOccurred())
366+
}()
367+
368+
Expect(resp.StatusCode).To(Equal(http.StatusOK))
369+
Expect(resp.Header.Get("X-Request-Id")).To(Equal("text-completion-request-111"))
370+
})
371+
372+
It("Should generate UUID when no X-Request-Id header provided and feature enabled", func() {
373+
ctx := context.TODO()
374+
args := []string{"cmd", "--model", testModel, "--mode", common.ModeEcho,
375+
"--enable-request-id-headers"}
376+
client, err := startServerWithArgs(ctx, args)
377+
Expect(err).NotTo(HaveOccurred())
378+
379+
reqBody := `{
380+
"messages": [{"role": "user", "content": "Hello"}],
381+
"model": "` + testModel + `",
382+
"max_tokens": 5
383+
}`
384+
385+
resp, err := client.Post("http://localhost/v1/chat/completions", "application/json", strings.NewReader(reqBody))
386+
Expect(err).NotTo(HaveOccurred())
387+
defer func() {
388+
err := resp.Body.Close()
389+
Expect(err).NotTo(HaveOccurred())
390+
}()
391+
392+
Expect(resp.StatusCode).To(Equal(http.StatusOK))
393+
// Should have a generated UUID in the response header
394+
requestID := resp.Header.Get("X-Request-Id")
395+
Expect(requestID).NotTo(BeEmpty())
396+
// UUID format check (basic validation)
397+
Expect(len(requestID)).To(BeNumerically(">", 30))
398+
})
399+
})
400+
215401
Context("sleep mode", Ordered, func() {
216402
AfterAll(func() {
217403
err := os.RemoveAll(tmpDir)

pkg/llm-d-inference-sim/simulator.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ const (
5151
podHeader = "x-inference-pod"
5252
portHeader = "x-inference-port"
5353
namespaceHeader = "x-inference-namespace"
54+
requestIDHeader = "X-Request-Id"
5455
podNameEnv = "POD_NAME"
5556
podNsEnv = "POD_NAMESPACE"
5657
)
@@ -573,8 +574,8 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool
573574
// modelName - display name returned to the client and used in metrics. It is either the first alias
574575
// from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request).
575576
func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall,
576-
finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse {
577-
baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
577+
finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool, requestID string) openaiserverapi.CompletionResponse {
578+
baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+requestID,
578579
time.Now().Unix(), modelName, usageData)
579580

580581
if doRemoteDecode {
@@ -655,9 +656,10 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r
655656
if toolCalls == nil {
656657
logprobs = reqCtx.CompletionReq.GetLogprobs()
657658
}
659+
requestID := reqCtx.CompletionReq.GetRequestID()
658660

659661
resp := s.createCompletionResponse(logprobs, reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName,
660-
reqCtx.CompletionReq.IsDoRemoteDecode())
662+
reqCtx.CompletionReq.IsDoRemoteDecode(), requestID)
661663

662664
// calculate how long to wait before returning the response, time is based on number of tokens
663665
nCachedPromptTokens := reqCtx.CompletionReq.GetNumberOfCachedPromptTokens()
@@ -679,7 +681,7 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r
679681
}
680682
common.WriteToChannel(s.metrics.reqDecodeTimeChan, time.Since(startDecode).Seconds(), s.logger, "metrics.reqDecodeTimeChan")
681683

682-
s.sendCompletionResponse(reqCtx.HTTPReqCtx, resp)
684+
s.sendCompletionResponse(reqCtx.HTTPReqCtx, resp, requestID)
683685
s.responseSentCallback(modelName, reqCtx.IsChatCompletion, reqCtx.CompletionReq.GetRequestID())
684686
}
685687

pkg/llm-d-inference-sim/streaming.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons
6060
if s.namespace != "" {
6161
context.ctx.Response.Header.Add(namespaceHeader, s.namespace)
6262
}
63+
if s.config.EnableRequestIDHeaders {
64+
context.ctx.Response.Header.Add(requestIDHeader, context.requestID)
65+
}
6366

6467
context.ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
6568
context.creationTime = time.Now().Unix()
@@ -176,7 +179,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ
176179
// createUsageChunk creates and returns a CompletionRespChunk with usage data, a single chunk of streamed completion API response,
177180
// supports both modes (text and chat)
178181
func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *openaiserverapi.Usage) openaiserverapi.CompletionRespChunk {
179-
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
182+
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+context.requestID,
180183
context.creationTime, context.model, usageData)
181184

182185
if context.isChatCompletion {
@@ -191,7 +194,7 @@ func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *o
191194
// createTextCompletionChunk creates and returns a CompletionRespChunk, a single chunk of streamed completion API response,
192195
// for text completion.
193196
func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, token string, finishReason *string) openaiserverapi.CompletionRespChunk {
194-
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
197+
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+context.requestID,
195198
context.creationTime, context.model, nil)
196199
baseChunk.Object = textCompletionObject
197200

@@ -214,7 +217,7 @@ func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, tok
214217
// API response, for chat completion. It sets either role, or token, or tool call info in the message.
215218
func (s *VllmSimulator) createChatCompletionChunk(context *streamingContext, token string, tool *openaiserverapi.ToolCall,
216219
role string, finishReason *string) openaiserverapi.CompletionRespChunk {
217-
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
220+
baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+context.requestID,
218221
context.creationTime, context.model, nil)
219222
baseChunk.Object = chatCompletionChunkObject
220223
chunk := openaiserverapi.CreateChatCompletionRespChunk(baseChunk,

0 commit comments

Comments
 (0)