Skip to content

Commit 6995f2f

Browse files
authored
Query EPP and proxy AI traffic (#3942)
Problem: We need to connect NGINX to the Golang shim that talks to the EndpointPicker, and then pass client traffic to the proper inference workload. Solution: Write an NJS module that will query the local Go server to get the AI endpoint to route traffic to. Then redirect the original client request to an internal location that proxies the traffic to the chosen endpoint. The location building gets a bit complicated especially when using both HTTP matching conditions and inference workloads. It requires 2 layers of internal redirects. I added lots of comments to hopefully clear up how we build these locations to perform all the routing steps.
1 parent 183dc72 commit 6995f2f

File tree

21 files changed

+1079
-227
lines changed

21 files changed

+1079
-227
lines changed

cmd/gateway/endpoint_picker.go

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,8 @@ import (
1414
"google.golang.org/grpc"
1515
"google.golang.org/grpc/credentials/insecure"
1616
eppMetadata "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
17-
)
1817

19-
const (
20-
// defaultPort is the default port for this server to listen on. If collisions become a problem,
21-
// we can make this configurable via the NginxProxy resource.
22-
defaultPort = 54800 // why 54800? Sum "nginx" in ASCII and multiply by 100.
23-
// eppEndpointHostHeader is the HTTP header used to specify the EPP endpoint host, set by the NJS module caller.
24-
eppEndpointHostHeader = "X-EPP-Host"
25-
// eppEndpointPortHeader is the HTTP header used to specify the EPP endpoint port, set by the NJS module caller.
26-
eppEndpointPortHeader = "X-EPP-Port"
18+
"github.com/nginx/nginx-gateway-fabric/v2/internal/framework/types"
2719
)
2820

2921
// extProcClientFactory creates a new ExternalProcessorClient and returns a close function.
@@ -32,7 +24,7 @@ type extProcClientFactory func(target string) (extprocv3.ExternalProcessorClient
3224
// endpointPickerServer starts an HTTP server on the given port with the provided handler.
3325
func endpointPickerServer(handler http.Handler) error {
3426
server := &http.Server{
35-
Addr: fmt.Sprintf("127.0.0.1:%d", defaultPort),
27+
Addr: fmt.Sprintf("127.0.0.1:%d", types.GoShimPort),
3628
Handler: handler,
3729
ReadHeaderTimeout: 10 * time.Second,
3830
}
@@ -54,13 +46,13 @@ func realExtProcClientFactory() extProcClientFactory {
5446
// createEndpointPickerHandler returns an http.Handler that forwards requests to the EndpointPicker.
5547
func createEndpointPickerHandler(factory extProcClientFactory, logger logr.Logger) http.Handler {
5648
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
57-
host := r.Header.Get(eppEndpointHostHeader)
58-
port := r.Header.Get(eppEndpointPortHeader)
49+
host := r.Header.Get(types.EPPEndpointHostHeader)
50+
port := r.Header.Get(types.EPPEndpointPortHeader)
5951
if host == "" || port == "" {
6052
msg := fmt.Sprintf(
6153
"missing at least one of required headers: %s and %s",
62-
eppEndpointHostHeader,
63-
eppEndpointPortHeader,
54+
types.EPPEndpointHostHeader,
55+
types.EPPEndpointPortHeader,
6456
)
6557
logger.Error(errors.New(msg), "error contacting EndpointPicker")
6658
http.Error(w, msg, http.StatusBadRequest)
@@ -174,6 +166,10 @@ func buildHeaderRequest(r *http.Request) *extprocv3.ProcessingRequest {
174166
}
175167

176168
func buildBodyRequest(r *http.Request) (*extprocv3.ProcessingRequest, error) {
169+
if r.ContentLength == 0 {
170+
return nil, errors.New("request body is empty")
171+
}
172+
177173
body, err := io.ReadAll(r.Body)
178174
if err != nil {
179175
return nil, fmt.Errorf("error reading request body: %w", err)

cmd/gateway/endpoint_picker_test.go

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import (
1717
"google.golang.org/grpc"
1818
"google.golang.org/grpc/metadata"
1919
eppMetadata "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
20+
21+
"github.com/nginx/nginx-gateway-fabric/v2/internal/framework/types"
2022
)
2123

2224
type mockExtProcClient struct {
@@ -122,8 +124,8 @@ func TestEndpointPickerHandler_Success(t *testing.T) {
122124

123125
h := createEndpointPickerHandler(factory, logr.Discard())
124126
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body"))
125-
req.Header.Set(eppEndpointHostHeader, "test-host")
126-
req.Header.Set(eppEndpointPortHeader, "1234")
127+
req.Header.Set(types.EPPEndpointHostHeader, "test-host")
128+
req.Header.Set(types.EPPEndpointPortHeader, "1234")
127129
req.Header.Set("Content-Type", "application/json")
128130
w := httptest.NewRecorder()
129131

@@ -165,8 +167,8 @@ func TestEndpointPickerHandler_ImmediateResponse(t *testing.T) {
165167

166168
h := createEndpointPickerHandler(factory, logr.Discard())
167169
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body"))
168-
req.Header.Set(eppEndpointHostHeader, "test-host")
169-
req.Header.Set(eppEndpointPortHeader, "1234")
170+
req.Header.Set(types.EPPEndpointHostHeader, "test-host")
171+
req.Header.Set(types.EPPEndpointPortHeader, "1234")
170172
w := httptest.NewRecorder()
171173

172174
h.ServeHTTP(w, req)
@@ -190,8 +192,8 @@ func TestEndpointPickerHandler_Errors(t *testing.T) {
190192
h := createEndpointPickerHandler(factory, logr.Discard())
191193
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body"))
192194
if setHeaders {
193-
req.Header.Set(eppEndpointHostHeader, "test-host")
194-
req.Header.Set(eppEndpointPortHeader, "1234")
195+
req.Header.Set(types.EPPEndpointHostHeader, "test-host")
196+
req.Header.Set(types.EPPEndpointPortHeader, "1234")
195197
}
196198
w := httptest.NewRecorder()
197199
h.ServeHTTP(w, req)
@@ -236,7 +238,33 @@ func TestEndpointPickerHandler_Errors(t *testing.T) {
236238
}
237239
runErrorTestCase(factory, true, http.StatusBadGateway, "error sending headers")
238240

239-
// 4. Error sending body
241+
// 4a. Error building body request (content length 0)
242+
client = &mockProcessClient{
243+
SendFunc: func(*extprocv3.ProcessingRequest) error {
244+
return nil
245+
},
246+
RecvFunc: func() (*extprocv3.ProcessingResponse, error) { return nil, io.EOF },
247+
}
248+
extProcClient = &mockExtProcClient{
249+
ProcessFunc: func(context.Context, ...grpc.CallOption) (extprocv3.ExternalProcessor_ProcessClient, error) {
250+
return client, nil
251+
},
252+
}
253+
factory = func(string) (extprocv3.ExternalProcessorClient, func() error, error) {
254+
return extProcClient, func() error { return nil }, nil
255+
}
256+
h := createEndpointPickerHandler(factory, logr.Discard())
257+
req := httptest.NewRequest(http.MethodPost, "/", nil) // nil body, ContentLength = 0
258+
req.Header.Set(types.EPPEndpointHostHeader, "test-host")
259+
req.Header.Set(types.EPPEndpointPortHeader, "1234")
260+
w := httptest.NewRecorder()
261+
h.ServeHTTP(w, req)
262+
resp := w.Result()
263+
g.Expect(resp.StatusCode).To(Equal(http.StatusInternalServerError))
264+
body, _ := io.ReadAll(resp.Body)
265+
g.Expect(string(body)).To(ContainSubstring("request body is empty"))
266+
267+
// 4b. Error sending body
240268
client = &mockProcessClient{
241269
SendFunc: func(req *extprocv3.ProcessingRequest) error {
242270
if req.GetRequestBody() != nil {

deploy/inference-nginx-plus/deploy.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ spec:
281281
- --nginx-docker-secret=nginx-plus-registry-secret
282282
- --nginx-plus
283283
- --usage-report-secret=nplus-license
284+
- --usage-report-enforce-initial-report=true
284285
- --metrics-port=9113
285286
- --health-port=8081
286287
- --leader-election-lock-name=nginx-gateway-leader-election

internal/controller/nginx/config/http/config.go

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,58 @@ type Server struct {
2626
type LocationType string
2727

2828
const (
29+
// InternalLocationType defines an internal location that is only accessible within NGINX.
2930
InternalLocationType LocationType = "internal"
31+
// ExternalLocationType defines a normal external location that is accessible by clients.
3032
ExternalLocationType LocationType = "external"
33+
// RedirectLocationType defines an external location that redirects to an internal location
34+
// based on HTTP matching conditions.
3135
RedirectLocationType LocationType = "redirect"
36+
// InferenceExternalLocationType defines an external location that is used for calling NJS
37+
// to get the inference workload endpoint and redirects to the internal location that will proxy_pass
38+
// to that endpoint.
39+
InferenceExternalLocationType LocationType = "inference-external"
40+
// InferenceInternalLocationType defines an internal location that is used for calling NJS
41+
// to get the inference workload endpoint and redirects to the internal location that will proxy_pass
42+
// to that endpoint. This is used when an HTTP redirect location is also defined that redirects
43+
// to this internal inference location.
44+
InferenceInternalLocationType LocationType = "inference-internal"
3245
)
3346

3447
// Location holds all configuration for an HTTP location.
3548
type Location struct {
36-
Path string
37-
ProxyPass string
38-
HTTPMatchKey string
49+
// Return specifies a return directive (e.g., HTTP status or redirect) for this location block.
50+
Return *Return
51+
// ProxySSLVerify controls SSL verification for upstreams when proxying requests.
52+
ProxySSLVerify *ProxySSLVerify
53+
// ProxyPass is the upstream backend (URL or name) to which requests are proxied.
54+
ProxyPass string
55+
// HTTPMatchKey is the key for associating HTTP match rules, used for routing and NJS module logic.
56+
HTTPMatchKey string
57+
// MirrorSplitClientsVariableName is the variable name for split_clients, used in traffic mirroring scenarios.
3958
MirrorSplitClientsVariableName string
40-
Type LocationType
41-
ProxySetHeaders []Header
42-
ProxySSLVerify *ProxySSLVerify
43-
Return *Return
44-
ResponseHeaders ResponseHeaders
45-
Rewrites []string
46-
MirrorPaths []string
47-
Includes []shared.Include
48-
GRPC bool
59+
// EPPInternalPath is the internal path for the inference NJS module to redirect to.
60+
EPPInternalPath string
61+
// EPPHost is the host for the EndpointPicker, used for inference routing.
62+
EPPHost string
63+
// Type indicates the type of location (external, internal, redirect, etc).
64+
Type LocationType
65+
// Path is the NGINX location path.
66+
Path string
67+
// ResponseHeaders are custom response headers to be sent.
68+
ResponseHeaders ResponseHeaders
69+
// ProxySetHeaders are headers to set when proxying requests upstream.
70+
ProxySetHeaders []Header
71+
// Rewrites are rewrite rules for modifying request paths.
72+
Rewrites []string
73+
// MirrorPaths are paths to which requests are mirrored.
74+
MirrorPaths []string
75+
// Includes are additional NGINX config snippets or policies to include in this location.
76+
Includes []shared.Include
77+
// EPPPort is the port for the EndpointPicker, used for inference routing.
78+
EPPPort int
79+
// GRPC indicates if this location proxies gRPC traffic.
80+
GRPC bool
4981
}
5082

5183
// Header defines an HTTP header to be passed to the proxied server.

internal/controller/nginx/config/maps.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package config
22

33
import (
4+
"fmt"
45
"strings"
56
gotemplate "text/template"
67

8+
inference "sigs.k8s.io/gateway-api-inference-extension/api/v1"
9+
710
"github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/shared"
811
"github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/dataplane"
912
"github.com/nginx/nginx-gateway-fabric/v2/internal/framework/helpers"
@@ -26,6 +29,8 @@ const (
2629

2730
func executeMaps(conf dataplane.Configuration) []executeResult {
2831
maps := buildAddHeaderMaps(append(conf.HTTPServers, conf.SSLServers...))
32+
maps = append(maps, buildInferenceMaps(conf.BackendGroups)...)
33+
2934
result := executeResult{
3035
dest: httpConfigFile,
3136
data: helpers.MustExecuteTemplate(mapsTemplate, maps),
@@ -177,3 +182,42 @@ func createAddHeadersMap(name string) shared.Map {
177182
Parameters: params,
178183
}
179184
}
185+
186+
// buildInferenceMaps creates maps for InferencePool Backends.
187+
func buildInferenceMaps(groups []dataplane.BackendGroup) []shared.Map {
188+
inferenceMaps := make([]shared.Map, 0, len(groups))
189+
for _, group := range groups {
190+
for _, backend := range group.Backends {
191+
if backend.EndpointPickerConfig != nil {
192+
var defaultResult string
193+
switch backend.EndpointPickerConfig.FailureMode {
194+
// in FailClose mode, if the EPP is unavailable or returns an error,
195+
// we return an invalid backend to ensure the request fails
196+
case inference.EndpointPickerFailClose:
197+
defaultResult = invalidBackendRef
198+
// in FailOpen mode, if the EPP is unavailable or returns an error,
199+
// we fall back to the upstream
200+
case inference.EndpointPickerFailOpen:
201+
defaultResult = backend.UpstreamName
202+
}
203+
params := []shared.MapParameter{
204+
{
205+
Value: "~.+",
206+
Result: "$inference_workload_endpoint",
207+
},
208+
{
209+
Value: "default",
210+
Result: defaultResult,
211+
},
212+
}
213+
backendVarName := strings.ReplaceAll(backend.UpstreamName, "-", "_")
214+
inferenceMaps = append(inferenceMaps, shared.Map{
215+
Source: "$inference_workload_endpoint",
216+
Variable: fmt.Sprintf("$inference_backend_%s", backendVarName),
217+
Parameters: params,
218+
})
219+
}
220+
}
221+
}
222+
return inferenceMaps
223+
}

internal/controller/nginx/config/maps_test.go

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"testing"
66

77
. "github.com/onsi/gomega"
8+
inference "sigs.k8s.io/gateway-api-inference-extension/api/v1"
89

910
"github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/shared"
1011
"github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/dataplane"
@@ -59,22 +60,24 @@ func TestExecuteMaps(t *testing.T) {
5960

6061
conf := dataplane.Configuration{
6162
HTTPServers: []dataplane.VirtualServer{
62-
{
63-
PathRules: pathRules,
64-
},
65-
{
66-
PathRules: pathRules,
67-
},
68-
{
69-
IsDefault: true,
70-
},
63+
{PathRules: pathRules},
64+
{PathRules: pathRules},
65+
{IsDefault: true},
7166
},
7267
SSLServers: []dataplane.VirtualServer{
68+
{PathRules: pathRules},
69+
{IsDefault: true},
70+
},
71+
BackendGroups: []dataplane.BackendGroup{
7372
{
74-
PathRules: pathRules,
75-
},
76-
{
77-
IsDefault: true,
73+
Backends: []dataplane.Backend{
74+
{
75+
UpstreamName: "upstream1",
76+
EndpointPickerConfig: &inference.EndpointPickerRef{
77+
FailureMode: inference.EndpointPickerFailClose,
78+
},
79+
},
80+
},
7881
},
7982
},
8083
}
@@ -86,6 +89,9 @@ func TestExecuteMaps(t *testing.T) {
8689
"map ${http_my_second_add_header} $my_second_add_header_header_var {": 1,
8790
"~.* ${http_my_second_add_header},;": 1,
8891
"map ${http_my_set_header} $my_set_header_header_var {": 0,
92+
"$inference_workload_endpoint": 2,
93+
"$inference_backend": 1,
94+
"invalid-backend-ref": 1,
8995
}
9096

9197
mapResult := executeMaps(conf)
@@ -385,3 +391,36 @@ func TestCreateStreamMapsWithEmpty(t *testing.T) {
385391

386392
g.Expect(maps).To(BeNil())
387393
}
394+
395+
func TestBuildInferenceMaps(t *testing.T) {
396+
t.Parallel()
397+
g := NewWithT(t)
398+
399+
group := dataplane.BackendGroup{
400+
Backends: []dataplane.Backend{
401+
{
402+
UpstreamName: "upstream1",
403+
EndpointPickerConfig: &inference.EndpointPickerRef{
404+
FailureMode: inference.EndpointPickerFailClose,
405+
},
406+
},
407+
{
408+
UpstreamName: "upstream2",
409+
EndpointPickerConfig: &inference.EndpointPickerRef{
410+
FailureMode: inference.EndpointPickerFailOpen,
411+
},
412+
},
413+
{
414+
UpstreamName: "upstream3",
415+
EndpointPickerConfig: nil,
416+
},
417+
},
418+
}
419+
420+
maps := buildInferenceMaps([]dataplane.BackendGroup{group})
421+
g.Expect(maps).To(HaveLen(2))
422+
g.Expect(maps[0].Source).To(Equal("$inference_workload_endpoint"))
423+
g.Expect(maps[0].Variable).To(Equal("$inference_backend_upstream1"))
424+
g.Expect(maps[0].Parameters[1].Result).To(Equal("invalid-backend-ref"))
425+
g.Expect(maps[1].Parameters[1].Result).To(Equal("upstream2"))
426+
}

0 commit comments

Comments
 (0)