diff --git a/internal/extproc/messages_processor.go b/internal/extproc/messages_processor.go index 6888980a48..d39a017aa6 100644 --- a/internal/extproc/messages_processor.go +++ b/internal/extproc/messages_processor.go @@ -6,10 +6,12 @@ package extproc import ( + "bytes" "cmp" "context" "encoding/json" "fmt" + "io" "log/slog" corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -147,6 +149,21 @@ type messagesProcessorUpstreamFilter struct { stream bool metrics metrics.ChatCompletionMetrics costs translator.LLMTokenUsage + gzipBuffer []byte +} + +// createResponseBodyResponse creates a ProcessingResponse for response body processing. +func createResponseBodyResponse(headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation) *extprocv3.ProcessingResponse { + return &extprocv3.ProcessingResponse{ + Response: &extprocv3.ProcessingResponse_ResponseBody{ + ResponseBody: &extprocv3.BodyResponse{ + Response: &extprocv3.CommonResponse{ + HeaderMutation: headerMutation, + BodyMutation: bodyMutation, + }, + }, + }, + } } // selectTranslator selects the translator based on the output schema. @@ -269,14 +286,22 @@ func (c *messagesProcessorUpstreamFilter) ProcessResponseBody(ctx context.Contex } }() - // Decompress the body if needed using common utility. - decodingResult, err := decodeContentIfNeeded(body.Body, c.responseEncoding) + // Decompress response body with buffering support + decodingResult, err := decodeContentWithBuffering(body.Body, c.responseEncoding, &c.gzipBuffer, body.EndOfStream) if err != nil { return nil, err } - // headerMutation, bodyMutation, tokenUsage, err := c.translator.ResponseBody(c.responseHeaders, br, body.EndOfStream). - headerMutation, bodyMutation, tokenUsage, responseModel, err := c.translator.ResponseBody(c.responseHeaders, decodingResult.reader, body.EndOfStream) + // Check if we got decompressed data or are still buffering + data, _ := io.ReadAll(decodingResult.reader) + if len(data) == 0 && c.stream && !body.EndOfStream { + // Still buffering incomplete data - return early with no mutations, skip metrics + return createResponseBodyResponse(nil, nil), nil + } + + // Process the decompressed data + decodingResult.reader = bytes.NewReader(data) + headerMutation, bodyMutation, tokenUsage, responseModel, err := c.translator.ResponseBody(c.responseHeaders, decodingResult.reader, c.stream) if err != nil { return nil, fmt.Errorf("failed to transform response: %w", err) } @@ -284,19 +309,12 @@ func (c *messagesProcessorUpstreamFilter) ProcessResponseBody(ctx context.Contex c.metrics.SetResponseModel(responseModel) // Remove content-encoding header if original body encoded but was mutated in the processor. - headerMutation = removeContentEncodingIfNeeded(headerMutation, bodyMutation, decodingResult.isEncoded) - - resp := &extprocv3.ProcessingResponse{ - Response: &extprocv3.ProcessingResponse_ResponseBody{ - ResponseBody: &extprocv3.BodyResponse{ - Response: &extprocv3.CommonResponse{ - HeaderMutation: headerMutation, - BodyMutation: bodyMutation, - }, - }, - }, + if !c.stream || body.EndOfStream { + headerMutation = removeContentEncodingIfNeeded(headerMutation, bodyMutation, decodingResult.isEncoded) } + resp := createResponseBodyResponse(headerMutation, bodyMutation) + c.costs.InputTokens += tokenUsage.InputTokens c.costs.OutputTokens += tokenUsage.OutputTokens c.costs.TotalTokens += tokenUsage.TotalTokens diff --git a/internal/extproc/translator/anthropic_gcpanthropic.go b/internal/extproc/translator/anthropic_gcpanthropic.go index ea0d2ce56f..ee7e764e9a 100644 --- a/internal/extproc/translator/anthropic_gcpanthropic.go +++ b/internal/extproc/translator/anthropic_gcpanthropic.go @@ -90,7 +90,7 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseHeaders(_ map[string]string) // ResponseBody implements [AnthropicMessagesTranslator.ResponseBody] for Anthropic to GCP Anthropic. // This is essentially a passthrough since both use the same Anthropic response format. -func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, body io.Reader, endOfStream bool) ( +func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, body io.Reader, isStreaming bool) ( headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tokenUsage LLMTokenUsage, responseModel string, err error, ) { // Read the response body for both streaming and non-streaming. @@ -99,8 +99,8 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, bo return nil, nil, LLMTokenUsage{}, "", fmt.Errorf("failed to read response body: %w", err) } - // For streaming chunks, parse SSE format to extract token usage. - if !endOfStream { + // For streaming requests, parse SSE format to extract token usage. + if isStreaming { // Parse SSE format - split by lines and look for data: lines. for line := range bytes.Lines(bodyBytes) { line = bytes.TrimSpace(line) @@ -117,14 +117,16 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, bo switch eventType { case "message_start": // Extract input tokens from message.usage. + // Now handles complete response with potentially multiple events. if messageData, ok := eventData["message"].(map[string]any); ok { if usageData, ok := messageData["usage"].(map[string]any); ok { if inputTokens, ok := usageData["input_tokens"].(float64); ok { - tokenUsage.InputTokens = uint32(inputTokens) //nolint:gosec + // Accumulate input tokens (though typically only one message_start per conversation) + tokenUsage.InputTokens += uint32(inputTokens) //nolint:gosec } // Some message_start events may include initial output tokens. if outputTokens, ok := usageData["output_tokens"].(float64); ok && outputTokens > 0 { - tokenUsage.OutputTokens = uint32(outputTokens) //nolint:gosec + tokenUsage.OutputTokens += uint32(outputTokens) //nolint:gosec } tokenUsage.TotalTokens = tokenUsage.InputTokens + tokenUsage.OutputTokens } @@ -143,18 +145,16 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, bo } } - return nil, &extprocv3.BodyMutation{ - Mutation: &extprocv3.BodyMutation_Body{Body: bodyBytes}, - }, tokenUsage, a.requestModel, nil + // For streaming responses, we only extract token usage, don't modify the body + // Return nil bodyMutation to pass through original data (potentially gzipped) + return nil, nil, tokenUsage, a.requestModel, nil } // Parse the Anthropic response to extract token usage. var anthropicResp anthropic.Message if err = json.Unmarshal(bodyBytes, &anthropicResp); err != nil { - // If we can't parse as Anthropic format, pass through as-is. - return nil, &extprocv3.BodyMutation{ - Mutation: &extprocv3.BodyMutation_Body{Body: bodyBytes}, - }, LLMTokenUsage{}, a.requestModel, nil + // If we can't parse as Anthropic format, pass through as-is without modification. + return nil, nil, LLMTokenUsage{}, a.requestModel, nil } // Extract token usage from the response. @@ -164,12 +164,6 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, bo TotalTokens: uint32(anthropicResp.Usage.InputTokens + anthropicResp.Usage.OutputTokens), //nolint:gosec } - // Pass through the response body unchanged since both input and output are Anthropic format. - headerMutation = &extprocv3.HeaderMutation{} - setContentLength(headerMutation, bodyBytes) - bodyMutation = &extprocv3.BodyMutation{ - Mutation: &extprocv3.BodyMutation_Body{Body: bodyBytes}, - } - - return headerMutation, bodyMutation, tokenUsage, a.requestModel, nil + // Pass through the response body unchanged - don't create body mutation to preserve original encoding. + return nil, nil, tokenUsage, a.requestModel, nil } diff --git a/internal/extproc/translator/anthropic_gcpanthropic_test.go b/internal/extproc/translator/anthropic_gcpanthropic_test.go index 8b8281eef4..7591501d84 100644 --- a/internal/extproc/translator/anthropic_gcpanthropic_test.go +++ b/internal/extproc/translator/anthropic_gcpanthropic_test.go @@ -569,12 +569,12 @@ func TestAnthropicToGCPAnthropicTranslator_ResponseBody_StreamingTokenUsage(t *t bodyReader := bytes.NewReader([]byte(tt.chunk)) respHeaders := map[string]string{"content-type": "application/json"} - headerMutation, bodyMutation, tokenUsage, _, err := translator.ResponseBody(respHeaders, bodyReader, tt.endOfStream) + // These tests are for streaming SSE chunks, so use isStreaming=true + headerMutation, bodyMutation, tokenUsage, _, err := translator.ResponseBody(respHeaders, bodyReader, true) require.NoError(t, err) require.Nil(t, headerMutation) - require.NotNil(t, bodyMutation) - require.Equal(t, tt.expectedBody, string(bodyMutation.GetBody())) + require.Nil(t, bodyMutation) // No body mutation to preserve original encoding require.Equal(t, tt.expectedUsage, tokenUsage) }) } @@ -649,12 +649,12 @@ func TestAnthropicToGCPAnthropicTranslator_ResponseBody_StreamingEdgeCases(t *te bodyReader := bytes.NewReader([]byte(tt.chunk)) respHeaders := map[string]string{"content-type": "application/json"} - headerMutation, bodyMutation, tokenUsage, _, err := translator.ResponseBody(respHeaders, bodyReader, false) + // These are streaming edge case tests, so use isStreaming=true + headerMutation, bodyMutation, tokenUsage, _, err := translator.ResponseBody(respHeaders, bodyReader, true) require.NoError(t, err) require.Nil(t, headerMutation) - require.NotNil(t, bodyMutation) - require.Equal(t, tt.chunk, string(bodyMutation.GetBody())) + require.Nil(t, bodyMutation) // No body mutation to preserve original encoding require.Equal(t, tt.expectedUsage, tokenUsage) }) } diff --git a/internal/extproc/util.go b/internal/extproc/util.go index 651437c0bf..44fb763fba 100644 --- a/internal/extproc/util.go +++ b/internal/extproc/util.go @@ -10,6 +10,7 @@ import ( "compress/gzip" "fmt" "io" + "log/slog" extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" ) @@ -42,6 +43,76 @@ func decodeContentIfNeeded(body []byte, contentEncoding string) (contentDecoding } } +// tryDecompressGzipBuffer attempts to decompress a gzip buffer. +// Returns the decompressed data or an error if decompression fails. +func tryDecompressGzipBuffer(buffer []byte) ([]byte, error) { + gzipReader, err := gzip.NewReader(bytes.NewReader(buffer)) + if err != nil { + return nil, fmt.Errorf("gzip header: %w", err) + } + defer gzipReader.Close() + + decompressed, err := io.ReadAll(gzipReader) + if err != nil { + return nil, fmt.Errorf("gzip decompression: %w", err) + } + + return decompressed, nil +} + +// decodeContentWithBuffering decompresses response body with buffering support for streaming. +// Accumulates chunks in the provided buffer until complete gzip data is available. +// Returns a reader for the (potentially decompressed) body and metadata about the encoding. +func decodeContentWithBuffering(body []byte, contentEncoding string, gzipBuffer *[]byte, endOfStream bool) (contentDecodingResult, error) { + switch contentEncoding { + case "gzip": + // Accumulate chunks in buffer + *gzipBuffer = append(*gzipBuffer, body...) + + // Try to decompress the accumulated buffer + if len(*gzipBuffer) > 0 { + decompressedBody, err := tryDecompressGzipBuffer(*gzipBuffer) + if err != nil { + // If it's not endOfStream, keep buffering + if !endOfStream { + return contentDecodingResult{ + reader: bytes.NewReader(nil), // Empty reader to signal buffering in progress + isEncoded: true, + }, nil + } + // If endOfStream and decompression failed, pass through buffered data + slog.Info("gzip buffering: decompression failed at end of stream, passing through buffered data", + "error", err, + "buffer_size", len(*gzipBuffer)) + result := contentDecodingResult{ + reader: bytes.NewReader(*gzipBuffer), + isEncoded: true, + } + *gzipBuffer = nil // Clear buffer + return result, nil + } + + // Successfully decompressed! + *gzipBuffer = nil // Clear buffer + return contentDecodingResult{ + reader: bytes.NewReader(decompressedBody), + isEncoded: true, + }, nil + } + + // Empty buffer, return empty + return contentDecodingResult{ + reader: bytes.NewReader(nil), // Empty reader for empty buffer + isEncoded: true, + }, nil + default: + return contentDecodingResult{ + reader: bytes.NewReader(body), + isEncoded: false, + }, nil + } +} + // removeContentEncodingIfNeeded removes the content-encoding header if the body was modified and was encoded. // This is needed when the transformation modifies the body content but the response was originally compressed. func removeContentEncodingIfNeeded(headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, isEncoded bool) *extprocv3.HeaderMutation { diff --git a/internal/extproc/util_test.go b/internal/extproc/util_test.go index c63131aaa8..8125f81d01 100644 --- a/internal/extproc/util_test.go +++ b/internal/extproc/util_test.go @@ -95,6 +95,165 @@ func TestDecodeContentIfNeeded(t *testing.T) { } } +func TestDecodeContentWithBuffering(t *testing.T) { + t.Run("successful buffering and decompression", func(t *testing.T) { + // Test complete gzip buffering scenario + completeSSEResponse := `data: {"type":"message_start","message":{"usage":{"input_tokens":10,"output_tokens":0}}} + +data: {"type":"message_delta","usage":{"output_tokens":5}} + +data: {"type":"message_delta","usage":{"output_tokens":3}} + +data: [DONE] + +` + + // Create gzipped version + var gzipBuf bytes.Buffer + gzipWriter := gzip.NewWriter(&gzipBuf) + _, err := gzipWriter.Write([]byte(completeSSEResponse)) + require.NoError(t, err) + gzipWriter.Close() + gzippedData := gzipBuf.Bytes() + + // Split into chunks to simulate streaming + chunk1 := gzippedData[:len(gzippedData)/3] + chunk2 := gzippedData[len(gzippedData)/3 : 2*len(gzippedData)/3] + chunk3 := gzippedData[2*len(gzippedData)/3:] + + var buffer []byte + + // Test chunk 1 - should buffer + res1, err := decodeContentWithBuffering(chunk1, "gzip", &buffer, false) + require.NoError(t, err) + require.True(t, res1.isEncoded) + data1, _ := io.ReadAll(res1.reader) + require.Empty(t, data1) // Should return empty reader while buffering + require.Equal(t, chunk1, buffer) // Should be in buffer + + // Test chunk 2 - should continue buffering + res2, err := decodeContentWithBuffering(chunk2, "gzip", &buffer, false) + require.NoError(t, err) + require.True(t, res2.isEncoded) + data2, _ := io.ReadAll(res2.reader) + require.Empty(t, data2) // Should return empty reader while buffering + expectedBuffer := make([]byte, 0, len(chunk1)+len(chunk2)) + expectedBuffer = append(expectedBuffer, chunk1...) + expectedBuffer = append(expectedBuffer, chunk2...) + require.Equal(t, expectedBuffer, buffer) // Should accumulate in buffer + + // Test chunk 3 (endOfStream=true) - should decompress complete buffer + res3, err := decodeContentWithBuffering(chunk3, "gzip", &buffer, true) + require.NoError(t, err) + require.True(t, res3.isEncoded) + data3, _ := io.ReadAll(res3.reader) + require.Equal(t, []byte(completeSSEResponse), data3) // Should return decompressed data + require.Empty(t, buffer) // Buffer should be cleared + }) + + t.Run("invalid gzip header at end of stream", func(t *testing.T) { + var buffer []byte + invalidGzipData := []byte("not a gzip header") + + // Should pass through data when endOfStream=true and invalid gzip + res, err := decodeContentWithBuffering(invalidGzipData, "gzip", &buffer, true) + require.NoError(t, err) + require.True(t, res.isEncoded) + + data, _ := io.ReadAll(res.reader) + require.Equal(t, invalidGzipData, data) + require.Empty(t, buffer) // Buffer should be cleared + }) + + t.Run("decompression fails at end of stream", func(t *testing.T) { + var buffer []byte + // Create truncated gzip data that has valid header but incomplete content + var gzipBuf bytes.Buffer + gzipWriter := gzip.NewWriter(&gzipBuf) + _, err := gzipWriter.Write([]byte("test data")) + require.NoError(t, err) + gzipWriter.Close() + truncatedGzip := gzipBuf.Bytes()[:15] // Truncate to make decompression fail + + // Should pass through data when endOfStream=true and decompression fails + res, err := decodeContentWithBuffering(truncatedGzip, "gzip", &buffer, true) + require.NoError(t, err) + require.True(t, res.isEncoded) + + data, _ := io.ReadAll(res.reader) + require.Equal(t, truncatedGzip, data) + require.Empty(t, buffer) // Buffer should be cleared + }) + + t.Run("empty buffer case", func(t *testing.T) { + var buffer []byte + emptyBody := []byte{} + + // Should handle empty body gracefully + res, err := decodeContentWithBuffering(emptyBody, "gzip", &buffer, false) + require.NoError(t, err) + require.True(t, res.isEncoded) + + data, _ := io.ReadAll(res.reader) + require.Empty(t, data) // Should return empty reader for empty body + require.Empty(t, buffer) // Buffer should remain empty + }) + + t.Run("non-gzip encoding", func(t *testing.T) { + var buffer []byte + testData := []byte("plain text data") + + // Should pass through non-gzip data unchanged + res, err := decodeContentWithBuffering(testData, "deflate", &buffer, false) + require.NoError(t, err) + require.False(t, res.isEncoded) + + data, _ := io.ReadAll(res.reader) + require.Equal(t, testData, data) + require.Empty(t, buffer) // Buffer should remain empty for non-gzip + }) + + t.Run("empty encoding", func(t *testing.T) { + var buffer []byte + testData := []byte("plain text data") + + // Should pass through data with empty encoding + res, err := decodeContentWithBuffering(testData, "", &buffer, false) + require.NoError(t, err) + require.False(t, res.isEncoded) + + data, _ := io.ReadAll(res.reader) + require.Equal(t, testData, data) + require.Empty(t, buffer) // Buffer should remain empty for non-gzip + }) + + t.Run("invalid gzip header with endOfStream=false", func(t *testing.T) { + var buffer []byte + invalidGzipData := []byte("not a gzip header") + + // Should buffer and return empty reader when endOfStream=false and invalid gzip + res, err := decodeContentWithBuffering(invalidGzipData, "gzip", &buffer, false) + require.NoError(t, err) + require.True(t, res.isEncoded) + + data, _ := io.ReadAll(res.reader) + require.Empty(t, data) // Should return empty reader while buffering + require.Equal(t, invalidGzipData, buffer) // Should accumulate in buffer + }) +} + +func TestNonGzipPassthrough(t *testing.T) { + // Test that non-gzip data passes through unchanged + testData := []byte(`{"type":"message_start","usage":{"input_tokens":10}}`) + + res, err := decodeContentIfNeeded(testData, "") + require.NoError(t, err) + require.False(t, res.isEncoded) + + output, _ := io.ReadAll(res.reader) + require.Equal(t, testData, output) +} + func TestRemoveContentEncodingIfNeeded(t *testing.T) { tests := []struct { name string