Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions pkg/sidecar/proxy/connector_sglang.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package proxy

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -77,8 +78,10 @@ func (s *Server) runSGLangProtocol(w http.ResponseWriter, r *http.Request, prefi

func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Request, body []byte, prefillHost string) {
// Create separate requests for prefill and decode
prefillReq := cloneWithJSONBody(r, body)
decodeReq := cloneWithJSONBody(r, body)
// Use context.WithoutCancel for prefillReq to prevent it from being aborted
// if the main HTTP handler (which serves decodeReq) finishes first.
prefillReq := cloneWithJSONBody(context.WithoutCancel(r.Context()), r, body)
decodeReq := cloneWithJSONBody(r.Context(), r, body)

prefillHandler, err := s.prefillerProxyHandler(prefillHost)
if err != nil {
Expand All @@ -90,6 +93,11 @@ func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Req

// Send prefill request asynchronously
go func() {
defer func() {
if rec := recover(); rec != nil && rec != http.ErrAbortHandler {
s.logger.Error(fmt.Errorf("panic: %v", rec), "panic in prefill request")
}
}()
pw := &bufferedResponseWriter{}
prefillHandler.ServeHTTP(pw, prefillReq)
s.logger.V(5).Info("prefill request completed", "status", pw.statusCode)
Expand All @@ -99,8 +107,8 @@ func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Req
s.decoderProxy.ServeHTTP(w, decodeReq)
}

func cloneWithJSONBody(r *http.Request, body []byte) *http.Request {
req := r.Clone(r.Context())
func cloneWithJSONBody(ctx context.Context, r *http.Request, body []byte) *http.Request {
req := r.Clone(ctx)
req.Body = io.NopCloser(bytes.NewReader(body))
req.ContentLength = int64(len(body))
return req
Expand Down
177 changes: 177 additions & 0 deletions pkg/sidecar/proxy/connector_sglang_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
Copyright 2025 The llm-d Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package proxy

import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"time"

"github.com/llm-d/llm-d-inference-scheduler/pkg/common"
. "github.com/onsi/ginkgo/v2" // nolint:revive
. "github.com/onsi/gomega" // nolint:revive
)

var _ = Describe("SGLang Connector", func() {

var testInfo *sidecarTestInfo

BeforeEach(func() {
// Mock testing setup using the SGLang connector mode
testInfo = sidecarConnectionTestSetup(ConnectorSGLang)
})

It("should successfully send concurrent requests to prefill and decode with bootstrap info", func() {
By("starting the proxy")
go func() {
defer GinkgoRecover()

validator := &AllowlistValidator{enabled: false}
err := testInfo.proxy.Start(testInfo.ctx, nil, validator)
Expect(err).ToNot(HaveOccurred())

testInfo.stoppedCh <- struct{}{}
}()

// Wait for proxy to start
time.Sleep(1 * time.Second)
Expect(testInfo.proxy.addr).ToNot(BeNil())
proxyBaseAddr := "http://" + testInfo.proxy.addr.String()

By("sending a /v1/chat/completions request with prefill header")
body := `{
"model": "Qwen/Qwen2-0.5B",
"messages": [
{"role": "user", "content": "Hello"}
],
"max_tokens": 50
}`

req, err := http.NewRequest(http.MethodPost, proxyBaseAddr+ChatCompletionsPath, strings.NewReader(body))
Expect(err).ToNot(HaveOccurred())

prefillHostPort := testInfo.prefillBackend.URL[len("http://"):]
req.Header.Add(common.PrefillPodHeader, prefillHostPort)

rp, err := http.DefaultClient.Do(req)
Expect(err).ToNot(HaveOccurred())

if rp.StatusCode != 200 {
bp, _ := io.ReadAll(rp.Body) //nolint:all
Fail(string(bp))
}

// Because SGLang connector sends requests concurrently (prefill in goroutine),
// we sleep a tiny bit to ensure the prefill handler has time to finish processing.
time.Sleep(100 * time.Millisecond)

// Validate prefill request
Expect(testInfo.prefillHandler.RequestCount.Load()).To(BeNumerically("==", 1))
Expect(testInfo.prefillHandler.CompletionRequests).To(HaveLen(1))
prq1 := testInfo.prefillHandler.CompletionRequests[0]

// Validate decode request
Expect(testInfo.decodeHandler.RequestCount.Load()).To(BeNumerically("==", 1))
Expect(testInfo.decodeHandler.CompletionRequests).To(HaveLen(1))
drq1 := testInfo.decodeHandler.CompletionRequests[0]

// Bootstrap validations for prefill
Expect(prq1).To(HaveKey(requestFieldBootstrapHost))
Expect(prq1).To(HaveKey(requestFieldBootstrapPort))
Expect(prq1).To(HaveKey(requestFieldBootstrapRoom))

expectedHost := strings.Split(prefillHostPort, ":")[0]
Expect(prq1[requestFieldBootstrapHost]).To(Equal(expectedHost))
Expect(prq1[requestFieldBootstrapPort]).To(Equal(float64(sglangBootstrapPort)))
Expect(prq1[requestFieldBootstrapRoom]).ToNot(BeNil())

// Bootstrap validations for decode
Expect(drq1).To(HaveKey(requestFieldBootstrapHost))
Expect(drq1).To(HaveKey(requestFieldBootstrapPort))
Expect(drq1).To(HaveKey(requestFieldBootstrapRoom))

Expect(drq1[requestFieldBootstrapHost]).To(Equal(expectedHost))
Expect(drq1[requestFieldBootstrapPort]).To(Equal(float64(sglangBootstrapPort)))
Expect(drq1[requestFieldBootstrapRoom]).To(Equal(prq1[requestFieldBootstrapRoom])) // Room ID must match

testInfo.cancelFn()
<-testInfo.stoppedCh
})

It("should not panic when prefill response is slower than decode response", func() {
// Stop previously injected servers
testInfo.decodeBackend.Close()
testInfo.prefillBackend.Close()

var prefillFinished bool

slowPrefill := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
testInfo.prefillHandler.ServeHTTP(w, r)
time.Sleep(300 * time.Millisecond) // Simulated load delay on KV Cache
prefillFinished = true
})
testInfo.prefillBackend = httptest.NewServer(slowPrefill)

fastDecode := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
testInfo.decodeHandler.ServeHTTP(w, r)
})
testInfo.decodeBackend = httptest.NewServer(fastDecode)
testInfo.decodeURL, _ = url.Parse(testInfo.decodeBackend.URL)

// Re-initialize proxy to fetch the new mock addresses
cfg := Config{
Connector: ConnectorSGLang,
}
testInfo.proxy = NewProxy("0", testInfo.decodeURL, cfg)

go func() {
defer GinkgoRecover()
validator := &AllowlistValidator{enabled: false}
err := testInfo.proxy.Start(testInfo.ctx, nil, validator)
Expect(err).ToNot(HaveOccurred())
testInfo.stoppedCh <- struct{}{}
}()

time.Sleep(1 * time.Second)
proxyBaseAddr := "http://" + testInfo.proxy.addr.String()

body := `{"model": "Qwen", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}`
req, err := http.NewRequest(http.MethodPost, proxyBaseAddr+ChatCompletionsPath, strings.NewReader(body))
Expect(err).ToNot(HaveOccurred())

prefillHostPort := testInfo.prefillBackend.URL[len("http://"):]
req.Header.Add(common.PrefillPodHeader, prefillHostPort)

// Submit request. This will complete as soon as fastDecode completes.
rp, err := http.DefaultClient.Do(req)
Expect(err).ToNot(HaveOccurred())
Expect(rp.StatusCode).To(Equal(200))

// The original panicking goroutine takes 300ms total. Give it time to attempt finishing up!
time.Sleep(500 * time.Millisecond)

Expect(prefillFinished).To(BeTrue())
Expect(testInfo.prefillHandler.RequestCount.Load()).To(BeNumerically("==", 1))
Expect(testInfo.decodeHandler.RequestCount.Load()).To(BeNumerically("==", 1))

testInfo.cancelFn()
<-testInfo.stoppedCh
})
})