Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions internal/extproc/messages_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
package extproc

import (
"bytes"
"cmp"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"

corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
Expand Down Expand Up @@ -147,6 +149,21 @@ type messagesProcessorUpstreamFilter struct {
stream bool
metrics metrics.ChatCompletionMetrics
costs translator.LLMTokenUsage
gzipBuffer []byte
}

// createResponseBodyResponse creates a ProcessingResponse for response body processing.
func createResponseBodyResponse(headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation) *extprocv3.ProcessingResponse {
return &extprocv3.ProcessingResponse{
Response: &extprocv3.ProcessingResponse_ResponseBody{
ResponseBody: &extprocv3.BodyResponse{
Response: &extprocv3.CommonResponse{
HeaderMutation: headerMutation,
BodyMutation: bodyMutation,
},
},
},
}
}

// selectTranslator selects the translator based on the output schema.
Expand Down Expand Up @@ -266,34 +283,35 @@ func (c *messagesProcessorUpstreamFilter) ProcessResponseBody(ctx context.Contex
}
}()

// Decompress the body if needed using common utility.
decodingResult, err := decodeContentIfNeeded(body.Body, c.responseEncoding)
// Decompress response body with buffering support
decodingResult, err := decodeContentWithBuffering(body.Body, c.responseEncoding, &c.gzipBuffer, body.EndOfStream)
if err != nil {
return nil, err
}

// headerMutation, bodyMutation, tokenUsage, err := c.translator.ResponseBody(c.responseHeaders, br, body.EndOfStream).
headerMutation, bodyMutation, tokenUsage, responseModel, err := c.translator.ResponseBody(c.responseHeaders, decodingResult.reader, body.EndOfStream)
// Check if we got decompressed data or are still buffering
data, _ := io.ReadAll(decodingResult.reader)
if len(data) == 0 && c.stream && !body.EndOfStream {
// Still buffering incomplete data - return early with no mutations, skip metrics
return createResponseBodyResponse(nil, nil), nil
}

// Process the decompressed data
decodingResult.reader = bytes.NewReader(data)
headerMutation, bodyMutation, tokenUsage, responseModel, err := c.translator.ResponseBody(c.responseHeaders, decodingResult.reader, c.stream)
if err != nil {
return nil, fmt.Errorf("failed to transform response: %w", err)
}

c.metrics.SetResponseModel(responseModel)

// Remove content-encoding header if original body encoded but was mutated in the processor.
headerMutation = removeContentEncodingIfNeeded(headerMutation, bodyMutation, decodingResult.isEncoded)

resp := &extprocv3.ProcessingResponse{
Response: &extprocv3.ProcessingResponse_ResponseBody{
ResponseBody: &extprocv3.BodyResponse{
Response: &extprocv3.CommonResponse{
HeaderMutation: headerMutation,
BodyMutation: bodyMutation,
},
},
},
if !c.stream || body.EndOfStream {
headerMutation = removeContentEncodingIfNeeded(headerMutation, bodyMutation, decodingResult.isEncoded)
}

resp := createResponseBodyResponse(headerMutation, bodyMutation)

c.costs.InputTokens += tokenUsage.InputTokens
c.costs.OutputTokens += tokenUsage.OutputTokens
c.costs.TotalTokens += tokenUsage.TotalTokens
Expand Down
34 changes: 14 additions & 20 deletions internal/extproc/translator/anthropic_gcpanthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseHeaders(_ map[string]string)

// ResponseBody implements [AnthropicMessagesTranslator.ResponseBody] for Anthropic to GCP Anthropic.
// This is essentially a passthrough since both use the same Anthropic response format.
func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, body io.Reader, endOfStream bool) (
func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, body io.Reader, isStreaming bool) (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why renaming this? endOfStream is different from isStreaming

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question, but I indeed mean isStreaming rather than endOfStream basically we have 2 cases:

  1. translation for streaming
  2. translation for regular request

This boolean basically tells which one to use.

Once you raised this question let me elaborate. Previously we would do translation while it's streaming, now because we accumulate we don't care about endOfStream. Let me know if that makes sense

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a comment so other maintainers are aware of the difference?

headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tokenUsage LLMTokenUsage, responseModel string, err error,
) {
// Read the response body for both streaming and non-streaming.
Expand All @@ -99,8 +99,8 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, bo
return nil, nil, LLMTokenUsage{}, "", fmt.Errorf("failed to read response body: %w", err)
}

// For streaming chunks, parse SSE format to extract token usage.
if !endOfStream {
// For streaming requests, parse SSE format to extract token usage.
if isStreaming {
// Parse SSE format - split by lines and look for data: lines.
for line := range bytes.Lines(bodyBytes) {
line = bytes.TrimSpace(line)
Expand All @@ -117,14 +117,16 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, bo
switch eventType {
case "message_start":
// Extract input tokens from message.usage.
// Now handles complete response with potentially multiple events.
if messageData, ok := eventData["message"].(map[string]any); ok {
if usageData, ok := messageData["usage"].(map[string]any); ok {
if inputTokens, ok := usageData["input_tokens"].(float64); ok {
tokenUsage.InputTokens = uint32(inputTokens) //nolint:gosec
// Accumulate input tokens (though typically only one message_start per conversation)
tokenUsage.InputTokens += uint32(inputTokens) //nolint:gosec
}
// Some message_start events may include initial output tokens.
if outputTokens, ok := usageData["output_tokens"].(float64); ok && outputTokens > 0 {
tokenUsage.OutputTokens = uint32(outputTokens) //nolint:gosec
tokenUsage.OutputTokens += uint32(outputTokens) //nolint:gosec
}
tokenUsage.TotalTokens = tokenUsage.InputTokens + tokenUsage.OutputTokens
}
Expand All @@ -143,18 +145,16 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, bo
}
}

return nil, &extprocv3.BodyMutation{
Mutation: &extprocv3.BodyMutation_Body{Body: bodyBytes},
}, tokenUsage, a.requestModel, nil
// For streaming responses, we only extract token usage, don't modify the body
// Return nil bodyMutation to pass through original data (potentially gzipped)
return nil, nil, tokenUsage, a.requestModel, nil
}

// Parse the Anthropic response to extract token usage.
var anthropicResp anthropic.Message
if err = json.Unmarshal(bodyBytes, &anthropicResp); err != nil {
// If we can't parse as Anthropic format, pass through as-is.
return nil, &extprocv3.BodyMutation{
Mutation: &extprocv3.BodyMutation_Body{Body: bodyBytes},
}, LLMTokenUsage{}, a.requestModel, nil
// If we can't parse as Anthropic format, pass through as-is without modification.
return nil, nil, LLMTokenUsage{}, a.requestModel, nil
}

// Extract token usage from the response.
Expand All @@ -164,12 +164,6 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, bo
TotalTokens: uint32(anthropicResp.Usage.InputTokens + anthropicResp.Usage.OutputTokens), //nolint:gosec
}

// Pass through the response body unchanged since both input and output are Anthropic format.
headerMutation = &extprocv3.HeaderMutation{}
setContentLength(headerMutation, bodyBytes)
bodyMutation = &extprocv3.BodyMutation{
Mutation: &extprocv3.BodyMutation_Body{Body: bodyBytes},
}

return headerMutation, bodyMutation, tokenUsage, a.requestModel, nil
// Pass through the response body unchanged - don't create body mutation to preserve original encoding.
return nil, nil, tokenUsage, a.requestModel, nil
}
12 changes: 6 additions & 6 deletions internal/extproc/translator/anthropic_gcpanthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,12 +569,12 @@ func TestAnthropicToGCPAnthropicTranslator_ResponseBody_StreamingTokenUsage(t *t
bodyReader := bytes.NewReader([]byte(tt.chunk))
respHeaders := map[string]string{"content-type": "application/json"}

headerMutation, bodyMutation, tokenUsage, _, err := translator.ResponseBody(respHeaders, bodyReader, tt.endOfStream)
// These tests are for streaming SSE chunks, so use isStreaming=true
headerMutation, bodyMutation, tokenUsage, _, err := translator.ResponseBody(respHeaders, bodyReader, true)

require.NoError(t, err)
require.Nil(t, headerMutation)
require.NotNil(t, bodyMutation)
require.Equal(t, tt.expectedBody, string(bodyMutation.GetBody()))
require.Nil(t, bodyMutation) // No body mutation to preserve original encoding
require.Equal(t, tt.expectedUsage, tokenUsage)
})
}
Expand Down Expand Up @@ -649,12 +649,12 @@ func TestAnthropicToGCPAnthropicTranslator_ResponseBody_StreamingEdgeCases(t *te
bodyReader := bytes.NewReader([]byte(tt.chunk))
respHeaders := map[string]string{"content-type": "application/json"}

headerMutation, bodyMutation, tokenUsage, _, err := translator.ResponseBody(respHeaders, bodyReader, false)
// These are streaming edge case tests, so use isStreaming=true
headerMutation, bodyMutation, tokenUsage, _, err := translator.ResponseBody(respHeaders, bodyReader, true)

require.NoError(t, err)
require.Nil(t, headerMutation)
require.NotNil(t, bodyMutation)
require.Equal(t, tt.chunk, string(bodyMutation.GetBody()))
require.Nil(t, bodyMutation) // No body mutation to preserve original encoding
require.Equal(t, tt.expectedUsage, tokenUsage)
})
}
Expand Down
71 changes: 71 additions & 0 deletions internal/extproc/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"compress/gzip"
"fmt"
"io"
"log/slog"

extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
)
Expand Down Expand Up @@ -42,6 +43,76 @@ func decodeContentIfNeeded(body []byte, contentEncoding string) (contentDecoding
}
}

// tryDecompressGzipBuffer attempts to decompress a gzip buffer.
// Returns the decompressed data or an error if decompression fails.
func tryDecompressGzipBuffer(buffer []byte) ([]byte, error) {
gzipReader, err := gzip.NewReader(bytes.NewReader(buffer))
if err != nil {
return nil, fmt.Errorf("gzip header: %w", err)
}
defer gzipReader.Close()

decompressed, err := io.ReadAll(gzipReader)
if err != nil {
return nil, fmt.Errorf("gzip decompression: %w", err)
}

return decompressed, nil
}

// decodeContentWithBuffering decompresses response body with buffering support for streaming.
// Accumulates chunks in the provided buffer until complete gzip data is available.
// Returns a reader for the (potentially decompressed) body and metadata about the encoding.
func decodeContentWithBuffering(body []byte, contentEncoding string, gzipBuffer *[]byte, endOfStream bool) (contentDecodingResult, error) {
switch contentEncoding {
case "gzip":
// Accumulate chunks in buffer
*gzipBuffer = append(*gzipBuffer, body...)

// Try to decompress the accumulated buffer
if len(*gzipBuffer) > 0 {
decompressedBody, err := tryDecompressGzipBuffer(*gzipBuffer)
if err != nil {
// If it's not endOfStream, keep buffering
if !endOfStream {
return contentDecodingResult{
reader: bytes.NewReader(nil), // Empty reader to signal buffering in progress
isEncoded: true,
}, nil
}
// If endOfStream and decompression failed, pass through buffered data
slog.Info("gzip buffering: decompression failed at end of stream, passing through buffered data",
"error", err,
"buffer_size", len(*gzipBuffer))
result := contentDecodingResult{
reader: bytes.NewReader(*gzipBuffer),
isEncoded: true,
}
*gzipBuffer = nil // Clear buffer
return result, nil
}

// Successfully decompressed!
*gzipBuffer = nil // Clear buffer
return contentDecodingResult{
reader: bytes.NewReader(decompressedBody),
isEncoded: true,
}, nil
}

// Empty buffer, return empty
return contentDecodingResult{
reader: bytes.NewReader(nil), // Empty reader for empty buffer
isEncoded: true,
}, nil
default:
return contentDecodingResult{
reader: bytes.NewReader(body),
isEncoded: false,
}, nil
}
}

// removeContentEncodingIfNeeded removes the content-encoding header if the body was modified and was encoded.
// This is needed when the transformation modifies the body content but the response was originally compressed.
func removeContentEncodingIfNeeded(headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, isEncoded bool) *extprocv3.HeaderMutation {
Expand Down
Loading
Loading