@@ -25,11 +25,13 @@ import (
2525 "strings"
2626 "time"
2727
28+ . "github.com/onsi/ginkgo/v2"
29+ . "github.com/onsi/gomega"
30+ "github.com/valyala/fasthttp"
31+
2832 "github.com/llm-d/llm-d-inference-sim/pkg/common"
2933 kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache"
3034 vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api"
31- . "github.com/onsi/ginkgo/v2"
32- . "github.com/onsi/gomega"
3335)
3436
3537const tmpDir = "./tests-tmp/"
@@ -212,6 +214,108 @@ var _ = Describe("Server", func() {
212214
213215 })
214216
217+ Context ("request ID headers" , func () {
218+ testRequestIDHeader := func (enableRequestID bool , endpoint , reqBody , inputRequestID string , expectRequestID * string , validateBody func ([]byte )) {
219+ ctx := context .TODO ()
220+ args := []string {"cmd" , "--model" , testModel , "--mode" , common .ModeEcho }
221+ if enableRequestID {
222+ args = append (args , "--enable-request-id-headers" )
223+ }
224+ client , err := startServerWithArgs (ctx , args )
225+ Expect (err ).NotTo (HaveOccurred ())
226+
227+ req , err := http .NewRequest ("POST" , "http://localhost" + endpoint , strings .NewReader (reqBody ))
228+ Expect (err ).NotTo (HaveOccurred ())
229+ req .Header .Set (fasthttp .HeaderContentType , "application/json" )
230+ if inputRequestID != "" {
231+ req .Header .Set (requestIDHeader , inputRequestID )
232+ }
233+
234+ resp , err := client .Do (req )
235+ Expect (err ).NotTo (HaveOccurred ())
236+ defer resp .Body .Close ()
237+
238+ Expect (resp .StatusCode ).To (Equal (http .StatusOK ))
239+
240+ if expectRequestID != nil {
241+ actualRequestID := resp .Header .Get (requestIDHeader )
242+ if * expectRequestID != "" {
243+ // When a request ID is provided, it should be echoed back
244+ Expect (actualRequestID ).To (Equal (* expectRequestID ))
245+ } else {
246+ // When no request ID is provided, a UUID should be generated
247+ Expect (actualRequestID ).NotTo (BeEmpty ())
248+ Expect (len (actualRequestID )).To (BeNumerically (">" , 30 ))
249+ }
250+ } else {
251+ // When request ID headers are disabled, the header should be empty
252+ Expect (resp .Header .Get (requestIDHeader )).To (BeEmpty ())
253+ }
254+
255+ if validateBody != nil {
256+ body , err := io .ReadAll (resp .Body )
257+ Expect (err ).NotTo (HaveOccurred ())
258+ validateBody (body )
259+ }
260+ }
261+
262+ DescribeTable ("request ID behavior" ,
263+ testRequestIDHeader ,
264+ Entry ("includes X-Request-Id when enabled" ,
265+ true ,
266+ "/v1/chat/completions" ,
267+ `{"messages": [{"role": "user", "content": "Hello"}], "model": "` + testModel + `", "max_tokens": 5}` ,
268+ "test-request-id-123" ,
269+ ptr ("test-request-id-123" ),
270+ nil ,
271+ ),
272+ Entry ("excludes X-Request-Id when disabled" ,
273+ false ,
274+ "/v1/chat/completions" ,
275+ `{"messages": [{"role": "user", "content": "Hello"}], "model": "` + testModel + `", "max_tokens": 5}` ,
276+ "test-request-id-456" ,
277+ nil ,
278+ nil ,
279+ ),
280+ Entry ("includes X-Request-Id in streaming response" ,
281+ true ,
282+ "/v1/chat/completions" ,
283+ `{"messages": [{"role": "user", "content": "Hello"}], "model": "` + testModel + `", "max_tokens": 5, "stream": true}` ,
284+ "test-streaming-789" ,
285+ ptr ("test-streaming-789" ),
286+ nil ,
287+ ),
288+ Entry ("works with text completions endpoint" ,
289+ true ,
290+ "/v1/completions" ,
291+ `{"prompt": "Hello world", "model": "` + testModel + `", "max_tokens": 5}` ,
292+ "text-request-111" ,
293+ ptr ("text-request-111" ),
294+ nil ,
295+ ),
296+ Entry ("generates UUID when no request ID provided" ,
297+ true ,
298+ "/v1/chat/completions" ,
299+ `{"messages": [{"role": "user", "content": "Hello"}], "model": "` + testModel + `", "max_tokens": 5}` ,
300+ "" ,
301+ ptr ("" ),
302+ nil ,
303+ ),
304+ Entry ("uses request ID in response body ID field" ,
305+ true ,
306+ "/v1/chat/completions" ,
307+ `{"messages": [{"role": "user", "content": "Hello"}], "model": "` + testModel + `", "max_tokens": 5}` ,
308+ "body-test-999" ,
309+ ptr ("body-test-999" ),
310+ func (body []byte ) {
311+ var resp map [string ]any
312+ Expect (json .Unmarshal (body , & resp )).To (Succeed ())
313+ Expect (resp ["id" ]).To (Equal ("chatcmpl-body-test-999" ))
314+ },
315+ ),
316+ )
317+ })
318+
215319 Context ("sleep mode" , Ordered , func () {
216320 AfterAll (func () {
217321 err := os .RemoveAll (tmpDir )
0 commit comments