diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index c9c5f604f..0ece16dcd 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -78,15 +78,6 @@ var ( "enable-pprof", runserver.DefaultEnablePprof, "Enables pprof handlers. Defaults to true. Set to false to disable pprof handlers.") - destinationEndpointHintKey = flag.String( - "destination-endpoint-hint-key", - runserver.DefaultDestinationEndpointHintKey, - "Header and response metadata key used by Envoy to route to the appropriate pod. This must match Envoy configuration.") - destinationEndpointHintMetadataNamespace = flag.String( - "destination-endpoint-hint-metadata-namespace", - runserver.DefaultDestinationEndpointHintMetadataNamespace, - "The key for the outer namespace struct in the metadata field of the extproc response that is used to wrap the"+ - "target endpoint. If not set, then an outer namespace struct should not be created.") poolName = flag.String( "pool-name", runserver.DefaultPoolName, @@ -113,6 +104,20 @@ var ( "The path to the certificate for secure serving. The certificate and private key files "+ "are assumed to be named tls.crt and tls.key, respectively. If not set, and secureServing is enabled, "+ "then a self-signed certificate is used.") + // header/metadata flags + destinationEndpointHintKey = flag.String( + "destination-endpoint-hint-key", + runserver.DefaultDestinationEndpointHintKey, + "Header and response metadata key used by Envoy to route to the appropriate pod. This must match Envoy configuration.") + destinationEndpointHintMetadataNamespace = flag.String( + "destination-endpoint-hint-metadata-namespace", + runserver.DefaultDestinationEndpointHintMetadataNamespace, + "The key for the outer namespace struct in the metadata field of the extproc response that is used to wrap the"+ + "target endpoint. If not set, then an outer namespace struct should not be created.") + fairnessIDHeaderKey = flag.String( + "fairness-id-header-key", + runserver.DefaultFairnessIDHeaderKey, + "The header key used to pass the fairness ID to be used in Flow Control.") // metric flags totalQueuedRequestsMetric = flag.String( "total-queued-requests-metric", @@ -337,6 +342,7 @@ func (r *Runner) Run(ctx context.Context) error { GrpcPort: *grpcPort, DestinationEndpointHintMetadataNamespace: *destinationEndpointHintMetadataNamespace, DestinationEndpointHintKey: *destinationEndpointHintKey, + FairnessIDHeaderKey: *fairnessIDHeaderKey, PoolNamespacedName: poolNamespacedName, Datastore: datastore, SecureServing: *secureServing, diff --git a/config/charts/inferencepool/templates/epp-deployment.yaml b/config/charts/inferencepool/templates/epp-deployment.yaml index 5ab9e2d4f..23957a3da 100644 --- a/config/charts/inferencepool/templates/epp-deployment.yaml +++ b/config/charts/inferencepool/templates/epp-deployment.yaml @@ -33,6 +33,8 @@ spec: - "9002" - --grpc-health-port - "9003" + - --zap-encoder + - "json" - --metrics-port - "9090" - --config-file diff --git a/config/charts/inferencepool/templates/rbac.yaml b/config/charts/inferencepool/templates/rbac.yaml index a8d4d0ef3..a8d891c32 100644 --- a/config/charts/inferencepool/templates/rbac.yaml +++ b/config/charts/inferencepool/templates/rbac.yaml @@ -42,6 +42,9 @@ rules: - apiGroups: ["inference.networking.x-k8s.io"] resources: ["inferenceobjectives", "inferencepools"] verbs: ["get", "watch", "list"] +- apiGroups: ["inference.networking.k8s.io"] + resources: ["inferencepools"] + verbs: ["get", "watch", "list"] - apiGroups: [""] resources: ["pods"] verbs: ["get", "watch", "list"] diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 5789b3cc9..b528ed70d 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -57,6 +57,13 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ } else { reqCtx.Request.Headers[header.Key] = header.Value } + if header.Key == s.fairnessIDHeaderKey { + reqCtx.FairnessID = reqCtx.Request.Headers[header.Key] + // remove the fairness ID header from the request headers, + // this is not data that should be manipulated or sent to the backend. + // It is only used for flow control. + delete(reqCtx.Request.Headers, header.Key) + } } return nil } diff --git a/pkg/epp/handlers/request_test.go b/pkg/epp/handlers/request_test.go new file mode 100644 index 000000000..796082c00 --- /dev/null +++ b/pkg/epp/handlers/request_test.go @@ -0,0 +1,70 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package handlers + +import ( + "context" + "testing" + + configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" +) + +func TestHandleRequestHeaders(t *testing.T) { + t.Parallel() + + // Setup a mock server and request context + server := &StreamingServer{ + fairnessIDHeaderKey: "test-fairness-id", + } + + reqCtx := &RequestContext{ + Request: &Request{ + Headers: make(map[string]string), + }, + } + + req := &extProcPb.ProcessingRequest_RequestHeaders{ + RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{ + Headers: []*configPb.HeaderValue{ + { + Key: "x-test-header", + Value: "test-value", + }, + { + Key: "test-fairness-id", + Value: "test-fairness-id-value", + }, + }, + }, + EndOfStream: false, + }, + } + + err := server.HandleRequestHeaders(context.Background(), reqCtx, req) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if reqCtx.FairnessID != "test-fairness-id-value" { + t.Errorf("expected fairness ID to be 'test-fairness-id-value', got %s", reqCtx.FairnessID) + } + if reqCtx.Request.Headers["test-fairness-id"] == "test-fairness-id-value" { + t.Errorf("expected fairness ID header to be removed from request headers, but it was not") + } +} diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 30596606e..62ffc8c0b 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -44,10 +44,11 @@ const ( bodyByteLimit = 62000 ) -func NewStreamingServer(destinationEndpointHintMetadataNamespace, destinationEndpointHintKey string, datastore Datastore, director Director) *StreamingServer { +func NewStreamingServer(destinationEndpointHintMetadataNamespace, destinationEndpointHintKey, fairnessIDHeaderKey string, datastore Datastore, director Director) *StreamingServer { return &StreamingServer{ destinationEndpointHintMetadataNamespace: destinationEndpointHintMetadataNamespace, destinationEndpointHintKey: destinationEndpointHintKey, + fairnessIDHeaderKey: fairnessIDHeaderKey, director: director, datastore: datastore, } @@ -72,6 +73,7 @@ type StreamingServer struct { // The key acting as the outer namespace struct in the metadata extproc response to communicate // back the picked endpoints. destinationEndpointHintMetadataNamespace string + fairnessIDHeaderKey string datastore Datastore director Director } @@ -85,6 +87,7 @@ type RequestContext struct { TargetEndpoint string Model string ResolvedTargetModel string + FairnessID string RequestReceivedTimestamp time.Time ResponseCompleteTimestamp time.Time RequestSize int @@ -192,8 +195,8 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) switch v := req.Request.(type) { case *extProcPb.ProcessingRequest_RequestHeaders: - if requestId := requtil.ExtractHeaderValue(v, requtil.RequestIdHeaderKey); len(requestId) > 0 { - logger = logger.WithValues(requtil.RequestIdHeaderKey, requestId) + if requestID := requtil.ExtractHeaderValue(v, requtil.RequestIdHeaderKey); len(requestID) > 0 { + logger = logger.WithValues(requtil.RequestIdHeaderKey, requestID) loggerTrace = logger.V(logutil.TRACE) ctx = log.IntoContext(ctx, logger) } diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index 65232bed8..d2073eb65 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -139,7 +139,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo logger.V(logutil.DEBUG).Info("LLM request assembled") // --- 2. Admission Control check -- - if err := d.admitRequest(ctx, requestCriticality); err != nil { + if err := d.admitRequest(ctx, requestCriticality, reqCtx.FairnessID); err != nil { return reqCtx, err } @@ -166,9 +166,11 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo // admitRequest handles admission control to decide whether or not to accept the request // based on the request criticality and system saturation state. -func (d *Director) admitRequest(ctx context.Context, requestCriticality v1alpha2.Criticality) error { +func (d *Director) admitRequest(ctx context.Context, requestCriticality v1alpha2.Criticality, fairnessID string) error { logger := log.FromContext(ctx) + logger.V(logutil.TRACE).Info("Entering Flow Control", "criticality", requestCriticality, "fairnessID", fairnessID) + if requestCriticality == v1alpha2.Critical { logger.V(logutil.DEBUG).Info("Critical request bypassing saturation check.") return nil diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index 50053231f..2cbcbec0e 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -46,6 +46,7 @@ type ExtProcServerRunner struct { GrpcPort int DestinationEndpointHintMetadataNamespace string DestinationEndpointHintKey string + FairnessIDHeaderKey string PoolNamespacedName types.NamespacedName Datastore datastore.Datastore SecureServing bool @@ -63,24 +64,25 @@ type ExtProcServerRunner struct { // Default values for CLI flags in main const ( - DefaultGrpcPort = 9002 // default for --grpc-port - DefaultGrpcHealthPort = 9003 // default for --grpc-health-port - DefaultMetricsPort = 9090 // default for --metrics-port - DefaultDestinationEndpointHintMetadataNamespace = "envoy.lb" // default for --destinationEndpointHintMetadataNamespace - DefaultDestinationEndpointHintKey = "x-gateway-destination-endpoint" // default for --destination-endpoint-hint-key - DefaultPoolName = "" // required but no default - DefaultPoolNamespace = "default" // default for --pool-namespace - DefaultRefreshMetricsInterval = 50 * time.Millisecond // default for --refresh-metrics-interval - DefaultRefreshPrometheusMetricsInterval = 5 * time.Second // default for --refresh-prometheus-metrics-interval - DefaultSecureServing = true // default for --secure-serving - DefaultHealthChecking = false // default for --health-checking - DefaultEnablePprof = true // default for --enable-pprof - DefaultTotalQueuedRequestsMetric = "vllm:num_requests_waiting" // default for --total-queued-requests-metric - DefaultKvCacheUsagePercentageMetric = "vllm:gpu_cache_usage_perc" // default for --kv-cache-usage-percentage-metric - DefaultLoraInfoMetric = "vllm:lora_requests_info" // default for --lora-info-metric - DefaultCertPath = "" // default for --cert-path - DefaultConfigFile = "" // default for --config-file - DefaultConfigText = "" // default for --config-text + DefaultGrpcPort = 9002 // default for --grpc-port + DefaultGrpcHealthPort = 9003 // default for --grpc-health-port + DefaultMetricsPort = 9090 // default for --metrics-port + DefaultDestinationEndpointHintMetadataNamespace = "envoy.lb" // default for --destinationEndpointHintMetadataNamespace + DefaultDestinationEndpointHintKey = "x-gateway-destination-endpoint" // default for --destination-endpoint-hint-key + DefaultFairnessIDHeaderKey = "x-gateway-inference-fairness-id" // default for --fairness-id-header-key + DefaultPoolName = "" // required but no default + DefaultPoolNamespace = "default" // default for --pool-namespace + DefaultRefreshMetricsInterval = 50 * time.Millisecond // default for --refresh-metrics-interval + DefaultRefreshPrometheusMetricsInterval = 5 * time.Second // default for --refresh-prometheus-metrics-interval + DefaultSecureServing = true // default for --secure-serving + DefaultHealthChecking = false // default for --health-checking + DefaultEnablePprof = true // default for --enable-pprof + DefaultTotalQueuedRequestsMetric = "vllm:num_requests_waiting" // default for --total-queued-requests-metric + DefaultKvCacheUsagePercentageMetric = "vllm:gpu_cache_usage_perc" // default for --kv-cache-usage-percentage-metric + DefaultLoraInfoMetric = "vllm:lora_requests_info" // default for --lora-info-metric + DefaultCertPath = "" // default for --cert-path + DefaultConfigFile = "" // default for --config-file + DefaultConfigText = "" // default for --config-text DefaultMetricsStalenessThreshold = 2 * time.Second ) @@ -91,6 +93,7 @@ func NewDefaultExtProcServerRunner() *ExtProcServerRunner { GrpcPort: DefaultGrpcPort, DestinationEndpointHintKey: DefaultDestinationEndpointHintKey, DestinationEndpointHintMetadataNamespace: DefaultDestinationEndpointHintMetadataNamespace, + FairnessIDHeaderKey: DefaultFairnessIDHeaderKey, PoolNamespacedName: types.NamespacedName{Name: DefaultPoolName, Namespace: DefaultPoolNamespace}, SecureServing: DefaultSecureServing, HealthChecking: DefaultHealthChecking, @@ -159,6 +162,7 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { extProcServer := handlers.NewStreamingServer( r.DestinationEndpointHintMetadataNamespace, r.DestinationEndpointHintKey, + r.FairnessIDHeaderKey, r.Datastore, r.Director, ) diff --git a/pkg/epp/server/server_test.go b/pkg/epp/server/server_test.go index de934301d..fc075de38 100644 --- a/pkg/epp/server/server_test.go +++ b/pkg/epp/server/server_test.go @@ -38,6 +38,7 @@ const ( podAddress = "1.2.3.4" poolPort = int32(5678) destinationEndpointHintKey = "test-target" + fairnessIDHeaderKey = "x-fairness-id" namespace = "ns1" ) @@ -60,14 +61,18 @@ func TestServer(t *testing.T) { ctx, cancel, ds, _ := utils.PrepareForTestStreamingServer([]*v1alpha2.InferenceObjective{model}, []*v1.Pod{{ObjectMeta: metav1.ObjectMeta{Name: podName}}}, "test-pool1", namespace, poolPort) - streamingServer := handlers.NewStreamingServer(namespace, destinationEndpointHintKey, ds, director) + streamingServer := handlers.NewStreamingServer(namespace, destinationEndpointHintKey, fairnessIDHeaderKey, ds, director) testListener, errChan := utils.SetupTestStreamingServer(t, ctx, ds, streamingServer) process, conn := utils.GetStreamingServerClient(ctx, t) defer conn.Close() // Send request headers - no response expected - headers := utils.BuildEnvoyGRPCHeaders(map[string]string{requestHeader: theHeaderValue, ":method": "POST"}, true) + headers := utils.BuildEnvoyGRPCHeaders(map[string]string{ + requestHeader: theHeaderValue, + ":method": "POST", + fairnessIDHeaderKey: "a-very-interesting-fairness-id", + }, true) request := &pb.ProcessingRequest{ Request: &pb.ProcessingRequest_RequestHeaders{ RequestHeaders: headers,