diff --git a/cmd/gateway/commands.go b/cmd/gateway/commands.go index ce2eff9df6..834355b270 100644 --- a/cmd/gateway/commands.go +++ b/cmd/gateway/commands.go @@ -728,6 +728,20 @@ func createSleepCommand() *cobra.Command { return cmd } +func createEndpointPickerCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "endpoint-picker", + Short: "Shim server for communication between NGINX and the Gateway API Inference Extension Endpoint Picker", + RunE: func(_ *cobra.Command, _ []string) error { + logger := ctlrZap.New().WithName("endpoint-picker-shim") + handler := createEndpointPickerHandler(realExtProcClientFactory(), logger) + return endpointPickerServer(handler) + }, + } + + return cmd +} + func parseFlags(flags *pflag.FlagSet) ([]string, []string) { var flagKeys, flagValues []string diff --git a/cmd/gateway/endpoint_picker.go b/cmd/gateway/endpoint_picker.go new file mode 100644 index 0000000000..7c67a83671 --- /dev/null +++ b/cmd/gateway/endpoint_picker.go @@ -0,0 +1,190 @@ +package main + +import ( + "errors" + "fmt" + "io" + "net" + "net/http" + "time" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/go-logr/logr" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + eppMetadata "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" +) + +const ( + // defaultPort is the default port for this server to listen on. If collisions become a problem, + // we can make this configurable via the NginxProxy resource. + defaultPort = 54800 // why 54800? Sum "nginx" in ASCII and multiply by 100. + // eppEndpointHostHeader is the HTTP header used to specify the EPP endpoint host, set by the NJS module caller. + eppEndpointHostHeader = "X-EPP-Host" + // eppEndpointPortHeader is the HTTP header used to specify the EPP endpoint port, set by the NJS module caller. + eppEndpointPortHeader = "X-EPP-Port" +) + +// extProcClientFactory creates a new ExternalProcessorClient and returns a close function. +type extProcClientFactory func(target string) (extprocv3.ExternalProcessorClient, func() error, error) + +// endpointPickerServer starts an HTTP server on the given port with the provided handler. +func endpointPickerServer(handler http.Handler) error { + server := &http.Server{ + Addr: fmt.Sprintf("127.0.0.1:%d", defaultPort), + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + } + return server.ListenAndServe() +} + +// realExtProcClientFactory returns a factory that creates a new gRPC connection and client per request. +func realExtProcClientFactory() extProcClientFactory { + return func(target string) (extprocv3.ExternalProcessorClient, func() error, error) { + conn, err := grpc.NewClient(target, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, nil, err + } + client := extprocv3.NewExternalProcessorClient(conn) + return client, conn.Close, nil + } +} + +// createEndpointPickerHandler returns an http.Handler that forwards requests to the EndpointPicker. +func createEndpointPickerHandler(factory extProcClientFactory, logger logr.Logger) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host := r.Header.Get(eppEndpointHostHeader) + port := r.Header.Get(eppEndpointPortHeader) + if host == "" || port == "" { + msg := fmt.Sprintf( + "missing at least one of required headers: %s and %s", + eppEndpointHostHeader, + eppEndpointPortHeader, + ) + logger.Error(errors.New(msg), "error contacting EndpointPicker") + http.Error(w, msg, http.StatusBadRequest) + return + } + + target := net.JoinHostPort(host, port) + logger.Info("Getting inference workload endpoint from EndpointPicker", "endpointPicker", target) + + client, closeConn, err := factory(target) + if err != nil { + logger.Error(err, "error creating gRPC client") + http.Error(w, fmt.Sprintf("error creating gRPC client: %v", err), http.StatusInternalServerError) + return + } + defer func() { + if err := closeConn(); err != nil { + logger.Error(err, "error closing gRPC connection") + } + }() + + stream, err := client.Process(r.Context()) + if err != nil { + logger.Error(err, "error opening ext_proc stream") + http.Error(w, fmt.Sprintf("error opening ext_proc stream: %v", err), http.StatusBadGateway) + return + } + + if code, err := sendRequest(stream, r); err != nil { + logger.Error(err, "error sending request") + http.Error(w, err.Error(), code) + return + } + + // Receive response and extract header + for { + resp, err := stream.Recv() + if errors.Is(err, io.EOF) { + break // End of stream + } else if err != nil { + logger.Error(err, "error receiving from ext_proc") + http.Error(w, fmt.Sprintf("error receiving from ext_proc: %v", err), http.StatusBadGateway) + return + } + + if ir := resp.GetImmediateResponse(); ir != nil { + code := int(ir.GetStatus().GetCode()) + body := ir.GetBody() + logger.Error(fmt.Errorf("code: %d, body: %s", code, body), "received immediate response") + http.Error(w, string(body), code) + return + } + + headers := resp.GetRequestHeaders().GetResponse().GetHeaderMutation().GetSetHeaders() + for _, h := range headers { + if h.GetHeader().GetKey() == eppMetadata.DestinationEndpointKey { + endpoint := string(h.GetHeader().GetRawValue()) + w.Header().Set(h.GetHeader().GetKey(), endpoint) + logger.Info("Found endpoint", "endpoint", endpoint) + } + } + } + w.WriteHeader(http.StatusOK) + }) +} + +func sendRequest(stream extprocv3.ExternalProcessor_ProcessClient, r *http.Request) (int, error) { + if err := stream.Send(buildHeaderRequest(r)); err != nil { + return http.StatusBadGateway, fmt.Errorf("error sending headers: %w", err) + } + + bodyReq, err := buildBodyRequest(r) + if err != nil { + return http.StatusInternalServerError, fmt.Errorf("error building body request: %w", err) + } + + if err := stream.Send(bodyReq); err != nil { + return http.StatusBadGateway, fmt.Errorf("error sending body: %w", err) + } + + if err := stream.CloseSend(); err != nil { + return http.StatusInternalServerError, fmt.Errorf("error closing stream: %w", err) + } + + return 0, nil +} + +func buildHeaderRequest(r *http.Request) *extprocv3.ProcessingRequest { + headerList := make([]*corev3.HeaderValue, 0, len(r.Header)) + headerMap := &corev3.HeaderMap{ + Headers: headerList, + } + + for key, values := range r.Header { + for _, value := range values { + headerMap.Headers = append(headerMap.Headers, &corev3.HeaderValue{ + Key: key, + Value: value, + }) + } + } + + return &extprocv3.ProcessingRequest{ + Request: &extprocv3.ProcessingRequest_RequestHeaders{ + RequestHeaders: &extprocv3.HttpHeaders{ + Headers: headerMap, + EndOfStream: false, + }, + }, + } +} + +func buildBodyRequest(r *http.Request) (*extprocv3.ProcessingRequest, error) { + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, fmt.Errorf("error reading request body: %w", err) + } + + return &extprocv3.ProcessingRequest{ + Request: &extprocv3.ProcessingRequest_RequestBody{ + RequestBody: &extprocv3.HttpBody{ + Body: body, + EndOfStream: true, + }, + }, + }, nil +} diff --git a/cmd/gateway/endpoint_picker_test.go b/cmd/gateway/endpoint_picker_test.go new file mode 100644 index 0000000000..99808348fc --- /dev/null +++ b/cmd/gateway/endpoint_picker_test.go @@ -0,0 +1,261 @@ +package main + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" + "github.com/go-logr/logr" + . "github.com/onsi/gomega" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + eppMetadata "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" +) + +type mockExtProcClient struct { + ProcessFunc func( + context.Context, + ...grpc.CallOption, + ) (extprocv3.ExternalProcessor_ProcessClient, error) +} + +func (m *mockExtProcClient) Process( + ctx context.Context, + opts ...grpc.CallOption, +) (extprocv3.ExternalProcessor_ProcessClient, error) { + if m.ProcessFunc != nil { + return m.ProcessFunc(ctx, opts...) + } + return nil, errors.New("not implemented") +} + +type mockProcessClient struct { + SendFunc func(*extprocv3.ProcessingRequest) error + RecvFunc func() (*extprocv3.ProcessingResponse, error) + CloseSendFunc func() error + Ctx context.Context +} + +func (m *mockProcessClient) Send(req *extprocv3.ProcessingRequest) error { + if m.SendFunc != nil { + return m.SendFunc(req) + } + return nil +} + +func (m *mockProcessClient) Recv() (*extprocv3.ProcessingResponse, error) { + if m.RecvFunc != nil { + return m.RecvFunc() + } + return nil, io.EOF +} + +func (*mockProcessClient) RecvMsg(any) error { return nil } +func (*mockProcessClient) SendMsg(any) error { return nil } + +func (m *mockProcessClient) CloseSend() error { + if m.CloseSendFunc != nil { + return m.CloseSendFunc() + } + return nil +} + +func (m *mockProcessClient) Context() context.Context { + if m.Ctx != nil { + return m.Ctx + } + return context.Background() +} + +func (*mockProcessClient) Header() (metadata.MD, error) { return nil, nil } //nolint:nilnil // interface satisfier +func (*mockProcessClient) Trailer() metadata.MD { return nil } + +func TestEndpointPickerHandler_Success(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + // Prepare mock client to simulate gRPC responses + callCount := 0 + client := &mockProcessClient{ + SendFunc: func(*extprocv3.ProcessingRequest) error { return nil }, + RecvFunc: func() (*extprocv3.ProcessingResponse, error) { + if callCount == 0 { + callCount++ + resp := &extprocv3.ProcessingResponse{ + Response: &extprocv3.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extprocv3.HeadersResponse{ + Response: &extprocv3.CommonResponse{ + HeaderMutation: &extprocv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{{ + Header: &corev3.HeaderValue{ + Key: eppMetadata.DestinationEndpointKey, + RawValue: []byte("test-value"), + }, + }}, + }, + }, + }, + }, + } + return resp, nil + } + return nil, io.EOF + }, + } + + extProcClient := &mockExtProcClient{ + ProcessFunc: func(context.Context, ...grpc.CallOption) (extprocv3.ExternalProcessor_ProcessClient, error) { + return client, nil + }, + } + + factory := func(string) (extprocv3.ExternalProcessorClient, func() error, error) { + return extProcClient, func() error { return nil }, nil + } + + h := createEndpointPickerHandler(factory, logr.Discard()) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body")) + req.Header.Set(eppEndpointHostHeader, "test-host") + req.Header.Set(eppEndpointPortHeader, "1234") + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + resp := w.Result() + g.Expect(resp.StatusCode).To(Equal(http.StatusOK)) + g.Expect(resp.Header.Get(eppMetadata.DestinationEndpointKey)).To(Equal("test-value")) +} + +func TestEndpointPickerHandler_ImmediateResponse(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + client := &mockProcessClient{ + SendFunc: func(*extprocv3.ProcessingRequest) error { return nil }, + RecvFunc: func() (*extprocv3.ProcessingResponse, error) { + resp := &extprocv3.ProcessingResponse{ + Response: &extprocv3.ProcessingResponse_ImmediateResponse{ + ImmediateResponse: &extprocv3.ImmediateResponse{ + Status: &typev3.HttpStatus{Code: http.StatusInternalServerError}, + Body: []byte("some error"), + }, + }, + } + return resp, nil + }, + } + + extClient := &mockExtProcClient{ + ProcessFunc: func(context.Context, ...grpc.CallOption) (extprocv3.ExternalProcessor_ProcessClient, error) { + return client, nil + }, + } + + factory := func(string) (extprocv3.ExternalProcessorClient, func() error, error) { + return extClient, func() error { return nil }, nil + } + + h := createEndpointPickerHandler(factory, logr.Discard()) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body")) + req.Header.Set(eppEndpointHostHeader, "test-host") + req.Header.Set(eppEndpointPortHeader, "1234") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + resp := w.Result() + + g.Expect(resp.StatusCode).To(Equal(http.StatusInternalServerError)) + body, _ := io.ReadAll(resp.Body) + g.Expect(string(body)).To(ContainSubstring("some error")) +} + +func TestEndpointPickerHandler_Errors(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + runErrorTestCase := func(factory func(string) (extprocv3.ExternalProcessorClient, func() error, error), + setHeaders bool, + expectedStatus int, + expectedBodySubstring string, + ) { + h := createEndpointPickerHandler(factory, logr.Discard()) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body")) + if setHeaders { + req.Header.Set(eppEndpointHostHeader, "test-host") + req.Header.Set(eppEndpointPortHeader, "1234") + } + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + resp := w.Result() + g.Expect(resp.StatusCode).To(Equal(expectedStatus)) + body, _ := io.ReadAll(resp.Body) + g.Expect(string(body)).To(ContainSubstring(expectedBodySubstring)) + } + + // 1. Error creating gRPC client + factoryErr := errors.New("factory error") + factory := func(string) (extprocv3.ExternalProcessorClient, func() error, error) { + return nil, nil, factoryErr + } + runErrorTestCase(factory, true, http.StatusInternalServerError, "error creating gRPC client") + + // 2. Error opening ext_proc stream + extProcClient := &mockExtProcClient{ + ProcessFunc: func(context.Context, ...grpc.CallOption) (extprocv3.ExternalProcessor_ProcessClient, error) { + return nil, errors.New("process error") + }, + } + factory = func(string) (extprocv3.ExternalProcessorClient, func() error, error) { + return extProcClient, func() error { return nil }, nil + } + runErrorTestCase(factory, true, http.StatusBadGateway, "error opening ext_proc stream") + + // 3. Error sending headers + client := &mockProcessClient{ + SendFunc: func(*extprocv3.ProcessingRequest) error { + return errors.New("send headers error") + }, + RecvFunc: func() (*extprocv3.ProcessingResponse, error) { return nil, io.EOF }, + } + extProcClient = &mockExtProcClient{ + ProcessFunc: func(context.Context, ...grpc.CallOption) (extprocv3.ExternalProcessor_ProcessClient, error) { + return client, nil + }, + } + factory = func(string) (extprocv3.ExternalProcessorClient, func() error, error) { + return extProcClient, func() error { return nil }, nil + } + runErrorTestCase(factory, true, http.StatusBadGateway, "error sending headers") + + // 4. Error sending body + client = &mockProcessClient{ + SendFunc: func(req *extprocv3.ProcessingRequest) error { + if req.GetRequestBody() != nil { + return errors.New("send body error") + } + return nil + }, + RecvFunc: func() (*extprocv3.ProcessingResponse, error) { return nil, io.EOF }, + } + extProcClient = &mockExtProcClient{ + ProcessFunc: func(context.Context, ...grpc.CallOption) (extprocv3.ExternalProcessor_ProcessClient, error) { + return client, nil + }, + } + factory = func(string) (extprocv3.ExternalProcessorClient, func() error, error) { + return extProcClient, func() error { return nil }, nil + } + runErrorTestCase(factory, true, http.StatusBadGateway, "error sending body") + + // 5. Error with empty headers + runErrorTestCase(factory, false, http.StatusBadRequest, "missing at least one of required headers") +} diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index 515fcc3f16..c932a4ee4c 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -25,6 +25,7 @@ func main() { createGenerateCertsCommand(), createInitializeCommand(), createSleepCommand(), + createEndpointPickerCommand(), ) if err := rootCmd.Execute(); err != nil { diff --git a/go.mod b/go.mod index d4ec8da3ba..8fd0b7ac56 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/nginx/nginx-gateway-fabric/v2 go 1.24.2 require ( + github.com/envoyproxy/go-control-plane/envoy v1.32.4 github.com/fsnotify/fsnotify v1.9.0 github.com/go-logr/logr v1.4.3 github.com/google/go-cmp v0.7.0 @@ -37,8 +38,10 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.12.2 // indirect + github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -60,6 +63,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect diff --git a/go.sum b/go.sum index 5325827f3f..80bc2f5ee2 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1x github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 h1:aQ3y1lwWyqYPiWZThqv1aFbZMiM9vblcSArJRf2Irls= +github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A= @@ -39,6 +41,10 @@ github.com/ebitengine/purego v0.8.2 h1:jPPGWs2sZ1UgOSgD2bClL0MJIqu58nOmIcBuXr62z github.com/ebitengine/purego v0.8.2/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/emicklei/go-restful/v3 v3.12.2 h1:DhwDP0vY3k8ZzE0RunuJy8GhNpPL6zqLkDf9B/a0/xU= github.com/emicklei/go-restful/v3 v3.12.2/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/envoyproxy/go-control-plane/envoy v1.32.4 h1:jb83lalDRZSpPWW2Z7Mck/8kXZ5CQAFYVjQcdVIr83A= +github.com/envoyproxy/go-control-plane/envoy v1.32.4/go.mod h1:Gzjc5k8JcJswLjAx1Zm+wSYE20UrLtt7JZMWiWQXQEw= +github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8= +github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k= github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= github.com/evanphx/json-patch/v5 v5.9.11 h1:/8HVnzMq13/3x9TPvjG08wUGqBTmZBsCWzjTM0wiaDU= @@ -155,6 +161,8 @@ github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNH github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/internal/controller/manager.go b/internal/controller/manager.go index d02411571b..dc9c4835bd 100644 --- a/internal/controller/manager.go +++ b/internal/controller/manager.go @@ -220,6 +220,7 @@ func StartManager(cfg config.Config) error { NginxDockerSecretNames: cfg.NginxDockerSecretNames, PlusUsageConfig: &cfg.UsageReportConfig, NginxOneConsoleTelemetryConfig: cfg.NginxOneConsoleTelemetryConfig, + InferenceExtension: cfg.InferenceExtension, }, ) if err != nil { diff --git a/internal/controller/provisioner/objects.go b/internal/controller/provisioner/objects.go index 14b3cbcc72..fda6c9caab 100644 --- a/internal/controller/provisioner/objects.go +++ b/internal/controller/provisioner/objects.go @@ -888,6 +888,7 @@ func (p *NginxProvisioner) buildNginxPodTemplateSpec( {MountPath: "/etc/nginx/events-includes", Name: "nginx-events-includes"}, }, SecurityContext: &corev1.SecurityContext{ + AllowPrivilegeEscalation: helpers.GetPointer(false), Capabilities: &corev1.Capabilities{ Drop: []corev1.Capability{"ALL"}, }, @@ -1108,6 +1109,30 @@ func (p *NginxProvisioner) buildNginxPodTemplateSpec( spec.Spec.Containers[0].VolumeMounts = volumeMounts } + if p.cfg.InferenceExtension { + spec.Spec.Containers = append(spec.Spec.Containers, corev1.Container{ + Name: "endpoint-picker-shim", + Image: p.cfg.GatewayPodConfig.Image, + ImagePullPolicy: pullPolicy, + Command: []string{ + "/usr/bin/gateway", + "endpoint-picker", + }, + SecurityContext: &corev1.SecurityContext{ + AllowPrivilegeEscalation: helpers.GetPointer(false), + Capabilities: &corev1.Capabilities{ + Drop: []corev1.Capability{"ALL"}, + }, + ReadOnlyRootFilesystem: helpers.GetPointer(true), + RunAsGroup: helpers.GetPointer[int64](1001), + RunAsUser: helpers.GetPointer[int64](101), + SeccompProfile: &corev1.SeccompProfile{ + Type: corev1.SeccompProfileTypeRuntimeDefault, + }, + }, + }) + } + return spec } diff --git a/internal/controller/provisioner/objects_test.go b/internal/controller/provisioner/objects_test.go index 4af6f2deca..bab40a417a 100644 --- a/internal/controller/provisioner/objects_test.go +++ b/internal/controller/provisioner/objects_test.go @@ -1755,3 +1755,57 @@ func TestBuildNginxResourceObjects_Patches(t *testing.T) { g.Expect(svc.Labels).To(HaveKeyWithValue("app", "nginx")) g.Expect(dep.Labels).To(HaveKeyWithValue("app", "nginx")) } + +func TestBuildNginxResourceObjects_InferenceExtension(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + agentTLSSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: agentTLSTestSecretName, + Namespace: ngfNamespace, + }, + Data: map[string][]byte{"tls.crt": []byte("tls")}, + } + fakeClient := fake.NewFakeClient(agentTLSSecret) + + provisioner := &NginxProvisioner{ + cfg: Config{ + GatewayPodConfig: &config.GatewayPodConfig{ + Namespace: ngfNamespace, + }, + AgentTLSSecretName: agentTLSTestSecretName, + InferenceExtension: true, + }, + k8sClient: fakeClient, + baseLabelSelector: metav1.LabelSelector{ + MatchLabels: map[string]string{"app": "nginx"}, + }, + } + + gateway := &gatewayv1.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "gw", + Namespace: "default", + }, + Spec: gatewayv1.GatewaySpec{ + Listeners: []gatewayv1.Listener{{Port: 80}}, + }, + } + + objects, err := provisioner.buildNginxResourceObjects("gw-nginx", gateway, &graph.EffectiveNginxProxy{}) + g.Expect(err).ToNot(HaveOccurred()) + + // Find the deployment object + var deployment *appsv1.Deployment + for _, obj := range objects { + if d, ok := obj.(*appsv1.Deployment); ok { + deployment = d + break + } + } + g.Expect(deployment).ToNot(BeNil()) + containers := deployment.Spec.Template.Spec.Containers + g.Expect(containers).To(HaveLen(2)) + g.Expect(containers[1].Name).To(Equal("endpoint-picker-shim")) +} diff --git a/internal/controller/provisioner/provisioner.go b/internal/controller/provisioner/provisioner.go index fe59f5be1b..8a2abffd0a 100644 --- a/internal/controller/provisioner/provisioner.go +++ b/internal/controller/provisioner/provisioner.go @@ -58,6 +58,7 @@ type Config struct { NginxDockerSecretNames []string NginxOneConsoleTelemetryConfig config.NginxOneConsoleTelemetryConfig Plus bool + InferenceExtension bool } // NginxProvisioner handles provisioning nginx kubernetes resources.