From 5fd9820ce7405fdc7c5713d1dc08b229797e46ab Mon Sep 17 00:00:00 2001 From: Luke Van Drie Date: Sat, 20 Sep 2025 02:43:14 +0000 Subject: [PATCH] fix: streaming metrics and header parsing bugs This commit addresses several critical bugs discovered in the External Processing Proxy (EPP) that impact reliability, observability, and correctness, particularly for streaming use cases. **Bug Fixes:** 1. **Correct Streaming Token Metrics:** - **Problem:** For streaming responses (e.g., `text/event-stream`), token usage metrics were recorded incorrectly. The logic only inspected the final `[DONE]` message for a `usage` block, failing to accumulate token counts from earlier messages in the stream. Additionally, the `IncomingModelName` was being overwritten by a blank value from the request body, causing the `model_name` label in Prometheus to be empty. - **Fix:** The response handler now correctly accumulates token counts from all streaming chunks into the `RequestContext`. The final, accurate count is recorded only when the `[DONE]` message is received. The request handler logic was reordered to ensure headers (containing the model name) are always processed before the body, preventing the context from being corrupted. 2. **Robust Header Parsing:** - **Problem:** Multiple locations in the codebase exclusively checked the `RawValue` field of an Envoy `HeaderValue` message, ignoring the valid `Value` field. This caused failures in detecting the `content-type` for streaming and loss of the `x-request-id` for tracing if a client sent them in the `Value` field. - **Fix:** All header parsing logic has been updated to check both `RawValue` and `Value`, making it robust and compliant with the Envoy API. **Refactoring:** - **Hermetic Test Overhaul:** The integration test suite in `test/integration/epp/hermetic_test.go` has been completely refactored for reliability and clarity. - The old, monolithic, table-driven test has been replaced with a `testHarness` structure that provides each test case with its own isolated server instance, Kubernetes resources (scoped by a unique label), and gRPC client. - This eliminates test interference and makes the suite significantly more stable and maintainable. While true parallelism is still blocked by a global metrics registry in controller-runtime, this change achieves full resource and state isolation. - Test cases are now grouped by functionality (`RequestRouting`, `ResponseHandling`, etc.) with clear, descriptive names. - All associated test utilities and documentation have been polished to improve readability and maintainability. --- pkg/epp/handlers/request.go | 56 +- pkg/epp/handlers/response.go | 28 +- pkg/epp/handlers/server.go | 11 +- pkg/epp/util/request/headers.go | 6 +- test/integration/epp/hermetic_test.go | 1664 +++++++++++-------------- test/integration/util.go | 192 ++- 6 files changed, 920 insertions(+), 1037 deletions(-) diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 7f8122195..eabba753f 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -32,6 +32,37 @@ import ( func (s *StreamingServer) HandleRequestHeaders(reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error { reqCtx.RequestReceivedTimestamp = time.Now() + // Headers must be processed first to populate the request context as subsequent logic (like body processing) may + // depend upon it. + for _, header := range req.RequestHeaders.Headers.Headers { + key := header.Key + // Per the Envoy API, a header's value can be in either the `RawValue` (bytes) or `Value` (string) field. + if header.RawValue != nil { + reqCtx.Request.Headers[key] = string(header.RawValue) + } else { + reqCtx.Request.Headers[key] = header.Value + } + switch key { + case metadata.FlowFairnessIDKey: + reqCtx.FairnessID = reqCtx.Request.Headers[key] + // remove the fairness ID header from the request headers, + // this is not data that should be manipulated or sent to the backend. + // It is only used for flow control. + delete(reqCtx.Request.Headers, key) + case metadata.ObjectiveKey: + reqCtx.ObjectiveKey = reqCtx.Request.Headers[key] + reqCtx.IncomingModelName = reqCtx.ObjectiveKey + // remove the objective header from the request headers, + // this is not data that should be manipulated or sent to the backend. + delete(reqCtx.Request.Headers, key) + case metadata.ModelNameRewriteKey: + reqCtx.TargetModelName = reqCtx.Request.Headers[key] + // remove the rewrite header from the request headers, + // this is not data that should be manipulated or sent to the backend. + delete(reqCtx.Request.Headers, key) + } + } + // an EoS in the request headers means this request has no body or trailers. if req.RequestHeaders.EndOfStream { // We will route this request to a random pod as this is assumed to just be a GET @@ -55,31 +86,6 @@ func (s *StreamingServer) HandleRequestHeaders(reqCtx *RequestContext, req *extP return nil } - for _, header := range req.RequestHeaders.Headers.Headers { - if header.RawValue != nil { - reqCtx.Request.Headers[header.Key] = string(header.RawValue) - } else { - reqCtx.Request.Headers[header.Key] = header.Value - } - switch header.Key { - case metadata.FlowFairnessIDKey: - reqCtx.FairnessID = reqCtx.Request.Headers[header.Key] - // remove the fairness ID header from the request headers, - // this is not data that should be manipulated or sent to the backend. - // It is only used for flow control. - delete(reqCtx.Request.Headers, header.Key) - case metadata.ObjectiveKey: - reqCtx.ObjectiveKey = reqCtx.Request.Headers[header.Key] - // remove the objective header from the request headers, - // this is not data that should be manipulated or sent to the backend. - delete(reqCtx.Request.Headers, header.Key) - case metadata.ModelNameRewriteKey: - reqCtx.TargetModelName = reqCtx.Request.Headers[header.Key] - // remove the rewrite header from the request headers, - // this is not data that should be manipulated or sent to the backend. - delete(reqCtx.Request.Headers, header.Key) - } - } return nil } diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index 7dfaf3b2e..785e3e310 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -64,13 +64,31 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques return reqCtx, nil } -// The function is to handle streaming response if the modelServer is streaming. +// HandleResponseBodyModelStreaming processes a single chunk of a streaming response. +// It accumulates token usage over the entire stream and records the final metrics only when the end-of-stream message +// is detected. func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) { + // Parse the current chunk for a 'usage' block. + resp := parseRespForUsage(ctx, responseText) + + // Accumulate token counts. + // The 'usage' block typically appears in one of the last messages of a stream. + // We continuously update the context with the latest non-zero values we've seen. + if resp.Usage.PromptTokens > 0 { + reqCtx.Usage.PromptTokens = resp.Usage.PromptTokens + } + if resp.Usage.CompletionTokens > 0 { + reqCtx.Usage.CompletionTokens = resp.Usage.CompletionTokens + } + if resp.Usage.TotalTokens > 0 { + reqCtx.Usage.TotalTokens = resp.Usage.TotalTokens + } + + // Record metrics at the end of the stream. + // When we see the final "[DONE]" message, we record the total accumulated token counts from the context. if strings.Contains(responseText, streamingEndMsg) { - resp := parseRespForUsage(ctx, responseText) - reqCtx.Usage = resp.Usage - metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens) - metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens) + metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.Usage.PromptTokens) + metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.Usage.CompletionTokens) } } diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index ddfb3316c..0c1604c5e 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -190,8 +190,10 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) if len(requestID) == 0 { requestID = uuid.NewString() loggerTrace.Info("RequestID header is not found in the request, generated a request id") - reqCtx.Request.Headers[requtil.RequestIdHeaderKey] = requestID // update in headers so director can consume it } + // Ensure the request ID, whether pre-existing or newly generated, is in the context's header map. + // This makes it available to all downstream logic (e.g,. the director). + reqCtx.Request.Headers[requtil.RequestIdHeaderKey] = requestID logger = logger.WithValues(requtil.RequestIdHeaderKey, requestID) loggerTrace = logger.V(logutil.TRACE) ctx = log.IntoContext(ctx, logger) @@ -241,7 +243,12 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) // This is currently unused. case *extProcPb.ProcessingRequest_ResponseHeaders: for _, header := range v.ResponseHeaders.Headers.GetHeaders() { - value := string(header.RawValue) + var value string + if len(header.RawValue) > 0 { + value = string(header.RawValue) + } else { + value = header.Value + } loggerTrace.Info("header", "key", header.Key, "value", value) if header.Key == "status" && value != "200" { diff --git a/pkg/epp/util/request/headers.go b/pkg/epp/util/request/headers.go index b1936d31a..fb5915127 100644 --- a/pkg/epp/util/request/headers.go +++ b/pkg/epp/util/request/headers.go @@ -32,7 +32,11 @@ func ExtractHeaderValue(req *extProcPb.ProcessingRequest_RequestHeaders, headerK if req != nil && req.RequestHeaders != nil && req.RequestHeaders.Headers != nil { for _, headerKv := range req.RequestHeaders.Headers.Headers { if strings.ToLower(headerKv.Key) == headerKeyInLower { - return string(headerKv.RawValue) + // Per the Envoy API, a header's value can be in either the `RawValue` (bytes) or `Value` (string) field. + if len(headerKv.RawValue) > 0 { + return string(headerKv.RawValue) + } + return headerKv.Value } } } diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index 3dc42f8ba..5ae7b8034 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -14,13 +14,15 @@ See the License for the specific language governing permissions and limitations under the License. */ -// Package epp contains integration tests for the ext proc while faking the backend pods. +// Package epp contains hermetic integration tests for the External Processing Proxy (EPP), faking the backend pods to +// allow for precise control over their metrics and state. package epp import ( "bufio" "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -34,13 +36,16 @@ import ( extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3" "github.com/google/go-cmp/cmp" + "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/testing/protocmp" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/fields" + "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" @@ -54,7 +59,6 @@ import ( crconfig "sigs.k8s.io/controller-runtime/pkg/config" "sigs.k8s.io/controller-runtime/pkg/envtest" crmetrics "sigs.k8s.io/controller-runtime/pkg/metrics" - "sigs.k8s.io/controller-runtime/pkg/metrics/filters" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" "sigs.k8s.io/yaml" @@ -82,12 +86,9 @@ import ( ) const ( - // Test Infrastructure - testPoolName = "vllm-llama3-8b-instruct-pool" - testNamespace = "default" - testMetricsPort = 8889 + testPoolName = "vllm-llama3-8b-instruct-pool" + testNamespace = "default" - // Model Names modelMyModel = "my-model" modelMyModelTarget = "my-model-12345" modelSQLLora = "sql-lora" @@ -99,7 +100,6 @@ const ( var ( testGRPCAddress = fmt.Sprintf("localhost:%d", server.DefaultGrpcPort) - serverRunner *server.ExtProcServerRunner k8sClient k8sclient.Client testEnv *envtest.Environment scheme = runtime.NewScheme() @@ -114,19 +114,16 @@ func TestMain(m *testing.M) { } type label struct { - name, - value string + name, value string } func labelsToString(labels []label) string { var sb strings.Builder - i := 0 - for _, l := range labels { + for i, l := range labels { if i > 0 { sb.WriteString(",") } sb.WriteString(fmt.Sprintf("%s=%q", l.name, l.value)) - i++ } return sb.String() } @@ -147,937 +144,759 @@ func inferencePoolReadyPods(v int, labels []label) string { `, labelsToString(labels), v) } -func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { - tests := []struct { - name string - requests []*extProcPb.ProcessingRequest - pods map[*backend.Pod]*backendmetrics.MetricsState - wantResponses []*extProcPb.ProcessingResponse - wantMetrics map[string]string - wantErr bool - immediateResponse *extProcPb.ImmediateResponse - }{ - // Request flow tests - { - name: "select lower queue and kv cache, no active lora", - requests: integrationutils.GenerateStreamedRequestSet(logger, "test1", modelMyModel, modelMyModelTarget, nil), - // Pod 1 will be picked because it has relatively low queue size and low KV cache. - pods: newPodStates( - podState{index: 0, queueSize: 3, kvCacheUsage: 0.2}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.1}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.2}, - ), - wantMetrics: map[string]string{ - "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ - {"model_name", modelMyModel}, - {"target_model_name", modelMyModelTarget}, - }), - "inference_pool_ready_pods": inferencePoolReadyPods(3, []label{ - {"name", testPoolName}, - }), - }, - wantErr: false, - wantResponses: integrationutils.NewRequestBufferedResponse( - "192.168.1.2:8000", - fmt.Sprintf(`{"max_tokens":100,"model":%q,"prompt":"test1","temperature":0}`, modelMyModelTarget), - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "hi", - RawValue: []byte("mom"), - }, - }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: requtil.RequestIdHeaderKey, - RawValue: []byte("test-request-id"), - }, - }, - ), - }, - { - name: "invalid json; return body", - requests: []*extProcPb.ProcessingRequest{ - { - Request: &extProcPb.ProcessingRequest_RequestHeaders{ - RequestHeaders: &extProcPb.HttpHeaders{ - Headers: &configPb.HeaderMap{ - Headers: []*configPb.HeaderValue{ - { - Key: "hi", - Value: "mom", - }, - }, - }, - }, - }, - }, - { - Request: &extProcPb.ProcessingRequest_RequestBody{ - RequestBody: &extProcPb.HttpBody{Body: []byte("no healthy upstream"), EndOfStream: true}, - }, - }, - }, - // Pod 1 will be picked because it has relatively low queue size, the requested model active, and low KV cache. - pods: newPodStates( - podState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - ), - wantErr: false, - wantResponses: integrationutils.NewImmediateErrorResponse( - envoyTypePb.StatusCode_BadRequest, - "inference gateway: BadRequest - Error unmarshaling request body", - ), +// newRequestTotal creates the expected Prometheus metric string for the `inference_objective_request_total` counter. +func newRequestTotal(model, target string) string { + return inferenceObjectiveRequestTotal([]label{{"model_name", model}, {"target_model_name", target}}) +} + +// newReadyPods creates the expected Prometheus metric string for the `inference_pool_ready_pods` gauge. +func newReadyPods(count int) string { + return inferencePoolReadyPods(count, []label{{"name", testPoolName}}) +} + +// expectRouteTo constructs the full set of expected gRPC responses for a successful request that is routed to a +// specific backend endpoint. +func expectRouteTo(endpoint, targetModel, prompt string) []*extProcPb.ProcessingResponse { + bodyJSON := map[string]interface{}{ + "max_tokens": 100, + "model": targetModel, + "prompt": prompt, + "temperature": 0, + } + bodyBytes, _ := json.Marshal(bodyJSON) + + return integrationutils.NewRequestBufferedResponse( + endpoint, + string(bodyBytes), + &configPb.HeaderValueOption{Header: &configPb.HeaderValue{Key: "hi", RawValue: []byte("mom")}}, + &configPb.HeaderValueOption{Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, + RawValue: []byte("test-request-id"), + }}, + ) +} + +// testHarness encapsulates the setup and teardown for a single hermetic test run. +// It ensures that each test case runs with its own isolated environment, including a dedicated server runner, gRPC +// client, and Kubernetes resources. +type testHarness struct { + t *testing.T + runner *server.ExtProcServerRunner + client extProcPb.ExternalProcessor_ProcessClient + cancel func() +} + +// newTestHarness creates and initializes all components for a test. +// It's the factory for creating a fully isolated test environment (sans controller-runtime Manager and its global +// metrics registry). +func newTestHarness( + t *testing.T, + podAndMetrics map[*backend.Pod]*backendmetrics.MetricsState, + sdConfig *saturationdetector.Config, + + uniqueSuffix string, +) *testHarness { + runner, client, serverCancel, conn, clientCancel := setupTestInfrastructure(t, podAndMetrics, sdConfig, uniqueSuffix) + return &testHarness{ + t: t, + runner: runner, + client: client, + cancel: func() { + clientCancel() + conn.Close() + serverCancel() + + for pod := range podAndMetrics { + podObj := epptestutil.MakePod(pod.NamespacedName.Name). + Namespace(pod.NamespacedName.Namespace).Complete().ObjRef() + if err := k8sClient.Delete(context.Background(), podObj); err != nil { + t.Logf("Failed to delete pod %s: %v", podObj.GetName(), err) + } + } }, - { - name: "select active lora, low queue", - requests: integrationutils.GenerateStreamedRequestSet(logger, "test2", modelSQLLora, modelSQLLoraTarget, nil), - // Pod 1 will be picked because it has relatively low queue size, the requested model active, and low KV cache. - pods: newPodStates( - podState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - ), + } +} - wantMetrics: map[string]string{ - "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ - {"model_name", modelSQLLora}, - {"target_model_name", modelSQLLoraTarget}, - }), - }, - wantErr: false, - wantResponses: integrationutils.NewRequestBufferedResponse( - "192.168.1.2:8000", - fmt.Sprintf(`{"max_tokens":100,"model":%q,"prompt":"test2","temperature":0}`, modelSQLLoraTarget), - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "hi", - RawValue: []byte("mom"), - }, +func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { + defaultPods := []podState{{index: 0}} + type testCase struct { + name string + requests []*extProcPb.ProcessingRequest + pods []podState + wantResponses []*extProcPb.ProcessingResponse + wantMetrics map[string]string + wantErr bool + } + + // runTestCases is a generic engine for executing a slice of test cases. + // It handles the boilerplate of setting up the test harness, running the test, validating the response, and checking + // metrics for each case. + runTestCases := func(t *testing.T, testCases []testCase, sdConfig *saturationdetector.Config) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // These integration tests are run serially. + // The controller-runtime Manager registers metrics to a global registry. + // We ensure correctness by resetting the global metrics for each test run. + metrics.Reset() + t.Cleanup(metrics.Reset) + + uniqueSuffix := uuid.NewString()[:8] + h := newTestHarness(t, newPodStates(uniqueSuffix, tc.pods...), sdConfig, uniqueSuffix) + t.Cleanup(h.cancel) + + responses, err := integrationutils.StreamedRequest(t, h.client, tc.requests, len(tc.wantResponses)) + validateResponse(t, err, tc.wantErr, tc.wantResponses, responses) + + errs := make(map[string]error) + assert.Eventually(t, func() bool { + for metricName, value := range tc.wantMetrics { + if err := metricsutils.GatherAndCompare( + crmetrics.Registry, + strings.NewReader(value), + metricName, + ); err != nil { + errs[metricName] = err + return false + } + } + return true + }, 5*time.Second, 100*time.Millisecond, "failed to match all expected metrics") + if len(errs) > 0 { + for metricName, err := range errs { + t.Logf("Metric comparison failed for %s: %v", metricName, err) + } + } + }) + } + } + + t.Run("RequestRouting", func(t *testing.T) { + testCases := []testCase{ + { + name: "selects pod with lower queue and kv cache", + requests: integrationutils.GenerateStreamedRequestSetWithHeaders(logger, "test1", modelMyModel, + modelMyModelTarget, nil, nil), + pods: []podState{ + {index: 0, queueSize: 3, kvCacheUsage: 0.2}, + {index: 1, queueSize: 0, kvCacheUsage: 0.1}, + {index: 2, queueSize: 10, kvCacheUsage: 0.2}, }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: requtil.RequestIdHeaderKey, - RawValue: []byte("test-request-id"), - }, + wantResponses: expectRouteTo("192.168.1.2:8000", modelMyModelTarget, "test1"), + wantMetrics: map[string]string{ + "inference_objective_request_total": newRequestTotal(modelMyModel, modelMyModelTarget), + "inference_pool_ready_pods": newReadyPods(3), }, - ), - }, - { - name: "select lora despite higher kv cache usage", - requests: integrationutils.GenerateStreamedRequestSet(logger, "test3", modelSQLLora, modelSQLLoraTarget, nil), - // Pod 2 will be picked despite NOT having the requested model active as it is above the affinity for queue size. - // Also it is critical, so we should still admit the request despite all queue sizes being greater than the queue - // size threshold. - pods: newPodStates( - podState{index: 0, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - podState{index: 1, queueSize: 10, kvCacheUsage: 0.4, activeModels: []string{"foo", modelSQLLoraTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.3, activeModels: []string{"foo"}}, - ), - wantMetrics: map[string]string{ - "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ - {"model_name", modelSQLLora}, - {"target_model_name", modelSQLLoraTarget}, - }), }, - wantErr: false, - wantResponses: integrationutils.NewRequestBufferedResponse( - "192.168.1.2:8000", - fmt.Sprintf(`{"max_tokens":100,"model":%q,"prompt":"test3","temperature":0}`, modelSQLLoraTarget), - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "hi", - RawValue: []byte("mom"), - }, + { + name: "selects pod with active lora and low queue", + requests: integrationutils.GenerateStreamedRequestSetWithHeaders(logger, "test2", modelSQLLora, + modelSQLLoraTarget, nil, nil), + pods: []podState{ + {index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo"}}, + {index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, + {index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo"}}, }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: requtil.RequestIdHeaderKey, - RawValue: []byte("test-request-id"), - }, + wantResponses: expectRouteTo("192.168.1.2:8000", modelSQLLoraTarget, "test2"), + wantMetrics: map[string]string{ + "inference_objective_request_total": newRequestTotal(modelSQLLora, modelSQLLoraTarget), }, - ), - }, - { - name: "don't shed requests by default", - requests: integrationutils.GenerateStreamedRequestSet(logger, "test4", modelSQLLora, modelSQLLoraTarget, nil), - // pod 0: excluded; above queue size threshold - // pod 1: excluded; above KV cache threshold - // pod 2: excluded; above queue size threshold - pods: newPodStates( - podState{index: 0, queueSize: 6, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSQLLoraTarget}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo"}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo"}}, - ), - wantMetrics: map[string]string{ - "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ - {"model_name", modelSQLLora}, - {"target_model_name", modelSQLLoraTarget}, - }), }, - wantErr: false, - wantResponses: integrationutils.NewRequestBufferedResponse( - "192.168.1.1:8000", - fmt.Sprintf(`{"max_tokens":100,"model":%q,"prompt":"test4","temperature":0}`, modelSQLLoraTarget), - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "hi", - RawValue: []byte("mom"), - }, - }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: requtil.RequestIdHeaderKey, - RawValue: []byte("test-request-id"), - }, - }, - ), - }, - { - name: "body sent over multiple requests, noncritical, but one server has capacity, do not shed", - requests: []*extProcPb.ProcessingRequest{ - { - Request: &extProcPb.ProcessingRequest_RequestHeaders{ - RequestHeaders: &extProcPb.HttpHeaders{ - Headers: &configPb.HeaderMap{ - Headers: []*configPb.HeaderValue{ - { - Key: "hi", - Value: "mom", - }, - { - Key: metadata.ObjectiveKey, - Value: modelSheddable, - }, - { - Key: metadata.ModelNameRewriteKey, - Value: modelSheddableTarget, - }, - { - Key: requtil.RequestIdHeaderKey, - Value: "test-request-id", - }, - }, - }, - }, - }, - }, { - Request: &extProcPb.ProcessingRequest_RequestBody{ - RequestBody: &extProcPb.HttpBody{Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lo"), EndOfStream: false}, - }, + { + name: "selects pod with lora affinity despite higher kv cache", + requests: integrationutils.GenerateStreamedRequestSetWithHeaders(logger, "test3", modelSQLLora, + modelSQLLoraTarget, nil, nil), + pods: []podState{ + {index: 0, queueSize: 10, kvCacheUsage: 0.2}, + {index: 1, queueSize: 10, kvCacheUsage: 0.4, activeModels: []string{modelSQLLoraTarget}}, + {index: 2, queueSize: 10, kvCacheUsage: 0.3}, }, - { - Request: &extProcPb.ProcessingRequest_RequestBody{ - RequestBody: &extProcPb.HttpBody{Body: []byte("ra-sheddable\",\"prompt\":\"test6\",\"temperature\":0}"), EndOfStream: true}, - }, - }, - }, - // Pod 1 will be picked because it has relatively low queue size and low KV cache. - pods: newPodStates( - podState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, - podState{index: 1, queueSize: 4, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, - ), - wantMetrics: map[string]string{ - "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ - {"model_name", modelSheddable}, - {"target_model_name", modelSheddableTarget}, - }), + wantResponses: expectRouteTo("192.168.1.2:8000", modelSQLLoraTarget, "test3"), + wantMetrics: map[string]string{"inference_objective_request_total": newRequestTotal(modelSQLLora, + modelSQLLoraTarget)}, }, - wantErr: false, - wantResponses: integrationutils.NewRequestBufferedResponse( - "192.168.1.1:8000", - fmt.Sprintf(`{"max_tokens":100,"model":%q,"prompt":"test6","temperature":0}`, modelSheddableTarget), - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "hi", - RawValue: []byte("mom"), - }, + { + name: "routes to least-saturated pod when all pods are under high load", + requests: integrationutils.GenerateStreamedRequestSetWithHeaders(logger, "test4", modelSQLLora, + modelSQLLoraTarget, nil, nil), + pods: []podState{ + {index: 0, queueSize: 6, kvCacheUsage: 0.2, activeModels: []string{modelSQLLoraTarget}}, + {index: 1, queueSize: 0, kvCacheUsage: 0.85}, + {index: 2, queueSize: 10, kvCacheUsage: 0.9}, }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: requtil.RequestIdHeaderKey, - RawValue: []byte("test-request-id"), - }, + wantResponses: expectRouteTo("192.168.1.1:8000", modelSQLLoraTarget, "test4"), + wantMetrics: map[string]string{ + "inference_objective_request_total": newRequestTotal(modelSQLLora, modelSQLLoraTarget), }, - ), - }, - { - name: "inferenceobjective's modelName is not translated, passthrough", - requests: []*extProcPb.ProcessingRequest{ - { - Request: &extProcPb.ProcessingRequest_RequestHeaders{ - RequestHeaders: &extProcPb.HttpHeaders{ - Headers: &configPb.HeaderMap{ - Headers: []*configPb.HeaderValue{ - { - Key: "hi", - Value: "mom", - }, - { - Key: metadata.ObjectiveKey, - Value: modelDirect, - }, - { - Key: metadata.ModelNameRewriteKey, - Value: modelDirect, - }, - { - Key: metadata.ModelNameRewriteKey, - Value: modelDirect, - }, - { - Key: requtil.RequestIdHeaderKey, - Value: "test-request-id", - }, - }, - }, - }, - }, - }, - { - Request: &extProcPb.ProcessingRequest_RequestBody{ - RequestBody: &extProcPb.HttpBody{Body: []byte("{\"max_tokens\":100,\"model\":\"direct-"), EndOfStream: false}, - }, + }, + { + name: "passthrough for models not defined in objectives", + requests: integrationutils.GenerateStreamedRequestSetWithHeaders(logger, "test6", modelDirect, modelDirect, nil, + map[string]string{metadata.ModelNameRewriteKey: modelDirect}), + pods: []podState{ + {index: 0, queueSize: 0, kvCacheUsage: 0.1}, + {index: 1, queueSize: 5, kvCacheUsage: 0.85}, + {index: 2, queueSize: 10, kvCacheUsage: 0.9}, }, - { - Request: &extProcPb.ProcessingRequest_RequestBody{ - RequestBody: &extProcPb.HttpBody{Body: []byte("model\",\"prompt\":\"test6\",\"temperature\":0}"), EndOfStream: true}, - }, + wantResponses: expectRouteTo("192.168.1.1:8000", modelDirect, "test6"), + wantMetrics: map[string]string{ + "inference_objective_request_total": newRequestTotal(modelDirect, modelDirect), }, }, - // pod 0: selected due to low queue size and kv cache usage - pods: newPodStates( - podState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, - ), - wantMetrics: map[string]string{ - "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ - {"model_name", modelDirect}, - {"target_model_name", modelDirect}, - }), - }, - wantErr: false, - wantResponses: integrationutils.NewRequestBufferedResponse( - "192.168.1.1:8000", - fmt.Sprintf(`{"max_tokens":100,"model":%q,"prompt":"test6","temperature":0}`, modelDirect), - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "hi", - RawValue: []byte("mom"), - }, - }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ + { + name: "routes request with multi-chunk body", + requests: integrationutils.GenerateStreamedRequestSetWithHeaders(logger, "test", modelMyModel, + modelMyModelTarget, nil, nil), + pods: defaultPods, + wantResponses: integrationutils.NewRequestBufferedResponse( + "192.168.1.1:8000", + `{"max_tokens":100,"model":"my-model-12345","prompt":"test","temperature":0}`, + &configPb.HeaderValueOption{Header: &configPb.HeaderValue{Key: "hi", RawValue: []byte("mom")}}, + &configPb.HeaderValueOption{Header: &configPb.HeaderValue{ Key: requtil.RequestIdHeaderKey, RawValue: []byte("test-request-id"), - }, - }, - ), - }, - // Response flow tests - { - name: "responsebody sent over multiple requests, content-type is json, buffer", - requests: []*extProcPb.ProcessingRequest{ - { - Request: &extProcPb.ProcessingRequest_ResponseHeaders{ - ResponseHeaders: &extProcPb.HttpHeaders{ - Headers: &configPb.HeaderMap{ - Headers: []*configPb.HeaderValue{ - { - Key: "content-type", - Value: "application/json", - }, - }, - }, - }, - }, - }, - { - Request: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lo"), EndOfStream: false}, - }, - }, - { - Request: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{Body: []byte("ra-sheddable\",\"prompt\":\"test6\",\"temperature\":0}"), EndOfStream: true}, - }, + }}, + ), + wantMetrics: map[string]string{ + "inference_objective_request_total": newRequestTotal(modelMyModel, modelMyModelTarget), }, }, - // pod 0: selected - // pod 1: excluded; above KV cache threshold - // pod 2: excluded; above queue size threshold - pods: newPodStates( - podState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, - ), - wantErr: false, - wantResponses: integrationutils.NewResponseBufferedResponse( - fmt.Sprintf(`{"max_tokens":100,"model":%q,"prompt":"test6","temperature":0}`, modelSheddable), - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "x-went-into-resp-headers", - RawValue: []byte("true"), - }, - }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "content-type", - RawValue: []uint8("application/json"), - }, - }, - ), - }, - { - name: "Response is invalid json; return body", - requests: []*extProcPb.ProcessingRequest{ - { - Request: &extProcPb.ProcessingRequest_ResponseHeaders{ - ResponseHeaders: &extProcPb.HttpHeaders{ - Headers: &configPb.HeaderMap{ - Headers: []*configPb.HeaderValue{ - { - Key: "content-type", - Value: "application/json", - }, - }, - }, + } + runTestCases(t, testCases, nil) + }) + + t.Run("ResponseHandling", func(t *testing.T) { + testCases := []testCase{ + { + name: "buffers and rewrites multi-chunk json response", + requests: []*extProcPb.ProcessingRequest{ + {Request: &extProcPb.ProcessingRequest_RequestHeaders{RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{Headers: []*configPb.HeaderValue{ + {Key: metadata.ObjectiveKey, Value: modelSheddable}, + {Key: metadata.ModelNameRewriteKey, Value: modelSheddableTarget}, + {Key: requtil.RequestIdHeaderKey, Value: "test-static-id-1"}, }, - }, + }, EndOfStream: true}}}, + {Request: &extProcPb.ProcessingRequest_ResponseHeaders{ResponseHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{Headers: []*configPb.HeaderValue{ + {Key: "content-type", Value: "application/json"}, + }, + }}}}, + {Request: &extProcPb.ProcessingRequest_ResponseBody{ResponseBody: &extProcPb.HttpBody{ + Body: []byte(`{"model":"` + modelSheddable + `", "prompt": "test"}`), EndOfStream: false}, + }}, + {Request: &extProcPb.ProcessingRequest_ResponseBody{ResponseBody: &extProcPb.HttpBody{ + Body: []byte(`}`), EndOfStream: true}, + }}, }, - { - Request: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{Body: []byte("no healthy upstream"), EndOfStream: true}, + pods: defaultPods, + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{SetHeaders: []*configPb.HeaderValueOption{ + {Header: &configPb.HeaderValue{ + Key: metadata.DestinationEndpointKey, + RawValue: []byte("192.168.1.1:8000"), + }}, + {Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, + RawValue: []byte("test-static-id-1"), + }}, + }}, + }, + }}, + DynamicMetadata: integrationutils.MakeMetadata("192.168.1.1:8000"), }, + integrationutils.NewResponseHeaders( + &configPb.HeaderValueOption{Header: &configPb.HeaderValue{ + Key: "content-type", + RawValue: []byte("application/json"), + }}, + &configPb.HeaderValueOption{Header: &configPb.HeaderValue{ + Key: "x-went-into-resp-headers", + RawValue: []byte("true"), + }}, + ), + integrationutils.NewResponseStreamChunk(`{"model":"`+modelSheddable+`","prompt":"test"}`, true), }, }, - // pod 0: selected - // pod 1: excluded; above KV cache threshold - // pod 2: excluded; above queue size threshold - pods: newPodStates( - podState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, - ), - wantErr: false, - wantResponses: integrationutils.NewResponseBufferedResponse( - "no healthy upstream", - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "x-went-into-resp-headers", - RawValue: []byte("true"), - }, - }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "content-type", - RawValue: []uint8("application/json"), - }, - }, - ), - }, - { - name: "responsebody sent over a single request, but empty body with EndOfStream in the second request(this is how envoy operates); content-type is json, buffer", - requests: []*extProcPb.ProcessingRequest{ - { - Request: &extProcPb.ProcessingRequest_ResponseHeaders{ - ResponseHeaders: &extProcPb.HttpHeaders{ - Headers: &configPb.HeaderMap{ - Headers: []*configPb.HeaderValue{ - { - Key: "content-type", - Value: "application/json", - }, - }, - }, + { + name: "handles invalid json in response body", + requests: []*extProcPb.ProcessingRequest{ + {Request: &extProcPb.ProcessingRequest_RequestHeaders{RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{Headers: []*configPb.HeaderValue{ + {Key: metadata.ObjectiveKey, Value: modelSheddable}, + {Key: metadata.ModelNameRewriteKey, Value: modelSheddableTarget}, + {Key: requtil.RequestIdHeaderKey, Value: "test-static-id-2"}, }, - }, + }, EndOfStream: true}}}, + {Request: &extProcPb.ProcessingRequest_ResponseHeaders{ResponseHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{Headers: []*configPb.HeaderValue{ + {Key: "content-type", Value: "application/json"}, + }}, + }}}, + {Request: &extProcPb.ProcessingRequest_ResponseBody{ResponseBody: &extProcPb.HttpBody{ + Body: []byte(`not valid json`), EndOfStream: true, + }}}, }, - { - Request: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-sheddable\",\"prompt\":\"test6\",\"temperature\":0}"), EndOfStream: false}, - }, - }, - { - Request: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{Body: []byte(""), EndOfStream: true}, + pods: defaultPods, + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{SetHeaders: []*configPb.HeaderValueOption{ + {Header: &configPb.HeaderValue{ + Key: metadata.DestinationEndpointKey, + RawValue: []byte("192.168.1.1:8000"), + }}, + {Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, + RawValue: []byte("test-static-id-2"), + }}, + }}, + }, + }}, + DynamicMetadata: integrationutils.MakeMetadata("192.168.1.1:8000"), }, + integrationutils.NewResponseHeaders( + &configPb.HeaderValueOption{Header: &configPb.HeaderValue{ + Key: "content-type", + RawValue: []byte("application/json"), + }}, + &configPb.HeaderValueOption{Header: &configPb.HeaderValue{ + Key: "x-went-into-resp-headers", + RawValue: []byte("true"), + }}, + ), + integrationutils.NewResponseStreamChunk(`not valid json`, true), }, }, - // pod 0: selected - // pod 1: excluded; above KV cache threshold - // pod 2: excluded; above queue size threshold - pods: newPodStates( - podState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, - ), - wantErr: false, - wantResponses: integrationutils.NewResponseBufferedResponse( - fmt.Sprintf(`{"max_tokens":100,"model":%q,"prompt":"test6","temperature":0}`, modelSheddable), - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "x-went-into-resp-headers", - RawValue: []byte("true"), - }, + { + name: "handles single chunk response followed by empty EOS chunk", + requests: []*extProcPb.ProcessingRequest{ + {Request: &extProcPb.ProcessingRequest_RequestHeaders{RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{Headers: []*configPb.HeaderValue{ + {Key: metadata.ObjectiveKey, Value: modelSheddable}, + {Key: metadata.ModelNameRewriteKey, Value: modelSheddableTarget}, + {Key: requtil.RequestIdHeaderKey, Value: "test-static-id-3"}}, + }, EndOfStream: true}}}, + {Request: &extProcPb.ProcessingRequest_ResponseHeaders{ResponseHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{Headers: []*configPb.HeaderValue{ + {Key: "content-type", Value: "application/json"}, + }}, + }}}, + {Request: &extProcPb.ProcessingRequest_ResponseBody{ResponseBody: &extProcPb.HttpBody{ + Body: []byte(`{"model":"` + modelSheddableTarget + `"}`), EndOfStream: false, + }}}, + {Request: &extProcPb.ProcessingRequest_ResponseBody{ResponseBody: &extProcPb.HttpBody{ + Body: []byte(""), EndOfStream: true, + }}}, }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "content-type", - RawValue: []uint8("application/json"), - }, - }, - ), - }, - { - name: "responsebody sent over a single request, but empty body with EndOfStream in the second request(this is how envoy operates); content-type is json, buffer", - requests: []*extProcPb.ProcessingRequest{ - { - Request: &extProcPb.ProcessingRequest_ResponseHeaders{ - ResponseHeaders: &extProcPb.HttpHeaders{ - Headers: &configPb.HeaderMap{ - Headers: []*configPb.HeaderValue{ - { - Key: "content-type", - RawValue: []byte("text/event-stream"), - }, - { - Key: "status", - RawValue: []byte("200"), - }, - }, + pods: defaultPods, + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{SetHeaders: []*configPb.HeaderValueOption{ + {Header: &configPb.HeaderValue{ + Key: metadata.DestinationEndpointKey, RawValue: []byte("192.168.1.1:8000"), + }}, + {Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, RawValue: []byte("test-static-id-3"), + }}, + }}, }, - }, - }, - }, - { - Request: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{ - Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"food-review-1","choices":[{"index":0,"text":"NEVER","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), - EndOfStream: false}, - }, - }, - { - Request: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{ - Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"food-review-1","choices":[{"index":0,"text":"GONNA","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), - EndOfStream: false}, - }, - }, - { - Request: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{ - Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"food-review-1","choices":[{"index":0,"text":"GIVE","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), - EndOfStream: false}, - }, - }, - { - Request: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{ - Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"food-review-1","choices":[{"index":0,"text":"YOU","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), - EndOfStream: false}, - }, - }, - { - Request: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{ - Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"food-review-1","choices":[{"index":0,"text":"UP","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), - EndOfStream: false}, - }, - }, - { - Request: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{ - Body: []byte("data: {\"id\":\"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9\",\"object\":\"text_completion\",\"created\":1741379018,\"model\":\"food-review-1\",\"choices\":[],\"usage\":{\"prompt_tokens\":7,\"total_tokens\":17,\"completion_tokens\":10}}\ndata: [DONE]"), - EndOfStream: false}, - }, - }, - { - Request: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{ - Body: []byte(""), - EndOfStream: true}, + }}, + DynamicMetadata: integrationutils.MakeMetadata("192.168.1.1:8000"), }, + integrationutils.NewResponseHeaders( + &configPb.HeaderValueOption{Header: &configPb.HeaderValue{ + Key: "content-type", RawValue: []byte("application/json"), + }}, + &configPb.HeaderValueOption{Header: &configPb.HeaderValue{ + Key: "x-went-into-resp-headers", RawValue: []byte("true"), + }}, + ), + integrationutils.NewResponseStreamChunk(`{"model":"`+modelSheddableTarget+`"}`, true), }, }, - wantMetrics: map[string]string{`inference_objective_input_tokens`: ` + { + name: "passes through and counts tokens in event-stream response", + requests: []*extProcPb.ProcessingRequest{ + {Request: &extProcPb.ProcessingRequest_RequestHeaders{RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{Headers: []*configPb.HeaderValue{ + {Key: metadata.ObjectiveKey, Value: modelSheddable}, + {Key: metadata.ModelNameRewriteKey, Value: modelSheddableTarget}, + {Key: requtil.RequestIdHeaderKey, Value: "test-static-id-4"}, + }}, EndOfStream: true, + }}}, + {Request: &extProcPb.ProcessingRequest_ResponseHeaders{ResponseHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{Headers: []*configPb.HeaderValue{ + {Key: "content-type", Value: "text/event-stream"}, + }}, + }}}, + {Request: &extProcPb.ProcessingRequest_ResponseBody{ResponseBody: &extProcPb.HttpBody{ + Body: []byte(`data: {"usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}}`), + EndOfStream: false, + }}}, + {Request: &extProcPb.ProcessingRequest_ResponseBody{ResponseBody: &extProcPb.HttpBody{ + Body: []byte("\ndata: [DONE]"), EndOfStream: true, + }}}, + }, + pods: defaultPods, + wantMetrics: map[string]string{`inference_objective_input_tokens`: ` # HELP inference_objective_input_tokens [ALPHA] Inference objective input token count distribution for requests in each model. # TYPE inference_objective_input_tokens histogram - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="1"} 0 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="8"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="16"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="32"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="64"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="128"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="256"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="512"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="1024"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="2048"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="4096"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="8192"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="16384"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="32778"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="65536"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="131072"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="262144"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="524288"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="1.048576e+06"} 1 - inference_objective_input_tokens_bucket{model_name="",target_model_name="",le="+Inf"} 1 - inference_objective_input_tokens_sum{model_name="",target_model_name=""} 7 - inference_objective_input_tokens_count{model_name="",target_model_name=""} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="1"} 0 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="8"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="16"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="32"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="64"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="128"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="256"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="512"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="1024"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="2048"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="4096"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="8192"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="16384"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="32778"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="65536"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="131072"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="262144"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="524288"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="1.048576e+06"} 1 + inference_objective_input_tokens_bucket{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3",le="+Inf"} 1 + inference_objective_input_tokens_sum{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3"} 7 + inference_objective_input_tokens_count{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3"} 1 `}, - wantResponses: []*extProcPb.ProcessingResponse{ - integrationutils.NewResponseHeaders( - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "x-went-into-resp-headers", - RawValue: []byte("true"), - }, + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{SetHeaders: []*configPb.HeaderValueOption{ + {Header: &configPb.HeaderValue{ + Key: metadata.DestinationEndpointKey, + RawValue: []byte("192.168.1.1:8000"), + }}, + {Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, + RawValue: []byte("test-static-id-4"), + }}, + }}, + }, + }}, + DynamicMetadata: integrationutils.MakeMetadata("192.168.1.1:8000"), }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ + integrationutils.NewResponseHeaders( + &configPb.HeaderValueOption{Header: &configPb.HeaderValue{ Key: "content-type", RawValue: []byte("text/event-stream"), - }, - }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "status", - RawValue: []byte("200"), - }, - }, - ), - integrationutils.NewResponseStreamChunk(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"food-review-1","choices":[{"index":0,"text":"NEVER","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`, false), - integrationutils.NewResponseStreamChunk(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"food-review-1","choices":[{"index":0,"text":"GONNA","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`, false), - integrationutils.NewResponseStreamChunk(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"food-review-1","choices":[{"index":0,"text":"GIVE","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`, false), - integrationutils.NewResponseStreamChunk(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"food-review-1","choices":[{"index":0,"text":"YOU","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`, false), - integrationutils.NewResponseStreamChunk(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"food-review-1","choices":[{"index":0,"text":"UP","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`, false), - integrationutils.NewResponseStreamChunk("data: {\"id\":\"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9\",\"object\":\"text_completion\",\"created\":1741379018,\"model\":\"food-review-1\",\"choices\":[],\"usage\":{\"prompt_tokens\":7,\"total_tokens\":17,\"completion_tokens\":10}}\ndata: [DONE]", false), - integrationutils.NewResponseStreamChunk("", true), - }, - }, - // Bodyless Request test - { - name: "simple GET Request", - requests: []*extProcPb.ProcessingRequest{ - { - Request: &extProcPb.ProcessingRequest_RequestHeaders{ - RequestHeaders: &extProcPb.HttpHeaders{ - Headers: &configPb.HeaderMap{ - Headers: []*configPb.HeaderValue{ - { - Key: "content-type", - RawValue: []byte("text/event-stream"), - }, - { - Key: "status", - RawValue: []byte("200"), - }, - }, - }, - EndOfStream: true, - }, - }, + }}, + &configPb.HeaderValueOption{Header: &configPb.HeaderValue{ + Key: "x-went-into-resp-headers", + RawValue: []byte("true"), + }}, + ), + integrationutils.NewResponseStreamChunk(`data: {"usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}}`, false), + integrationutils.NewResponseStreamChunk("\ndata: [DONE]", true), }, }, - wantResponses: []*extProcPb.ProcessingResponse{}, - pods: newPodStates( - podState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, - ), - wantMetrics: map[string]string{}, - }, - { - name: "select active lora with subsetting tag, all pods available", - requests: integrationutils.GenerateStreamedRequestSet( - logger, - "test2", - modelSQLLora, - modelSQLLoraTarget, - []string{"192.168.1.1:8000", "192.168.1.2:8000", "192.168.1.3:8000"}), - // Pod 1 will be picked because it has relatively low queue size, the requested model active, low KV cache, and within subset. - pods: newPodStates( - podState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - ), + } + runTestCases(t, testCases, nil) + }) - wantMetrics: map[string]string{ - "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ - {"model_name", modelSQLLora}, - {"target_model_name", modelSQLLoraTarget}, - }), - }, - wantErr: false, - wantResponses: integrationutils.NewRequestBufferedResponse( - "192.168.1.2:8000", - fmt.Sprintf(`{"max_tokens":100,"model":%q,"prompt":"test2","temperature":0}`, modelSQLLoraTarget), - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "hi", - RawValue: []byte("mom"), - }, - }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: requtil.RequestIdHeaderKey, - RawValue: []byte("test-request-id"), - }, + t.Run("Subsetting", func(t *testing.T) { + testCases := []testCase{ + { + name: "selects best pod from available subset", + requests: integrationutils.GenerateStreamedRequestSetWithHeaders(logger, "subset-test", modelSQLLora, + modelSQLLoraTarget, []string{"192.168.1.1:8000", "192.168.1.2:8000"}, nil), + pods: []podState{ + {index: 0, queueSize: 5, kvCacheUsage: 0.2}, + {index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{modelSQLLoraTarget}}, + {index: 2, queueSize: 0, kvCacheUsage: 0.1}, }, - ), - }, - { - name: "select active lora with subsetting tag, some pods match", - requests: integrationutils.GenerateStreamedRequestSet( - logger, - "test2", - modelSQLLora, - modelSQLLoraTarget, - []string{"192.168.1.3:8000"}), - // Pod 3 has high queue and kv cache utilization, but it will still be picked because it is the only one matching subsetting target. - pods: newPodStates( - podState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - ), - - wantMetrics: map[string]string{ - "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ - {"model_name", modelSQLLora}, - {"target_model_name", modelSQLLoraTarget}, - }), + wantResponses: expectRouteTo("192.168.1.2:8000", modelSQLLoraTarget, "subset-test"), }, - wantErr: false, - wantResponses: integrationutils.NewRequestBufferedResponse( - "192.168.1.3:8000", - fmt.Sprintf(`{"max_tokens":100,"model":%q,"prompt":"test2","temperature":0}`, modelSQLLoraTarget), - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "hi", - RawValue: []byte("mom"), - }, - }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: requtil.RequestIdHeaderKey, - RawValue: []byte("test-request-id"), - }, + { + name: "selects only available pod in subset despite high load", + requests: integrationutils.GenerateStreamedRequestSetWithHeaders(logger, "subset-test", modelMyModel, + modelMyModelTarget, []string{"192.168.1.3:8000"}, nil), + pods: []podState{ + {index: 0, queueSize: 0, kvCacheUsage: 0.1}, + {index: 1, queueSize: 0, kvCacheUsage: 0.1}, + {index: 2, queueSize: 10, kvCacheUsage: 0.9}, }, - ), - }, - { - name: "select active lora with subsetting tag, no pods available", - requests: integrationutils.GenerateStreamedRequestSet( - logger, - "test2", - modelSQLLora, - modelSQLLoraTarget, - []string{"192.168.1.4:8000", "192.168.1.5:8000", "192.168.1.6:8000"}), - // No pods will be picked as none are within the subset. - pods: newPodStates( - podState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - ), + wantResponses: expectRouteTo("192.168.1.3:8000", modelMyModelTarget, "subset-test"), + }, + { + name: "returns error when no pods match subset", + requests: integrationutils.GenerateStreamedRequestSetWithHeaders(logger, "subset-test", modelMyModel, + modelMyModelTarget, []string{"192.168.1.4:8000"}, nil), + pods: []podState{{index: 0}, {index: 1}, {index: 2}}, + wantErr: true, + wantResponses: integrationutils.NewImmediateErrorResponse( + envoyTypePb.StatusCode_ServiceUnavailable, + "inference gateway: ServiceUnavailable - failed to find candidate pods for serving the request", + ), + }, + } + runTestCases(t, testCases, nil) + }) - wantMetrics: map[string]string{}, - wantErr: true, - wantResponses: []*extProcPb.ProcessingResponse{ - { - Response: &extProcPb.ProcessingResponse_ImmediateResponse{ - ImmediateResponse: &extProcPb.ImmediateResponse{ - Status: &envoyTypePb.HttpStatus{ - Code: envoyTypePb.StatusCode_ServiceUnavailable, - }, - Body: []byte("inference gateway: ServiceUnavailable - failed to find candidate pods for serving the request"), - }, - }, + t.Run("ErrorConditions", func(t *testing.T) { + testCases := []testCase{ + { + name: "invalid json in request body", + requests: []*extProcPb.ProcessingRequest{ + {Request: &extProcPb.ProcessingRequest_RequestHeaders{RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{}, + }}}, + {Request: &extProcPb.ProcessingRequest_RequestBody{RequestBody: &extProcPb.HttpBody{ + Body: []byte("not json"), EndOfStream: true, + }}}, }, + pods: defaultPods, + wantErr: true, + wantResponses: integrationutils.NewImmediateErrorResponse( + envoyTypePb.StatusCode_BadRequest, + "inference gateway: BadRequest - Error unmarshaling request body", + ), }, - }, - { - name: "no backend pods are available", - requests: []*extProcPb.ProcessingRequest{ - { - Request: &extProcPb.ProcessingRequest_RequestHeaders{ - RequestHeaders: &extProcPb.HttpHeaders{ - Headers: &configPb.HeaderMap{ - Headers: []*configPb.HeaderValue{ - { - Key: "content-type", - RawValue: []byte("text/event-stream"), - }, - { - Key: "status", - RawValue: []byte("200"), - }, - }, - }, - EndOfStream: true, - }, - }, - }, - }, - pods: nil, - wantMetrics: map[string]string{}, - wantErr: true, - wantResponses: []*extProcPb.ProcessingResponse{ - { - Response: &extProcPb.ProcessingResponse_ImmediateResponse{ - ImmediateResponse: &extProcPb.ImmediateResponse{ - Status: &envoyTypePb.HttpStatus{ - Code: envoyTypePb.StatusCode_InternalServerError, - }, - Body: []byte("inference gateway: Internal - no pods available in datastore"), - }, - }, + { + name: "request body is missing model field", + requests: []*extProcPb.ProcessingRequest{ + {Request: &extProcPb.ProcessingRequest_RequestHeaders{RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{}, + }}}, + {Request: &extProcPb.ProcessingRequest_RequestBody{RequestBody: &extProcPb.HttpBody{ + Body: []byte(`{"prompt":"test"}`), EndOfStream: true, + }}}, }, + pods: defaultPods, + wantErr: true, + wantResponses: integrationutils.NewImmediateErrorResponse( + envoyTypePb.StatusCode_BadRequest, + "inference gateway: BadRequest - model not found in request body", + ), }, - }, - { - name: "request don't contains invalid payload, model not exist", - requests: []*extProcPb.ProcessingRequest{ - { - Request: &extProcPb.ProcessingRequest_RequestBody{ - RequestBody: &extProcPb.HttpBody{ - Body: []byte(`{"hello":"world"}`), - EndOfStream: true}, - }, + { + name: "no backend pods available in datastore", + requests: []*extProcPb.ProcessingRequest{ + {Request: &extProcPb.ProcessingRequest_RequestHeaders{RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{}, EndOfStream: true, + }}}, }, + pods: nil, + wantErr: true, + wantResponses: integrationutils.NewImmediateErrorResponse( + envoyTypePb.StatusCode_InternalServerError, + "inference gateway: Internal - no pods available in datastore", + ), }, - wantErr: true, - wantMetrics: map[string]string{}, - wantResponses: []*extProcPb.ProcessingResponse{ - { - Response: &extProcPb.ProcessingResponse_ImmediateResponse{ - ImmediateResponse: &extProcPb.ImmediateResponse{ - Status: &envoyTypePb.HttpStatus{ - Code: envoyTypePb.StatusCode_BadRequest, - }, - Body: []byte("inference gateway: BadRequest - model not found in request body"), - }, - }, + } + runTestCases(t, testCases, nil) + }) + + t.Run("RequestTypes", func(t *testing.T) { + testCases := []testCase{ + { + name: "simple GET request is passed through", + requests: []*extProcPb.ProcessingRequest{ + {Request: &extProcPb.ProcessingRequest_RequestHeaders{RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{}, EndOfStream: true}}}, }, + pods: defaultPods, + wantResponses: nil, // Expect no modification, just pass-through. }, - }, - } + } + runTestCases(t, testCases, nil) + }) +} - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - client, cleanup := setUpHermeticServer(t, test.pods) - t.Cleanup(cleanup) - responses, err := integrationutils.StreamedRequest(t, client, test.requests, len(test.wantResponses)) +// setupTestInfrastructure is the core setup engine for a single hermetic test case. It performs the following steps: +// 1. Creates a new controller-runtime manager with a cache scoped to a unique test ID to ensure resource isolation. +// 2. Starts the manager in the background. +// 3. Instantiates a new EPP server runner, configuring it with a fake metrics client and a real scheduler. +// 4. Creates the fake Kubernetes pod objects and injects their simulated metrics into the fake client. +// 5. Starts the EPP server in the background on a free port. +// 6. Waits until the manager's cache has synced and the EPP datastore is populated with the fake pods. +// 7. Creates and returns a gRPC client connected to the test server. +// +// It returns the server runner instance, the gRPC client, a function to stop the server and manager, the gRPC +// connection, and a function to cancel the client context. +func setupTestInfrastructure( + t *testing.T, + podAndMetrics map[*backend.Pod]*backendmetrics.MetricsState, + sdConfig *saturationdetector.Config, + uniqueSuffix string, +) ( + *server.ExtProcServerRunner, + extProcPb.ExternalProcessor_ProcessClient, + context.CancelFunc, + *grpc.ClientConn, + context.CancelFunc, +) { + // --- 1. Create a Manager with a Pod Selector for this Test Run --- + // We use a unique suffix for each test run to label pods. + // The manager's cache is configured to only watch pods with this label, ensuring that tests do not interfere with + // each other's backend pods. + testRunLabel := "test-run-id" + podSelector := labels.SelectorFromSet(map[string]string{testRunLabel: uniqueSuffix}) + mgr, err := server.NewManagerWithOptions(testEnv.Config, managerTestOptions(testNamespace, testPoolName, podSelector)) + require.NoError(t, err) - if err != nil && !test.wantErr { - t.Errorf("Unexpected error, got: %v, want error: %v", err, test.wantErr) - } - if diff := cmp.Diff(test.wantResponses, responses, - protocmp.Transform(), - protocmp.SortRepeated(func(a, b *configPb.HeaderValueOption) bool { - return a.GetHeader().GetKey() < b.GetHeader().GetKey() - }), - ); diff != "" { - t.Errorf("Unexpected response, (-want +got): %v", diff) - } + // --- 2. Start the Manager and EPP Server --- + // The manager is started in the background to handle Kubernetes watches. + // The EPP server is configured to run on a random free port to avoid conflicts during parallel test execution. + managerCtx, stopManager := context.WithCancel(context.Background()) + serverCtx, stopServer := context.WithCancel(context.Background()) - if len(test.wantMetrics) != 0 { - for metricName, value := range test.wantMetrics { - if err := metricsutils.GatherAndCompare(crmetrics.Registry, strings.NewReader(value), metricName); err != nil { - t.Error(err) - } - } + go func() { + if err := mgr.Start(managerCtx); err != nil { + if !errors.Is(err, context.Canceled) { + t.Errorf("Failed to start manager: %v", err) } - metrics.Reset() - }) + } + }() + + kvCacheUtilizationScorer := scorer.NewKVCacheUtilizationScorer() + queueingScorer := scorer.NewQueueScorer() + prefixCacheScorer := prefix.New(context.Background(), prefix.DefaultConfig) + loraAffinityScorer := scorer.NewLoraAffinityScorer() + defaultProfile := framework.NewSchedulerProfile(). + WithScorers( + framework.NewWeightedScorer(kvCacheUtilizationScorer, 1), + framework.NewWeightedScorer(queueingScorer, 1), + framework.NewWeightedScorer(prefixCacheScorer, 1), + framework.NewWeightedScorer(loraAffinityScorer, 1), + ). + WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints)) + profileHandler := profile.NewSingleProfileHandler() + schedulerConfig := scheduling.NewSchedulerConfig(profileHandler, map[string]*framework.SchedulerProfile{ + "default": defaultProfile, + }) + scheduler := scheduling.NewSchedulerWithConfig(schedulerConfig) + + runner := server.NewDefaultExtProcServerRunner() + grpcAddress, err := integrationutils.GetFreePort() + require.NoError(t, err) + runner.GrpcPort = grpcAddress.Port + runner.TestPodMetricsClient = &backendmetrics.FakePodMetricsClient{} + pmf := backendmetrics.NewPodMetricsFactory(runner.TestPodMetricsClient, 10*time.Millisecond) + + runner.PoolGKNN = common.GKNN{ + NamespacedName: types.NamespacedName{Namespace: testNamespace, Name: testPoolName}, + GroupKind: schema.GroupKind{Group: v1.GroupVersion.Group, Kind: "InferencePool"}, } -} + runner.Datastore = datastore.NewDatastore(context.Background(), pmf) + runner.SecureServing = false + + if err := runner.SetupWithManager(context.Background(), mgr); err != nil { + t.Fatalf("Failed to setup test-local server runner: %v", err) + } + + // --- 3. Configure the Director and its Dependencies --- + // The core EPP logic (scheduler, saturation detector) is wired up here. + // This allows tests to enable/disable features and provide custom configurations. + if sdConfig == nil { + sdConfig = &saturationdetector.Config{ + QueueDepthThreshold: saturationdetector.DefaultQueueDepthThreshold, + KVCacheUtilThreshold: saturationdetector.DefaultKVCacheUtilThreshold, + MetricsStalenessThreshold: saturationdetector.DefaultMetricsStalenessThreshold, + } + } + detector := saturationdetector.NewDetector(sdConfig, logger.WithName("saturation-detector")) + runner.SaturationDetector = detector + runner.Director = requestcontrol.NewDirectorWithConfig(runner.Datastore, scheduler, detector, + requestcontrol.NewConfig()) -func setUpHermeticServer(t *testing.T, podAndMetrics map[*backend.Pod]*backendmetrics.MetricsState) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { - // Reconfigure the TestPodMetricsClient. + // --- 4. Create Fake Backend Pods and Metrics --- + // The test harness creates fake Kubernetes pod objects and injects a fake metrics client into the server runner. + // This gives each test precise control over the perceived state of the backend. res := map[types.NamespacedName]*backendmetrics.MetricsState{} for pod, metrics := range podAndMetrics { res[pod.NamespacedName] = metrics } - serverRunner.TestPodMetricsClient.SetRes(res) + runner.TestPodMetricsClient.SetRes(res) - serverCtx, stopServer := context.WithCancel(context.Background()) - - // TODO: this should be consistent with the inference pool podLabels := map[string]string{ - "app": testPoolName, + "app": testPoolName, + testRunLabel: uniqueSuffix, } - for pod := range podAndMetrics { - pod := epptestutil.MakePod(pod.NamespacedName.Name). + podObj := epptestutil.MakePod(pod.NamespacedName.Name). Namespace(pod.NamespacedName.Namespace). ReadyCondition(). Labels(podLabels). IP(pod.Address). Complete(). ObjRef() - - copy := pod.DeepCopy() + copy := podObj.DeepCopy() if err := k8sClient.Create(context.Background(), copy); err != nil { - logutil.Fatal(logger, err, "Failed to create pod", "pod", pod) + t.Fatalf("Failed to create pod: %v", err) } - - // since no pod controllers deployed in fake environment, we manually update pod status - copy.Status = pod.Status + copy.Status = podObj.Status if err := k8sClient.Status().Update(context.Background(), copy); err != nil { - logutil.Fatal(logger, err, "Failed to update pod status", "pod", pod) + t.Fatalf("Failed to update pod status: %v", err) } } + go func() { - if err := serverRunner.AsRunnable(logger.WithName("ext-proc")).Start(serverCtx); err != nil { - logutil.Fatal(logger, err, "Failed to start ext-proc server") + if err := runner.AsRunnable(logger.WithName("ext-proc")).Start(serverCtx); err != nil { + if !errors.Is(err, context.Canceled) { + t.Errorf("Failed to start ext-proc server: %v", err) + } } }() - time.Sleep(serverRunner.RefreshPrometheusMetricsInterval) // wait for metrics to get available before running tests that rely on these metrics - - // check if all pods are synced to datastore + // --- 5. Wait for Datastore Sync --- + // We must wait for the controller-runtime cache to sync with the fake API server to ensure the EPP's datastore has + // the correct view of pods and objectives before the test begins. + // This is a critical step. We must block until the manager's cache has synced and propagated the pod and objective + // resources to the EPP's datastore. Otherwise, the test will start before the EPP server is aware of any backend + // pods, leading to guaranteed failures. assert.EventuallyWithT(t, func(t *assert.CollectT) { - assert.Len(t, serverRunner.Datastore.PodList(backendmetrics.AllPodsPredicate), len(podAndMetrics), "Datastore not synced") - }, 10*time.Second, time.Second) + synced := runner.Datastore.PoolHasSynced() + assert.True(t, synced, "Pool should be synced") + assert.Len(t, runner.Datastore.PodList(backendmetrics.AllPodsPredicate), len(podAndMetrics), "Datastore not synced") + assert.NotNil(t, runner.Datastore.ObjectiveGet(modelSheddable), "InferenceObjective not synced") + }, 10*time.Second, 100*time.Millisecond) - // Create a grpc connection - conn, err := grpc.NewClient(testGRPCAddress, grpc.WithTransportCredentials(insecure.NewCredentials())) + // --- 6. Set up the gRPC Client --- + // Finally, a gRPC client is created to communicate with the EPP server. + conn, err := grpc.NewClient(grpcAddress.String(), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - logutil.Fatal(logger, err, "Failed to connect", "address", testGRPCAddress) + t.Fatalf("Failed to connect to %s: %v", testGRPCAddress, err) } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - client, err = extProcPb.NewExternalProcessorClient(conn).Process(ctx) + client, err := extProcPb.NewExternalProcessorClient(conn).Process(ctx) if err != nil { - logutil.Fatal(logger, err, "Failed to create client") + t.Fatalf("Failed to create client: %v", err) } - return client, func() { - cancel() - conn.Close() - stopServer() - // clear created pods - for pod := range podAndMetrics { - pod := epptestutil.MakePod(pod.NamespacedName.Name). - Namespace(pod.NamespacedName.Namespace).Complete().ObjRef() - - if err := k8sClient.Delete(context.Background(), pod); err != nil { - logutil.Fatal(logger, err, "Failed to delete pod", "pod", fakePod) - } - } + stopAll := func() { + stopServer() + stopManager() } + + return runner, client, stopAll, conn, cancel } -func fakePod(index int) *backend.Pod { +func fakePod(index int, uniqueSuffix string) *backend.Pod { return &backend.Pod{ - NamespacedName: types.NamespacedName{Name: fmt.Sprintf("pod-%v", index), Namespace: testNamespace}, + NamespacedName: types.NamespacedName{Name: fmt.Sprintf("pod-%d-%s", index, uniqueSuffix), Namespace: testNamespace}, Address: fmt.Sprintf("192.168.1.%d", index+1), - Labels: make(map[string]string, 0), } } -// podState is a descriptor for a pod's simulated metrics. +// podState is a simplified descriptor for a pod's simulated metrics, used to define test scenarios. type podState struct { index int queueSize int @@ -1085,28 +904,54 @@ type podState struct { activeModels []string } -// newPodStates generates the backend metrics map required by the test setup. -func newPodStates(states ...podState) map[*backend.Pod]*backendmetrics.MetricsState { +// newPodStates is a test helper that converts a slice of simplified podState descriptors into the detailed +// `map[*backend.Pod]*backendmetrics.MetricsState` required by the fake metrics client. +func newPodStates(uniqueSuffix string, states ...podState) map[*backend.Pod]*backendmetrics.MetricsState { res := make(map[*backend.Pod]*backendmetrics.MetricsState) for _, s := range states { - pod := fakePod(s.index) - activeModelsMap := make(map[string]int) + pod := fakePod(s.index, uniqueSuffix) + actveModelsMap := make(map[string]int) for _, model := range s.activeModels { - activeModelsMap[model] = 1 + actveModelsMap[model] = 1 } res[pod] = &backendmetrics.MetricsState{ WaitingQueueSize: s.queueSize, KVCacheUsagePercent: s.kvCacheUsage, - ActiveModels: activeModelsMap, + ActiveModels: actveModelsMap, WaitingModels: make(map[string]int), } } return res } -// Sets up a test environment and returns the runner struct +// validateResponse centralizes the validation logic for test responses. It correctly handles the expected `io.EOF` +// error for immediate responses and performs a detailed diff of the received protobuf messages against the expected +// ones. +func validateResponse(t *testing.T, err error, wantErr bool, wantResponses, responses []*extProcPb.ProcessingResponse) { + if wantErr { + // For immediate error responses, the server often closes the stream, resulting in `io.EOF`. + // A nil error is also acceptable if the server sends the response then waits for the client to close. + if err != nil { + require.ErrorIs(t, err, io.EOF, "Expected EOF or nil error for immediate response stream") + } + } else { + require.NoError(t, err) + } + + if diff := cmp.Diff(wantResponses, responses, + protocmp.Transform(), + protocmp.SortRepeated(func(a, b *configPb.HeaderValueOption) bool { + return a.GetHeader().GetKey() < b.GetHeader().GetKey() + }), + ); diff != "" { + t.Errorf("Unexpected response, (-want +got): %v", diff) + } +} + +// BeforeSuite sets up the hermetic test environment for the entire package. It starts a fake API server using envtest, +// creates a Kubernetes client, and pre-loads the Custom Resource Definitions (CRDs) and a common set of custom +// resources (like `InferencePool` and `InferenceObjective`) that are required by the test cases. func BeforeSuite() func() { - // Set up mock k8s API Client testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "config", "crd", "bases")}, ErrorIfCRDPathMissing: true, @@ -1123,121 +968,40 @@ func BeforeSuite() func() { k8sClient, err = k8sclient.New(cfg, k8sclient.Options{Scheme: scheme}) if err != nil { logutil.Fatal(logger, err, "Failed to start k8s Client") - } else if k8sClient == nil { - logutil.Fatal(logger, nil, "No error, but returned kubernetes client is nil", "config", cfg) } - // Init runtime. - ctrl.SetLogger(logger) - - metrics.Register() - // Register metrics handler. - // Metrics endpoint is enabled in 'config/default/kustomization.yaml'. The Metrics options configure the server. - // More info: - // - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.19.1/pkg/metrics/server - // - https://book.kubebuilder.io/reference/metrics.html - metricsServerOptions := metricsserver.Options{ - BindAddress: fmt.Sprintf(":%d", testMetricsPort), - FilterProvider: filters.WithAuthenticationAndAuthorization, - } - mgr, err := server.NewManagerWithOptions(cfg, managerTestOptions(testNamespace, testPoolName, metricsServerOptions)) - if err != nil { - logutil.Fatal(logger, err, "Failed to create controller manager") - } - - serverRunner = server.NewDefaultExtProcServerRunner() - serverRunner.TestPodMetricsClient = &backendmetrics.FakePodMetricsClient{} - pmf := backendmetrics.NewPodMetricsFactory(serverRunner.TestPodMetricsClient, 10*time.Millisecond) - // Adjust from defaults - serverRunner.PoolGKNN = common.GKNN{ - NamespacedName: types.NamespacedName{Namespace: testNamespace, Name: testPoolName}, - GroupKind: schema.GroupKind{Group: v1.GroupVersion.Group, Kind: "InferencePool"}, - } - serverRunner.Datastore = datastore.NewDatastore(context.Background(), pmf) - - kvCacheUtilizationScorer := scorer.NewKVCacheUtilizationScorer() - queueingScorer := scorer.NewQueueScorer() - prefixCacheScorer := prefix.New(context.Background(), prefix.DefaultConfig) - loraAffinityScorer := scorer.NewLoraAffinityScorer() - - defaultProfile := framework.NewSchedulerProfile(). - WithScorers(framework.NewWeightedScorer(kvCacheUtilizationScorer, 1), - framework.NewWeightedScorer(queueingScorer, 1), - framework.NewWeightedScorer(prefixCacheScorer, 1), - framework.NewWeightedScorer(loraAffinityScorer, 1), - ). - WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints)) - - profileHandler := profile.NewSingleProfileHandler() - - schedulerConfig := scheduling.NewSchedulerConfig(profileHandler, map[string]*framework.SchedulerProfile{"default": defaultProfile}) - scheduler := scheduling.NewSchedulerWithConfig(schedulerConfig) - - sdConfig := &saturationdetector.Config{ - QueueDepthThreshold: saturationdetector.DefaultQueueDepthThreshold, - KVCacheUtilThreshold: saturationdetector.DefaultKVCacheUtilThreshold, - MetricsStalenessThreshold: saturationdetector.DefaultMetricsStalenessThreshold, - } - detector := saturationdetector.NewDetector(sdConfig, logger.WithName("saturation-detector")) - serverRunner.SaturationDetector = detector - serverRunner.Director = requestcontrol.NewDirectorWithConfig(serverRunner.Datastore, scheduler, detector, requestcontrol.NewConfig()) - serverRunner.SecureServing = false - - if err := serverRunner.SetupWithManager(context.Background(), mgr); err != nil { - logutil.Fatal(logger, err, "Failed to setup server runner") - } - - // Start the controller manager in a go routine, not blocking - go func() { - if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil { - logutil.Fatal(logger, err, "Failed to start manager") - } - }() - - logger.Info("Setting up hermetic ExtProc server") - - // Unmarshal CRDs from file into structs manifestsPath := filepath.Join("..", "..", "testdata", "inferencepool-with-model-hermetic.yaml") docs, err := readDocuments(manifestsPath) if err != nil { logutil.Fatal(logger, err, "Can't read object manifests", "path", manifestsPath) } - for _, doc := range docs { obj := &unstructured.Unstructured{} if err = yaml.Unmarshal(doc, obj); err != nil { logutil.Fatal(logger, err, "Can't unmarshal object", "document", doc) } - logger.Info("Creating object", "kind", obj.GetKind(), "object", obj) if err := k8sClient.Create(context.Background(), obj); err != nil { logutil.Fatal(logger, err, "Unable to create object", "object", obj.GetName()) } } - assert.Eventually(nil, func() bool { - modelExist := serverRunner.Datastore.ObjectiveGet(modelMyModel) - synced := serverRunner.Datastore.PoolHasSynced() && modelExist != nil - return synced - }, 10*time.Second, 10*time.Millisecond) + ctrl.SetLogger(logger) + metrics.Register() // Register global metrics once for the entire test suite. + logger.Info("Hermetic test suite setup complete") return func() { _ = testEnv.Stop() - _ = k8sClient.DeleteAllOf(context.Background(), &v1.InferencePool{}) - _ = k8sClient.DeleteAllOf(context.Background(), &v1alpha2.InferenceObjective{}) } } -// readDocuments reads documents from file. func readDocuments(fp string) ([][]byte, error) { b, err := os.ReadFile(fp) if err != nil { return nil, err } - - docs := [][]byte{} + var docs [][]byte reader := k8syaml.NewYAMLReader(bufio.NewReader(bytes.NewReader(b))) for { - // Read document doc, err := reader.Read() if err != nil { if errors.Is(err, io.EOF) { @@ -1250,38 +1014,36 @@ func readDocuments(fp string) ([][]byte, error) { return docs, nil } -// inject options that allow multiple test runs to run -// https://github.com/kubernetes-sigs/controller-runtime/issues/2937 -func managerTestOptions(namespace, name string, metricsServerOptions metricsserver.Options) ctrl.Options { +// managerTestOptions configures the controller-runtime manager for test isolation. Its most critical job is to +// configure the cache to only watch pods that have a specific `test-run-id` label. +// This ensures that the manager's cache for one test run is completely isolated from the pods created in another, +// preventing test interference. It also disables the metrics server to avoid port conflicts. +func managerTestOptions(namespace, name string, podSelector labels.Selector) ctrl.Options { return ctrl.Options{ Scheme: scheme, Cache: cache.Options{ ByObject: map[k8sclient.Object]cache.ByObject{ &corev1.Pod{}: { Namespaces: map[string]cache.Config{ - namespace: {}, + namespace: { + LabelSelector: podSelector, + }, }, }, &v1.InferencePool{}: { Namespaces: map[string]cache.Config{ namespace: { - FieldSelector: fields.SelectorFromSet(fields.Set{ - "metadata.name": name, - }), + FieldSelector: fields.SelectorFromSet(fields.Set{"metadata.name": name}), }, }, }, - &v1alpha2.InferenceObjective{}: { - Namespaces: map[string]cache.Config{ - namespace: {}, - }, - }, + &v1alpha2.InferenceObjective{}: {Namespaces: map[string]cache.Config{namespace: {}}}, }, }, - Controller: crconfig.Controller{ - SkipNameValidation: boolPointer(true), + Controller: crconfig.Controller{SkipNameValidation: boolPointer(true)}, + Metrics: metricsserver.Options{ + BindAddress: "0", // Disable the metrics server in tests }, - Metrics: metricsServerOptions, } } diff --git a/test/integration/util.go b/test/integration/util.go index d78b76e28..08610762b 100644 --- a/test/integration/util.go +++ b/test/integration/util.go @@ -18,7 +18,10 @@ package integration import ( "encoding/json" + "errors" + "fmt" "io" + "net" "strconv" "testing" "time" @@ -38,6 +41,24 @@ const ( headerKeyContentLength = "Content-Length" ) +// GetFreePort finds and returns an available TCP port on the host. +// It works by asking the OS to allocate a port by listening on port 0, capturing the assigned address, and then +// immediately closing the listener. +func GetFreePort() (*net.TCPAddr, error) { + // A port number of 0 instructs the OS to select a random, available port. + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + return nil, fmt.Errorf("failed to listen on a free port: %w", err) + } + defer listener.Close() + + addr, ok := listener.Addr().(*net.TCPAddr) + if !ok { + return nil, errors.New("failed to cast listener address to TCPAddr") + } + return addr, nil +} + func SendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, error) { t.Logf("Sending request: %v", req) if err := client.Send(req); err != nil { @@ -54,7 +75,13 @@ func SendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, return res, err } -func StreamedRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, requests []*extProcPb.ProcessingRequest, expectedResponses int) ([]*extProcPb.ProcessingResponse, error) { +// StreamedRequest sends a series of requests and collects the specified number of responses. +func StreamedRequest( + t *testing.T, + client extProcPb.ExternalProcessor_ProcessClient, + requests []*extProcPb.ProcessingRequest, + expectedResponses int, +) ([]*extProcPb.ProcessingResponse, error) { for _, req := range requests { t.Logf("Sending request: %v", req) if err := client.Send(req); err != nil { @@ -62,27 +89,35 @@ func StreamedRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessCli return nil, err } } - responses := []*extProcPb.ProcessingResponse{} - - // Make an incredible simple timeout func in the case where - // there is less than the expected amount of responses; bail and fail. - var simpleTimeout bool - go func() { - time.Sleep(10 * time.Second) - simpleTimeout = true - }() - - for range expectedResponses { - if simpleTimeout { - break + + var responses []*extProcPb.ProcessingResponse + for i := range expectedResponses { + type recvResult struct { + res *extProcPb.ProcessingResponse + err error } - res, err := client.Recv() - if err != nil && err != io.EOF { - t.Logf("Failed to receive: %v", err) - return nil, err + recvChan := make(chan recvResult, 1) + + go func() { + res, err := client.Recv() + recvChan <- recvResult{res, err} + }() + + select { + case <-time.After(10 * time.Second): + t.Logf("Timeout waiting for response %d of %d", i+1, expectedResponses) + return responses, nil + case result := <-recvChan: + if result.err != nil { + if result.err == io.EOF { + return responses, nil + } + t.Logf("Failed to receive: %v", result.err) + return nil, result.err + } + t.Logf("Received response %+v", result.res) + responses = append(responses, result.res) } - t.Logf("Received response %+v", res) - responses = append(responses, res) } return responses, nil } @@ -113,43 +148,96 @@ func GenerateRequest(logger logr.Logger, prompt, model string, filterMetadata [] } func GenerateStreamedRequestSet(logger logr.Logger, prompt, model, targetModel string, filterMetadata []string) []*extProcPb.ProcessingRequest { + return GenerateStreamedRequestSetWithHeaders(logger, prompt, model, targetModel, filterMetadata, nil) +} + +// GenerateStreamedRequestSetWithHeaders creates a complete set of gRPC messages to simulate a realistic, multi-chunk +// HTTP request. It includes a headers message, followed by two body messages, which is representative of how Envoy +// streams request bodies. It allows adding extra headers for specialized test cases. +func GenerateStreamedRequestSetWithHeaders(logger logr.Logger, prompt, model, targetModel string, filterMetadata []string, extraHeaders map[string]string) []*extProcPb.ProcessingRequest { requests := []*extProcPb.ProcessingRequest{} + headers := []*envoyCorev3.HeaderValue{ + { + Key: "hi", + Value: "mom", + }, + { + Key: metadata.ObjectiveKey, + Value: model, + }, + { + Key: metadata.ModelNameRewriteKey, + Value: targetModel, + }, + { + Key: requtil.RequestIdHeaderKey, + Value: "test-request-id", + }, + } + + for k, v := range extraHeaders { + headers = append(headers, &envoyCorev3.HeaderValue{ + Key: k, + Value: v, + }) + } + headerReq := &extProcPb.ProcessingRequest{ Request: &extProcPb.ProcessingRequest_RequestHeaders{ RequestHeaders: &extProcPb.HttpHeaders{ Headers: &envoyCorev3.HeaderMap{ - Headers: []*envoyCorev3.HeaderValue{ - { - Key: "hi", - Value: "mom", - }, - { - Key: metadata.ObjectiveKey, - Value: model, - }, - { - Key: metadata.ModelNameRewriteKey, - Value: targetModel, - }, - { - Key: requtil.RequestIdHeaderKey, - Value: "test-request-id", - }, - }, + Headers: headers, }, }, }, + MetadataContext: &envoyCorev3.Metadata{ + FilterMetadata: GenerateRequestMetadata(filterMetadata), + }, } + requests = append(requests, headerReq) - headerReq.MetadataContext = &envoyCorev3.Metadata{ - FilterMetadata: GenerateRequestMetadata(filterMetadata), + // Create and split the request body. + j := map[string]any{ + "prompt": prompt, + "max_tokens": 100, + "temperature": 0, + } + if model != "" { + j["model"] = model + } + llmReq, err := json.Marshal(j) + if err != nil { + logutil.Fatal(logger, err, "Failed to marshal LLM request") } - requests = append(requests, headerReq) - requests = append(requests, GenerateRequest(logger, prompt, model, filterMetadata)) + // Simulate a multi-chunk body by splitting the marshaled JSON. + // This is a more realistic representation of how a streaming body might arrive. + splitPoint := len(llmReq) / 2 + chunk1 := llmReq[:splitPoint] + chunk2 := llmReq[splitPoint:] + + requests = append(requests, &extProcPb.ProcessingRequest{ + Request: &extProcPb.ProcessingRequest_RequestBody{ + RequestBody: &extProcPb.HttpBody{Body: chunk1, EndOfStream: false}, + }, + MetadataContext: &envoyCorev3.Metadata{ + FilterMetadata: GenerateRequestMetadata(filterMetadata), + }, + }) + requests = append(requests, &extProcPb.ProcessingRequest{ + Request: &extProcPb.ProcessingRequest_RequestBody{ + RequestBody: &extProcPb.HttpBody{Body: chunk2, EndOfStream: true}, + }, + MetadataContext: &envoyCorev3.Metadata{ + FilterMetadata: GenerateRequestMetadata(filterMetadata), + }, + }) + return requests } +// GenerateRequestMetadata constructs the nested metadata structure required by Envoy for subset load balancing. +// It takes a list of endpoint addresses and embeds them into the `envoy.lb` filter metadata field. func GenerateRequestMetadata(filterMetadata []string) map[string]*structpb.Struct { requestMetadata := make(map[string]*structpb.Struct) interfaceList := make([]any, len(filterMetadata)) @@ -165,9 +253,8 @@ func GenerateRequestMetadata(filterMetadata []string) map[string]*structpb.Struc return requestMetadata } -// NewRequestBufferedResponse creates a complete set of responses for the request phase. -// It modifies request headers (e.g., for routing) and replaces the entire request body. -// It returns a slice of two messages, representing the complete buffered action. +// NewRequestBufferedResponse simulates a complete buffered mutation of the request phase. It returns a slice of +// two messages: one to replace the request headers (for routing) and one to replace the request body. func NewRequestBufferedResponse(destinationEndpoint string, rewrittenBody string, otherHeaders ...*envoyCorev3.HeaderValueOption) []*extProcPb.ProcessingResponse { setHeaders := []*envoyCorev3.HeaderValueOption{ { @@ -196,7 +283,7 @@ func NewRequestBufferedResponse(destinationEndpoint string, rewrittenBody string }, }, }, - DynamicMetadata: makeMetadata(destinationEndpoint), + DynamicMetadata: MakeMetadata(destinationEndpoint), } bodyResponse := &extProcPb.ProcessingResponse{ @@ -219,9 +306,8 @@ func NewRequestBufferedResponse(destinationEndpoint string, rewrittenBody string return []*extProcPb.ProcessingResponse{headerResponse, bodyResponse} } -// NewResponseBufferedResponse creates a complete set of responses for the response phase. -// It modifies response headers and replaces the entire response body. -// It is used when the processor buffers the upstream response before sending its own. +// NewResponseBufferedResponse simulates a complete buffered mutation of the response phase. It returns a slice of +// messages to first modify the response headers and then replace the entire response body. func NewResponseBufferedResponse(rewrittenBody string, headersToSet ...*envoyCorev3.HeaderValueOption) []*extProcPb.ProcessingResponse { return []*extProcPb.ProcessingResponse{ NewResponseHeaders(headersToSet...), @@ -245,8 +331,8 @@ func NewResponseHeaders(headersToSet ...*envoyCorev3.HeaderValueOption) *extProc } } -// NewResponseStreamChunk creates a single response for one body chunk in a stream. -// This is used to test streaming behaviors like text/event-stream pass-through. +// NewResponseStreamChunk creates a single gRPC message to send one chunk of a streaming response body. +// This is used to test streaming behaviors, such as passing through a text/event-stream. func NewResponseStreamChunk(body string, endOfStream bool) *extProcPb.ProcessingResponse { return &extProcPb.ProcessingResponse{ Response: &extProcPb.ProcessingResponse_ResponseBody{ @@ -282,8 +368,8 @@ func NewImmediateErrorResponse(code envoyTypePb.StatusCode, body string) []*extP return []*extProcPb.ProcessingResponse{response} } -// makeMetadata creates the dynamic metadata struct that Envoy uses for routing hints. -func makeMetadata(endpoint string) *structpb.Struct { +// MakeMetadata creates the dynamic metadata struct that Envoy uses for routing hints. +func MakeMetadata(endpoint string) *structpb.Struct { return &structpb.Struct{ Fields: map[string]*structpb.Value{ metadata.DestinationEndpointNamespace: {