Skip to content

Commit 86b0ddd

Browse files
authored
[bugfix] ReqBody receive stream=true when calling /v1/completion (vllm-project#1850)
/v1/completion api support stream=true headers Signed-off-by: DengHom <dengh20220925@163.com>
1 parent eef3dc0 commit 86b0ddd

File tree

2 files changed

+52
-12
lines changed

2 files changed

+52
-12
lines changed

pkg/plugins/gateway/gateway_req_body_test.go

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ import (
2828
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2929
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
3030
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
31+
v1 "k8s.io/api/core/v1"
32+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
33+
3134
"github.com/vllm-project/aibrix/pkg/cache"
3235
routingalgorithms "github.com/vllm-project/aibrix/pkg/plugins/gateway/algorithms"
3336
"github.com/vllm-project/aibrix/pkg/types"
3437
"github.com/vllm-project/aibrix/pkg/utils"
35-
v1 "k8s.io/api/core/v1"
36-
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3738
)
3839

3940
// TestRouterAlgorithm is a dedicated routing algorithm for testing
@@ -59,12 +60,12 @@ func Test_handleRequestBody(t *testing.T) {
5960
type testCase struct {
6061
name string
6162
requestBody string
63+
reqPath string
6264
user utils.User
6365
routingAlgo types.RoutingAlgorithm
6466
mockSetup func(*MockCache, *mockRouter)
6567
expected testResponse
6668
validate func(*testing.T, *testCase, *extProcPb.ProcessingResponse, string, *types.RoutingContext, bool, int64)
67-
checkStream bool
6869
}
6970

7071
// Define test cases for different routing and error scenarios
@@ -121,7 +122,6 @@ func Test_handleRequestBody(t *testing.T) {
121122
assert.NotEqual(t, HeaderTargetPod, header.Header.Key)
122123
}
123124
},
124-
checkStream: false,
125125
},
126126
{
127127
name: "model not in cache - should return error",
@@ -161,7 +161,6 @@ func Test_handleRequestBody(t *testing.T) {
161161
assert.Equal(t, tt.expected.stream, stream)
162162
assert.Equal(t, tt.expected.term, term)
163163
},
164-
checkStream: false,
165164
},
166165
{
167166
name: "valid routing strategy - should set both routing and target pod headers",
@@ -267,7 +266,6 @@ func Test_handleRequestBody(t *testing.T) {
267266
assert.True(t, foundRoutingStrategy, "HeaderRoutingStrategy not found")
268267
assert.True(t, foundTargetPod, "HeaderTargetPod not found")
269268
},
270-
checkStream: false,
271269
},
272270
{
273271
name: "invalid routing strategy - should fallback to random router",
@@ -344,7 +342,6 @@ func Test_handleRequestBody(t *testing.T) {
344342
assert.True(t, foundRoutingStrategy, "HeaderRoutingStrategy not found")
345343
assert.True(t, foundTargetPod, "HeaderTargetPod not found")
346344
},
347-
checkStream: false,
348345
},
349346
{
350347
name: "no routable pods available - should return ServiceUnavailable",
@@ -413,7 +410,6 @@ func Test_handleRequestBody(t *testing.T) {
413410
assert.Equal(t, tt.expected.stream, stream)
414411
assert.Equal(t, tt.expected.term, term)
415412
},
416-
checkStream: false,
417413
},
418414
{
419415
name: "empty pods list - should return ServiceUnavailable",
@@ -459,7 +455,6 @@ func Test_handleRequestBody(t *testing.T) {
459455
assert.Equal(t, tt.expected.stream, stream)
460456
assert.Equal(t, tt.expected.term, term)
461457
},
462-
checkStream: false,
463458
},
464459
{
465460
name: "single pod in termination - should return ServiceUnavailable",
@@ -520,7 +515,6 @@ func Test_handleRequestBody(t *testing.T) {
520515
assert.Equal(t, tt.expected.stream, stream)
521516
assert.Equal(t, tt.expected.term, term)
522517
},
523-
checkStream: false,
524518
},
525519
{
526520
name: "routable pod without IP - should return ServiceUnavailable",
@@ -578,7 +572,47 @@ func Test_handleRequestBody(t *testing.T) {
578572
assert.Equal(t, tt.expected.stream, stream)
579573
assert.Equal(t, tt.expected.term, term)
580574
},
581-
checkStream: false,
575+
},
576+
{
577+
name: "request /v1/completions with stream header - should get the true value of stream",
578+
reqPath: "/v1/completions",
579+
requestBody: `{"model": "test-model", "prompt": "test", "stream": true}`,
580+
user: utils.User{
581+
Name: "test-user",
582+
},
583+
routingAlgo: "",
584+
mockSetup: func(mockCache *MockCache, _ *mockRouter) {
585+
mockCache.On("HasModel", "test-model").Return(true)
586+
podList := &utils.PodArray{
587+
Pods: []*v1.Pod{
588+
{
589+
Status: v1.PodStatus{
590+
PodIP: "1.2.3.4",
591+
Conditions: []v1.PodCondition{{Type: v1.PodReady, Status: v1.ConditionTrue}},
592+
},
593+
},
594+
{
595+
Status: v1.PodStatus{
596+
PodIP: "4.5.6.7",
597+
Conditions: []v1.PodCondition{{Type: v1.PodReady, Status: v1.ConditionTrue}},
598+
},
599+
},
600+
},
601+
}
602+
mockCache.On("ListPodsByModel", "test-model").Return(podList, nil)
603+
mockCache.On("AddRequestCount", mock.Anything, mock.Anything, "test-model").Return(int64(1))
604+
},
605+
expected: testResponse{
606+
statusCode: envoyTypePb.StatusCode_OK,
607+
model: "test-model",
608+
stream: true,
609+
routingCtx: &types.RoutingContext{RequestID: "test-request-id"},
610+
},
611+
validate: func(t *testing.T, tt *testCase, resp *extProcPb.ProcessingResponse, model string, routingCtx *types.RoutingContext, stream bool, term int64) {
612+
assert.Equal(t, tt.expected.statusCode, envoyTypePb.StatusCode_OK)
613+
assert.Equal(t, tt.expected.stream, stream)
614+
assert.NotNil(t, routingCtx)
615+
},
582616
},
583617
}
584618

@@ -645,6 +679,9 @@ func Test_handleRequestBody(t *testing.T) {
645679
// Call HandleRequestBody and validate the response
646680
routingCtx := types.NewRoutingContext(context.Background(), tt.routingAlgo, tt.expected.model, "", "test-request-id", tt.user.Name)
647681
routingCtx.ReqPath = "/v1/chat/completions"
682+
if tt.reqPath != "" {
683+
routingCtx.ReqPath = tt.reqPath
684+
}
648685
resp, model, routingCtx, stream, term := server.HandleRequestBody(
649686
routingCtx,
650687
"test-request-id",

pkg/plugins/gateway/util.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ import (
2727
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
2828
"github.com/openai/openai-go"
2929
"github.com/openai/openai-go/packages/param"
30-
"github.com/vllm-project/aibrix/pkg/utils"
3130
"k8s.io/klog/v2"
31+
32+
"github.com/vllm-project/aibrix/pkg/utils"
3233
)
3334

3435
const (
@@ -83,6 +84,7 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user
8384
type Completion struct {
8485
Prompt string `json:"prompt"`
8586
Model string `json:"model"`
87+
Stream bool `json:"stream"`
8688
}
8789
completionObj := Completion{}
8890
err := json.Unmarshal(requestBody, &completionObj)
@@ -93,6 +95,7 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user
9395
}
9496
model = completionObj.Model
9597
message = completionObj.Prompt
98+
stream = completionObj.Stream
9699
case "/v1/embeddings":
97100
embeddingObj := openai.EmbeddingNewParams{}
98101
if err := json.Unmarshal(requestBody, &embeddingObj); err != nil {

0 commit comments

Comments
 (0)