Skip to content

Commit f8ed045

Browse files
committed
Add golang shim for comms wth EPP
Problem: In order for NGINX to get the endpoint of the AI workload from the EndpointPicker, it needs to send a gRPC request using the proper protobuf protocol. Solution: A simple Go server is injected as an additional container when the inference extension feature is enabled, that will listen for a request from our (upcoming) NJS module, and forward to the configured EPP to get a response in a header.
1 parent e9a3568 commit f8ed045

File tree

10 files changed

+559
-0
lines changed

10 files changed

+559
-0
lines changed

cmd/gateway/commands.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,20 @@ func createSleepCommand() *cobra.Command {
728728
return cmd
729729
}
730730

731+
func createEndpointPickerCommand() *cobra.Command {
732+
cmd := &cobra.Command{
733+
Use: "endpoint-picker",
734+
Short: "Shim server for communication between NGINX and the Gateway API Inference Extension Endpoint Picker",
735+
RunE: func(_ *cobra.Command, _ []string) error {
736+
logger := ctlrZap.New().WithName("endpoint-picker-shim")
737+
handler := createEndpointPickerHandler(realExtProcClientFactory(), logger)
738+
return endpointPickerServer(handler)
739+
},
740+
}
741+
742+
return cmd
743+
}
744+
731745
func parseFlags(flags *pflag.FlagSet) ([]string, []string) {
732746
var flagKeys, flagValues []string
733747

cmd/gateway/endpoint_picker.go

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
package main
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"io"
7+
"net"
8+
"net/http"
9+
"time"
10+
11+
corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
12+
extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
13+
"github.com/go-logr/logr"
14+
"google.golang.org/grpc"
15+
"google.golang.org/grpc/credentials/insecure"
16+
eppMetadata "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
17+
)
18+
19+
const (
20+
// defaultPort is the default port for this server to listen on. If collisions become a problem,
21+
// we can make this configurable via the NginxProxy resource.
22+
defaultPort = 54800 // why 54800? Sum "nginx" in ASCII and multiply by 100.
23+
// eppEndpointHostHeader is the HTTP header used to specify the EPP endpoint host, set by the NJS module caller.
24+
eppEndpointHostHeader = "X-EPP-Host"
25+
// eppEndpointPortHeader is the HTTP header used to specify the EPP endpoint port, set by the NJS module caller.
26+
eppEndpointPortHeader = "X-EPP-Port"
27+
)
28+
29+
// extProcClientFactory creates a new ExternalProcessorClient and returns a close function.
30+
type extProcClientFactory func(target string) (extprocv3.ExternalProcessorClient, func() error, error)
31+
32+
// endpointPickerServer starts an HTTP server on the given port with the provided handler.
33+
func endpointPickerServer(handler http.Handler) error {
34+
server := &http.Server{
35+
Addr: fmt.Sprintf(":%d", defaultPort),
36+
Handler: handler,
37+
ReadHeaderTimeout: 10 * time.Second,
38+
}
39+
return server.ListenAndServe()
40+
}
41+
42+
// realExtProcClientFactory returns a factory that creates a new gRPC connection and client per request.
43+
func realExtProcClientFactory() extProcClientFactory {
44+
return func(target string) (extprocv3.ExternalProcessorClient, func() error, error) {
45+
conn, err := grpc.NewClient(target, grpc.WithTransportCredentials(insecure.NewCredentials()))
46+
if err != nil {
47+
return nil, nil, err
48+
}
49+
client := extprocv3.NewExternalProcessorClient(conn)
50+
return client, conn.Close, nil
51+
}
52+
}
53+
54+
// createEndpointPickerHandler returns an http.Handler that forwards requests to the EndpointPicker.
55+
func createEndpointPickerHandler(factory extProcClientFactory, logger logr.Logger) http.Handler {
56+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
57+
host := r.Header.Get(eppEndpointHostHeader)
58+
port := r.Header.Get(eppEndpointPortHeader)
59+
if host == "" || port == "" {
60+
msg := fmt.Sprintf(
61+
"missing at least one of required headers: %s and %s",
62+
eppEndpointHostHeader,
63+
eppEndpointPortHeader,
64+
)
65+
logger.Error(errors.New(msg), "error contacting EndpointPicker")
66+
http.Error(w, msg, http.StatusBadRequest)
67+
return
68+
}
69+
70+
target := net.JoinHostPort(host, port)
71+
logger.Info("Getting inference workload endpoint from EndpointPicker", "endpointPicker", target)
72+
73+
client, closeConn, err := factory(target)
74+
if err != nil {
75+
logger.Error(err, "error creating gRPC client")
76+
http.Error(w, fmt.Sprintf("error creating gRPC client: %v", err), http.StatusInternalServerError)
77+
return
78+
}
79+
defer func() {
80+
if err := closeConn(); err != nil {
81+
logger.Error(err, "error closing gRPC connection")
82+
}
83+
}()
84+
85+
stream, err := client.Process(r.Context())
86+
if err != nil {
87+
logger.Error(err, "error opening ext_proc stream")
88+
http.Error(w, fmt.Sprintf("error opening ext_proc stream: %v", err), http.StatusBadGateway)
89+
return
90+
}
91+
92+
if code, err := sendRequest(stream, r); err != nil {
93+
logger.Error(err, "error sending request")
94+
http.Error(w, err.Error(), code)
95+
return
96+
}
97+
98+
// Receive response and extract header
99+
for {
100+
resp, err := stream.Recv()
101+
if errors.Is(err, io.EOF) {
102+
break // End of stream
103+
} else if err != nil {
104+
logger.Error(err, "error receiving from ext_proc")
105+
http.Error(w, fmt.Sprintf("error receiving from ext_proc: %v", err), http.StatusBadGateway)
106+
return
107+
}
108+
109+
if ir := resp.GetImmediateResponse(); ir != nil {
110+
code := int(ir.GetStatus().GetCode())
111+
body := ir.GetBody()
112+
logger.Error(fmt.Errorf("code: %d, body: %s", code, body), "received immediate response")
113+
http.Error(w, string(body), code)
114+
return
115+
}
116+
117+
headers := resp.GetRequestHeaders().GetResponse().GetHeaderMutation().GetSetHeaders()
118+
for _, h := range headers {
119+
if h.GetHeader().GetKey() == eppMetadata.DestinationEndpointKey {
120+
endpoint := string(h.GetHeader().GetRawValue())
121+
w.Header().Set(h.GetHeader().GetKey(), endpoint)
122+
logger.Info("Found endpoint", "endpoint", endpoint)
123+
}
124+
}
125+
}
126+
w.WriteHeader(http.StatusOK)
127+
})
128+
}
129+
130+
func sendRequest(stream extprocv3.ExternalProcessor_ProcessClient, r *http.Request) (int, error) {
131+
if err := stream.Send(buildHeaderRequest(r)); err != nil {
132+
return http.StatusBadGateway, fmt.Errorf("error sending headers: %w", err)
133+
}
134+
135+
bodyReq, err := buildBodyRequest(r)
136+
if err != nil {
137+
return http.StatusInternalServerError, fmt.Errorf("error building body request: %w", err)
138+
}
139+
140+
if err := stream.Send(bodyReq); err != nil {
141+
return http.StatusBadGateway, fmt.Errorf("error sending body: %w", err)
142+
}
143+
144+
if err := stream.CloseSend(); err != nil {
145+
return http.StatusInternalServerError, fmt.Errorf("error closing stream: %w", err)
146+
}
147+
148+
return 0, nil
149+
}
150+
151+
func buildHeaderRequest(r *http.Request) *extprocv3.ProcessingRequest {
152+
headerList := make([]*corev3.HeaderValue, 0, len(r.Header))
153+
headerMap := &corev3.HeaderMap{
154+
Headers: headerList,
155+
}
156+
157+
for key, values := range r.Header {
158+
for _, value := range values {
159+
headerMap.Headers = append(headerMap.Headers, &corev3.HeaderValue{
160+
Key: key,
161+
Value: value,
162+
})
163+
}
164+
}
165+
166+
return &extprocv3.ProcessingRequest{
167+
Request: &extprocv3.ProcessingRequest_RequestHeaders{
168+
RequestHeaders: &extprocv3.HttpHeaders{
169+
Headers: headerMap,
170+
EndOfStream: false,
171+
},
172+
},
173+
}
174+
}
175+
176+
func buildBodyRequest(r *http.Request) (*extprocv3.ProcessingRequest, error) {
177+
body, err := io.ReadAll(r.Body)
178+
if err != nil {
179+
return nil, fmt.Errorf("error reading request body: %w", err)
180+
}
181+
182+
return &extprocv3.ProcessingRequest{
183+
Request: &extprocv3.ProcessingRequest_RequestBody{
184+
RequestBody: &extprocv3.HttpBody{
185+
Body: body,
186+
EndOfStream: true,
187+
},
188+
},
189+
}, nil
190+
}

0 commit comments

Comments
 (0)