@@ -25,11 +25,13 @@ import (
2525 "strings"
2626 "time"
2727
28+ . "github.com/onsi/ginkgo/v2"
29+ . "github.com/onsi/gomega"
30+ "github.com/valyala/fasthttp"
31+
2832 "github.com/llm-d/llm-d-inference-sim/pkg/common"
2933 kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache"
3034 vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api"
31- . "github.com/onsi/ginkgo/v2"
32- . "github.com/onsi/gomega"
3335)
3436
3537const tmpDir = "./tests-tmp/"
@@ -212,6 +214,111 @@ var _ = Describe("Server", func() {
212214
213215 })
214216
217+ Context ("request ID headers" , func () {
218+ testRequestIDHeader := func (enableRequestID bool , endpoint , reqBody , inputRequestID string , expectRequestID * string , validateBody func ([]byte )) {
219+ ctx := context .TODO ()
220+ args := []string {"cmd" , "--model" , testModel , "--mode" , common .ModeEcho }
221+ if enableRequestID {
222+ args = append (args , "--enable-request-id-headers" )
223+ }
224+ client , err := startServerWithArgs (ctx , args )
225+ Expect (err ).NotTo (HaveOccurred ())
226+
227+ req , err := http .NewRequest ("POST" , "http://localhost" + endpoint , strings .NewReader (reqBody ))
228+ Expect (err ).NotTo (HaveOccurred ())
229+ req .Header .Set (fasthttp .HeaderContentType , "application/json" )
230+ if inputRequestID != "" {
231+ req .Header .Set (requestIDHeader , inputRequestID )
232+ }
233+
234+ resp , err := client .Do (req )
235+ Expect (err ).NotTo (HaveOccurred ())
236+ defer func () {
237+ err := resp .Body .Close ()
238+ Expect (err ).NotTo (HaveOccurred ())
239+ }()
240+
241+ Expect (resp .StatusCode ).To (Equal (http .StatusOK ))
242+
243+ if expectRequestID != nil {
244+ actualRequestID := resp .Header .Get (requestIDHeader )
245+ if * expectRequestID != "" {
246+ // When a request ID is provided, it should be echoed back
247+ Expect (actualRequestID ).To (Equal (* expectRequestID ))
248+ } else {
249+ // When no request ID is provided, a UUID should be generated
250+ Expect (actualRequestID ).NotTo (BeEmpty ())
251+ Expect (len (actualRequestID )).To (BeNumerically (">" , 30 ))
252+ }
253+ } else {
254+ // When request ID headers are disabled, the header should be empty
255+ Expect (resp .Header .Get (requestIDHeader )).To (BeEmpty ())
256+ }
257+
258+ if validateBody != nil {
259+ body , err := io .ReadAll (resp .Body )
260+ Expect (err ).NotTo (HaveOccurred ())
261+ validateBody (body )
262+ }
263+ }
264+
265+ DescribeTable ("request ID behavior" ,
266+ testRequestIDHeader ,
267+ Entry ("includes X-Request-Id when enabled" ,
268+ true ,
269+ "/v1/chat/completions" ,
270+ `{"messages": [{"role": "user", "content": "Hello"}], "model": "` + testModel + `", "max_tokens": 5}` ,
271+ "test-request-id-123" ,
272+ ptr ("test-request-id-123" ),
273+ nil ,
274+ ),
275+ Entry ("excludes X-Request-Id when disabled" ,
276+ false ,
277+ "/v1/chat/completions" ,
278+ `{"messages": [{"role": "user", "content": "Hello"}], "model": "` + testModel + `", "max_tokens": 5}` ,
279+ "test-request-id-456" ,
280+ nil ,
281+ nil ,
282+ ),
283+ Entry ("includes X-Request-Id in streaming response" ,
284+ true ,
285+ "/v1/chat/completions" ,
286+ `{"messages": [{"role": "user", "content": "Hello"}], "model": "` + testModel + `", "max_tokens": 5, "stream": true}` ,
287+ "test-streaming-789" ,
288+ ptr ("test-streaming-789" ),
289+ nil ,
290+ ),
291+ Entry ("works with text completions endpoint" ,
292+ true ,
293+ "/v1/completions" ,
294+ `{"prompt": "Hello world", "model": "` + testModel + `", "max_tokens": 5}` ,
295+ "text-request-111" ,
296+ ptr ("text-request-111" ),
297+ nil ,
298+ ),
299+ Entry ("generates UUID when no request ID provided" ,
300+ true ,
301+ "/v1/chat/completions" ,
302+ `{"messages": [{"role": "user", "content": "Hello"}], "model": "` + testModel + `", "max_tokens": 5}` ,
303+ "" ,
304+ ptr ("" ),
305+ nil ,
306+ ),
307+ Entry ("uses request ID in response body ID field" ,
308+ true ,
309+ "/v1/chat/completions" ,
310+ `{"messages": [{"role": "user", "content": "Hello"}], "model": "` + testModel + `", "max_tokens": 5}` ,
311+ "body-test-999" ,
312+ ptr ("body-test-999" ),
313+ func (body []byte ) {
314+ var resp map [string ]any
315+ Expect (json .Unmarshal (body , & resp )).To (Succeed ())
316+ Expect (resp ["id" ]).To (Equal ("chatcmpl-body-test-999" ))
317+ },
318+ ),
319+ )
320+ })
321+
215322 Context ("sleep mode" , Ordered , func () {
216323 AfterAll (func () {
217324 err := os .RemoveAll (tmpDir )
0 commit comments