Skip to content

Commit 5ff1e27

Browse files
authored
Server unit test and utility to help with such tests (#820)
Signed-off-by: Ira <[email protected]>
1 parent 7ca36bf commit 5ff1e27

File tree

2 files changed

+382
-0
lines changed

2 files changed

+382
-0
lines changed

pkg/epp/server/server_test.go

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package server
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"testing"
23+
24+
pb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
25+
v1 "k8s.io/api/core/v1"
26+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
27+
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
28+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
29+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
30+
testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
31+
"sigs.k8s.io/gateway-api-inference-extension/test/utils"
32+
)
33+
34+
const (
35+
bufSize = 1024 * 1024
36+
podName = "pod1"
37+
podAddress = "1.2.3.4"
38+
poolPort = int32(5678)
39+
destinationEndpointHintKey = "test-target"
40+
namespace = "ns1"
41+
)
42+
43+
func TestServer(t *testing.T) {
44+
theHeaderValue := "body"
45+
requestHeader := "x-test"
46+
47+
expectedRequestHeaders := map[string]string{destinationEndpointHintKey: fmt.Sprintf("%s:%d", podAddress, poolPort),
48+
"Content-Length": "42", ":method": "POST", requestHeader: theHeaderValue}
49+
expectedResponseHeaders := map[string]string{"x-went-into-resp-headers": "true", ":method": "POST", requestHeader: theHeaderValue}
50+
expectedSchedulerHeaders := map[string]string{":method": "POST", requestHeader: theHeaderValue}
51+
52+
t.Run("server", func(t *testing.T) {
53+
tsModel := "food-review"
54+
model := testutil.MakeInferenceModel("v1").
55+
CreationTimestamp(metav1.Unix(1000, 0)).
56+
ModelName(tsModel).ObjRef()
57+
58+
director := &testDirector{}
59+
ctx, cancel, ds, _ := utils.PrepareForTestStreamingServer([]*v1alpha2.InferenceModel{model},
60+
[]*v1.Pod{{ObjectMeta: metav1.ObjectMeta{Name: podName}}}, "test-pool1", namespace, poolPort)
61+
62+
streamingServer := handlers.NewStreamingServer(namespace, destinationEndpointHintKey, ds, director)
63+
64+
testListener, errChan := utils.SetupTestStreamingServer(t, ctx, ds, streamingServer)
65+
process, conn := utils.GetStreamingServerClient(ctx, t)
66+
defer conn.Close()
67+
68+
// Send request headers - no response expected
69+
headers := utils.BuildEnvoyGRPCHeaders(map[string]string{requestHeader: theHeaderValue, ":method": "POST"}, true)
70+
request := &pb.ProcessingRequest{
71+
Request: &pb.ProcessingRequest_RequestHeaders{
72+
RequestHeaders: headers,
73+
},
74+
}
75+
err := process.Send(request)
76+
if err != nil {
77+
t.Error("Error sending request headers", err)
78+
}
79+
80+
// Send request body
81+
requestBody := "{\"model\":\"food-review\",\"prompt\":\"Is banana tasty?\"}"
82+
expectedBody := "{\"model\":\"v1\",\"prompt\":\"Is banana tasty?\"}"
83+
request = &pb.ProcessingRequest{
84+
Request: &pb.ProcessingRequest_RequestBody{
85+
RequestBody: &pb.HttpBody{
86+
Body: []byte(requestBody),
87+
EndOfStream: true,
88+
},
89+
},
90+
}
91+
err = process.Send(request)
92+
if err != nil {
93+
t.Error("Error sending request body", err)
94+
}
95+
96+
// Receive request headers and check
97+
responseReqHeaders, err := process.Recv()
98+
if err != nil {
99+
t.Error("Error receiving response", err)
100+
} else {
101+
if responseReqHeaders == nil || responseReqHeaders.GetRequestHeaders() == nil ||
102+
responseReqHeaders.GetRequestHeaders().Response == nil ||
103+
responseReqHeaders.GetRequestHeaders().Response.HeaderMutation == nil ||
104+
responseReqHeaders.GetRequestHeaders().Response.HeaderMutation.SetHeaders == nil {
105+
t.Error("Invalid request headers response")
106+
} else if !utils.CheckEnvoyGRPCHeaders(t, responseReqHeaders.GetRequestHeaders().Response, expectedRequestHeaders) {
107+
t.Error("Incorrect request headers")
108+
}
109+
}
110+
111+
// Receive request body and check
112+
responseReqBody, err := process.Recv()
113+
if err != nil {
114+
t.Error("Error receiving response", err)
115+
} else {
116+
if responseReqBody == nil || responseReqBody.GetRequestBody() == nil ||
117+
responseReqBody.GetRequestBody().Response == nil ||
118+
responseReqBody.GetRequestBody().Response.BodyMutation == nil ||
119+
responseReqBody.GetRequestBody().Response.BodyMutation.GetStreamedResponse() == nil {
120+
t.Error("Invalid request body response")
121+
} else {
122+
body := responseReqBody.GetRequestBody().Response.BodyMutation.GetStreamedResponse().Body
123+
if string(body) != expectedBody {
124+
t.Errorf("Incorrect body %s expected %s", string(body), expectedBody)
125+
}
126+
}
127+
}
128+
129+
// Check headers passed to the scheduler
130+
if len(director.requestHeaders) != 2 {
131+
t.Errorf("Incorrect number of request headers %d instead of 2", len(director.requestHeaders))
132+
}
133+
for expectedKey, expectedValue := range expectedSchedulerHeaders {
134+
got, ok := director.requestHeaders[expectedKey]
135+
if !ok {
136+
t.Errorf("Missing header %s", expectedKey)
137+
} else if got != expectedValue {
138+
t.Errorf("Incorrect value for header %s, want %s got %s", expectedKey, expectedValue, got)
139+
}
140+
}
141+
142+
// Send response headers
143+
headers = utils.BuildEnvoyGRPCHeaders(map[string]string{requestHeader: theHeaderValue, ":method": "POST"}, false)
144+
request = &pb.ProcessingRequest{
145+
Request: &pb.ProcessingRequest_ResponseHeaders{
146+
ResponseHeaders: headers,
147+
},
148+
}
149+
err = process.Send(request)
150+
if err != nil {
151+
t.Error("Error sending response", err)
152+
}
153+
154+
// Receive response headers and check
155+
response, err := process.Recv()
156+
if err != nil {
157+
t.Error("Error receiving response", err)
158+
} else {
159+
if response == nil || response.GetResponseHeaders() == nil || response.GetResponseHeaders().Response == nil ||
160+
response.GetResponseHeaders().Response.HeaderMutation == nil ||
161+
response.GetResponseHeaders().Response.HeaderMutation.SetHeaders == nil {
162+
t.Error("Invalid response")
163+
} else if !utils.CheckEnvoyGRPCHeaders(t, response.GetResponseHeaders().Response, expectedResponseHeaders) {
164+
t.Error("Incorrect response headers")
165+
}
166+
}
167+
168+
cancel()
169+
<-errChan
170+
testListener.Close()
171+
})
172+
}
173+
174+
type testDirector struct {
175+
requestHeaders map[string]string
176+
}
177+
178+
func (ts *testDirector) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
179+
ts.requestHeaders = reqCtx.Request.Headers
180+
181+
reqCtx.Request.Body["model"] = "v1"
182+
reqCtx.TargetEndpoint = fmt.Sprintf("%s:%d", podAddress, poolPort)
183+
return reqCtx, nil
184+
}
185+
186+
func (ts *testDirector) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
187+
return reqCtx, nil
188+
}
189+
190+
func (ts *testDirector) GetRandomPod() *backend.Pod {
191+
return nil
192+
}

test/utils/server.go

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package utils
18+
19+
import (
20+
"context"
21+
"net"
22+
"testing"
23+
"time"
24+
25+
corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
26+
pb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
27+
"google.golang.org/grpc"
28+
"google.golang.org/grpc/credentials/insecure"
29+
"google.golang.org/grpc/test/bufconn"
30+
v1 "k8s.io/api/core/v1"
31+
"k8s.io/apimachinery/pkg/runtime"
32+
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
33+
"sigs.k8s.io/controller-runtime/pkg/client"
34+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
35+
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
36+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
37+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
38+
testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
39+
)
40+
41+
const bufSize = 1024 * 1024
42+
43+
var testListener *bufconn.Listener
44+
45+
func PrepareForTestStreamingServer(models []*v1alpha2.InferenceModel, pods []*v1.Pod, poolName string, namespace string,
46+
poolPort int32) (context.Context, context.CancelFunc, datastore.Datastore, *metrics.FakePodMetricsClient) {
47+
ctx, cancel := context.WithCancel(context.Background())
48+
49+
pmc := &metrics.FakePodMetricsClient{}
50+
pmf := metrics.NewPodMetricsFactory(pmc, time.Second)
51+
ds := datastore.NewDatastore(ctx, pmf)
52+
53+
initObjs := []client.Object{}
54+
for _, model := range models {
55+
initObjs = append(initObjs, model)
56+
ds.ModelSetIfOlder(model)
57+
}
58+
for _, pod := range pods {
59+
initObjs = append(initObjs, pod)
60+
ds.PodUpdateOrAddIfNotExist(pod)
61+
}
62+
63+
scheme := runtime.NewScheme()
64+
_ = clientgoscheme.AddToScheme(scheme)
65+
_ = v1alpha2.Install(scheme)
66+
fakeClient := fake.NewClientBuilder().
67+
WithScheme(scheme).
68+
WithObjects(initObjs...).
69+
Build()
70+
pool := testutil.MakeInferencePool(poolName).Namespace(namespace).ObjRef()
71+
pool.Spec.TargetPortNumber = poolPort
72+
_ = ds.PoolSet(context.Background(), fakeClient, pool)
73+
74+
return ctx, cancel, ds, pmc
75+
}
76+
77+
func SetupTestStreamingServer(t *testing.T, ctx context.Context, ds datastore.Datastore,
78+
streamingServer pb.ExternalProcessorServer) (*bufconn.Listener, chan error) {
79+
testListener = bufconn.Listen(bufSize)
80+
81+
errChan := make(chan error)
82+
go func() {
83+
err := LaunchTestGRPCServer(streamingServer, ctx, testListener)
84+
if err != nil {
85+
t.Error("Error launching listener", err)
86+
}
87+
errChan <- err
88+
}()
89+
90+
time.Sleep(2 * time.Second)
91+
return testListener, errChan
92+
}
93+
94+
func testDialer(context.Context, string) (net.Conn, error) {
95+
return testListener.Dial()
96+
}
97+
98+
func GetStreamingServerClient(ctx context.Context, t *testing.T) (pb.ExternalProcessor_ProcessClient, *grpc.ClientConn) {
99+
opts := []grpc.DialOption{
100+
grpc.WithTransportCredentials(insecure.NewCredentials()),
101+
grpc.WithContextDialer(testDialer),
102+
}
103+
conn, err := grpc.NewClient("passthrough://bufconn", opts...)
104+
if err != nil {
105+
t.Error(err)
106+
return nil, nil
107+
}
108+
109+
extProcClient := pb.NewExternalProcessorClient(conn)
110+
process, err := extProcClient.Process(ctx)
111+
if err != nil {
112+
t.Error(err)
113+
return nil, nil
114+
}
115+
116+
return process, conn
117+
}
118+
119+
// LaunchTestGRPCServer actually starts the server (enables testing)
120+
func LaunchTestGRPCServer(s pb.ExternalProcessorServer, ctx context.Context, listener net.Listener) error {
121+
grpcServer := grpc.NewServer()
122+
123+
pb.RegisterExternalProcessorServer(grpcServer, s)
124+
125+
// Shutdown on context closed.
126+
// Terminate the server on context closed.
127+
go func() {
128+
<-ctx.Done()
129+
grpcServer.GracefulStop()
130+
}()
131+
132+
if err := grpcServer.Serve(listener); err != nil {
133+
return err
134+
}
135+
136+
return nil
137+
}
138+
139+
func CheckEnvoyGRPCHeaders(t *testing.T, response *pb.CommonResponse, expectedHeaders map[string]string) bool {
140+
headers := response.HeaderMutation.SetHeaders
141+
for expectedKey, expectedValue := range expectedHeaders {
142+
found := false
143+
for _, header := range headers {
144+
if header.Header.Key == expectedKey {
145+
if expectedValue != string(header.Header.RawValue) {
146+
t.Errorf("Incorrect value for header %s, want %s got %s", expectedKey, expectedValue,
147+
string(header.Header.RawValue))
148+
return false
149+
}
150+
found = true
151+
break
152+
}
153+
}
154+
if !found {
155+
t.Errorf("Missing header %s", expectedKey)
156+
return false
157+
}
158+
}
159+
160+
for _, header := range headers {
161+
expectedValue, ok := expectedHeaders[header.Header.Key]
162+
if !ok {
163+
t.Errorf("Unexpected header %s", header.Header.Key)
164+
return false
165+
} else if expectedValue != string(header.Header.RawValue) {
166+
t.Errorf("Incorrect value for header %s, want %s got %s", header.Header.Key, expectedValue,
167+
string(header.Header.RawValue))
168+
return false
169+
}
170+
}
171+
return true
172+
}
173+
174+
func BuildEnvoyGRPCHeaders(headers map[string]string, rawValue bool) *pb.HttpHeaders {
175+
headerValues := make([]*corev3.HeaderValue, 0)
176+
for key, value := range headers {
177+
header := &corev3.HeaderValue{Key: key}
178+
if rawValue {
179+
header.RawValue = []byte(value)
180+
} else {
181+
header.Value = value
182+
}
183+
headerValues = append(headerValues, header)
184+
}
185+
return &pb.HttpHeaders{
186+
Headers: &corev3.HeaderMap{
187+
Headers: headerValues,
188+
},
189+
}
190+
}

0 commit comments

Comments
 (0)