Skip to content

Commit 7c2324a

Browse files
committed
Fix panic in SGLang proxy handling of concurrent requests
1 parent 5c89823 commit 7c2324a

File tree

2 files changed

+127
-4
lines changed

2 files changed

+127
-4
lines changed

pkg/sidecar/proxy/connector_sglang.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package proxy
1818

1919
import (
2020
"bytes"
21+
"context"
2122
"encoding/json"
2223
"fmt"
2324
"io"
@@ -77,8 +78,10 @@ func (s *Server) runSGLangProtocol(w http.ResponseWriter, r *http.Request, prefi
7778

7879
func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Request, body []byte, prefillHost string) {
7980
// Create separate requests for prefill and decode
80-
prefillReq := cloneWithJSONBody(r, body)
81-
decodeReq := cloneWithJSONBody(r, body)
81+
// Use context.WithoutCancel for prefillReq to prevent it from being aborted
82+
// if the main HTTP handler (which serves decodeReq) finishes first.
83+
prefillReq := cloneWithJSONBody(context.WithoutCancel(r.Context()), r, body)
84+
decodeReq := cloneWithJSONBody(r.Context(), r, body)
8285

8386
prefillHandler, err := s.prefillerProxyHandler(prefillHost)
8487
if err != nil {
@@ -90,6 +93,11 @@ func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Req
9093

9194
// Send prefill request asynchronously
9295
go func() {
96+
defer func() {
97+
if rec := recover(); rec != nil && rec != http.ErrAbortHandler {
98+
s.logger.Error(fmt.Errorf("panic: %v", rec), "panic in prefill request")
99+
}
100+
}()
93101
pw := &bufferedResponseWriter{}
94102
prefillHandler.ServeHTTP(pw, prefillReq)
95103
s.logger.V(5).Info("prefill request completed", "status", pw.statusCode)
@@ -99,8 +107,8 @@ func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Req
99107
s.decoderProxy.ServeHTTP(w, decodeReq)
100108
}
101109

102-
func cloneWithJSONBody(r *http.Request, body []byte) *http.Request {
103-
req := r.Clone(r.Context())
110+
func cloneWithJSONBody(ctx context.Context, r *http.Request, body []byte) *http.Request {
111+
req := r.Clone(ctx)
104112
req.Body = io.NopCloser(bytes.NewReader(body))
105113
req.ContentLength = int64(len(body))
106114
return req
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
Copyright 2025 The llm-d Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package proxy
18+
19+
import (
20+
"io"
21+
"net/http"
22+
"strings"
23+
"time"
24+
25+
"github.com/llm-d/llm-d-inference-scheduler/pkg/common"
26+
. "github.com/onsi/ginkgo/v2" // nolint:revive
27+
. "github.com/onsi/gomega" // nolint:revive
28+
)
29+
30+
var _ = Describe("SGLang Connector", func() {
31+
32+
var testInfo *sidecarTestInfo
33+
34+
BeforeEach(func() {
35+
// Mock testing setup using the SGLang connector mode
36+
testInfo = sidecarConnectionTestSetup(ConnectorSGLang)
37+
})
38+
39+
It("should successfully send concurrent requests to prefill and decode with bootstrap info", func() {
40+
By("starting the proxy")
41+
go func() {
42+
defer GinkgoRecover()
43+
44+
validator := &AllowlistValidator{enabled: false}
45+
err := testInfo.proxy.Start(testInfo.ctx, nil, validator)
46+
Expect(err).ToNot(HaveOccurred())
47+
48+
testInfo.stoppedCh <- struct{}{}
49+
}()
50+
51+
// Wait for proxy to start
52+
time.Sleep(1 * time.Second)
53+
Expect(testInfo.proxy.addr).ToNot(BeNil())
54+
proxyBaseAddr := "http://" + testInfo.proxy.addr.String()
55+
56+
By("sending a /v1/chat/completions request with prefill header")
57+
body := `{
58+
"model": "Qwen/Qwen2-0.5B",
59+
"messages": [
60+
{"role": "user", "content": "Hello"}
61+
],
62+
"max_tokens": 50
63+
}`
64+
65+
req, err := http.NewRequest(http.MethodPost, proxyBaseAddr+ChatCompletionsPath, strings.NewReader(body))
66+
Expect(err).ToNot(HaveOccurred())
67+
68+
prefillHostPort := testInfo.prefillBackend.URL[len("http://"):]
69+
req.Header.Add(common.PrefillPodHeader, prefillHostPort)
70+
71+
rp, err := http.DefaultClient.Do(req)
72+
Expect(err).ToNot(HaveOccurred())
73+
74+
if rp.StatusCode != 200 {
75+
bp, _ := io.ReadAll(rp.Body) //nolint:all
76+
Fail(string(bp))
77+
}
78+
79+
// Because SGLang connector sends requests concurrently (prefill in goroutine),
80+
// we sleep a tiny bit to ensure the prefill handler has time to finish processing.
81+
time.Sleep(100 * time.Millisecond)
82+
83+
// Validate prefill request
84+
Expect(testInfo.prefillHandler.RequestCount.Load()).To(BeNumerically("==", 1))
85+
Expect(testInfo.prefillHandler.CompletionRequests).To(HaveLen(1))
86+
prq1 := testInfo.prefillHandler.CompletionRequests[0]
87+
88+
// Validate decode request
89+
Expect(testInfo.decodeHandler.RequestCount.Load()).To(BeNumerically("==", 1))
90+
Expect(testInfo.decodeHandler.CompletionRequests).To(HaveLen(1))
91+
drq1 := testInfo.decodeHandler.CompletionRequests[0]
92+
93+
// Bootstrap validations for prefill
94+
Expect(prq1).To(HaveKey(requestFieldBootstrapHost))
95+
Expect(prq1).To(HaveKey(requestFieldBootstrapPort))
96+
Expect(prq1).To(HaveKey(requestFieldBootstrapRoom))
97+
98+
expectedHost := strings.Split(prefillHostPort, ":")[0]
99+
Expect(prq1[requestFieldBootstrapHost]).To(Equal(expectedHost))
100+
Expect(prq1[requestFieldBootstrapPort]).To(Equal(float64(sglangBootstrapPort)))
101+
Expect(prq1[requestFieldBootstrapRoom]).ToNot(BeNil())
102+
103+
// Bootstrap validations for decode
104+
Expect(drq1).To(HaveKey(requestFieldBootstrapHost))
105+
Expect(drq1).To(HaveKey(requestFieldBootstrapPort))
106+
Expect(drq1).To(HaveKey(requestFieldBootstrapRoom))
107+
108+
Expect(drq1[requestFieldBootstrapHost]).To(Equal(expectedHost))
109+
Expect(drq1[requestFieldBootstrapPort]).To(Equal(float64(sglangBootstrapPort)))
110+
Expect(drq1[requestFieldBootstrapRoom]).To(Equal(prq1[requestFieldBootstrapRoom])) // Room ID must match
111+
112+
testInfo.cancelFn()
113+
<-testInfo.stoppedCh
114+
})
115+
})

0 commit comments

Comments
 (0)