Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions internal/extproc/messages_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
package extproc

import (
"bytes"
"cmp"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"

corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
Expand Down Expand Up @@ -147,6 +149,7 @@ type messagesProcessorUpstreamFilter struct {
stream bool
metrics metrics.ChatCompletionMetrics
costs translator.LLMTokenUsage
gzipBuffer []byte
}

// selectTranslator selects the translator based on the output schema.
Expand Down Expand Up @@ -266,22 +269,53 @@ func (c *messagesProcessorUpstreamFilter) ProcessResponseBody(ctx context.Contex
}
}()

// Decompress the body if needed using common utility.
decodingResult, err := decodeContentIfNeeded(body.Body, c.responseEncoding)
if err != nil {
return nil, err
}
var headerMutation *extprocv3.HeaderMutation
var bodyMutation *extprocv3.BodyMutation
var tokenUsage translator.LLMTokenUsage
var responseModel internalapi.ResponseModel

// headerMutation, bodyMutation, tokenUsage, err := c.translator.ResponseBody(c.responseHeaders, br, body.EndOfStream).
headerMutation, bodyMutation, tokenUsage, responseModel, err := c.translator.ResponseBody(c.responseHeaders, decodingResult.reader, body.EndOfStream)
if err != nil {
return nil, fmt.Errorf("failed to transform response: %w", err)
}
if c.stream && !body.EndOfStream {
// For streaming intermediate chunks: buffer data and check if decompression succeeded
decodingResult, err := decodeContentWithBuffering(body.Body, c.responseEncoding, &c.gzipBuffer, body.EndOfStream)
if err != nil {
return nil, err
}

// Check if we got decompressed data (successful buffering completion)
data, _ := io.ReadAll(decodingResult.reader)
if len(data) > 0 {
// Decompression succeeded! Process the complete response
decodingResult.reader = bytes.NewReader(data)
headerMutation, bodyMutation, tokenUsage, responseModel, err = c.translator.ResponseBody(c.responseHeaders, decodingResult.reader, c.stream)
if err != nil {
return nil, fmt.Errorf("failed to transform response: %w", err)
}
c.metrics.SetResponseModel(responseModel)
} else {
// Still buffering incomplete data - pass through with no mutations
headerMutation, bodyMutation = nil, nil
tokenUsage = translator.LLMTokenUsage{}
}
} else {
// For non-streaming OR final streaming chunk: decompress and translate
decodingResult, err := decodeContentWithBuffering(body.Body, c.responseEncoding, &c.gzipBuffer, body.EndOfStream)
if err != nil {
return nil, err
}

c.metrics.SetResponseModel(responseModel)
// Process the decompressed data
data, _ := io.ReadAll(decodingResult.reader)
decodingResult.reader = bytes.NewReader(data)
headerMutation, bodyMutation, tokenUsage, responseModel, err = c.translator.ResponseBody(c.responseHeaders, decodingResult.reader, c.stream)
if err != nil {
return nil, fmt.Errorf("failed to transform response: %w", err)
}

c.metrics.SetResponseModel(responseModel)

// Remove content-encoding header if original body encoded but was mutated in the processor.
headerMutation = removeContentEncodingIfNeeded(headerMutation, bodyMutation, decodingResult.isEncoded)
// Remove content-encoding header if original body encoded but was mutated in the processor.
headerMutation = removeContentEncodingIfNeeded(headerMutation, bodyMutation, decodingResult.isEncoded)
}

resp := &extprocv3.ProcessingResponse{
Response: &extprocv3.ProcessingResponse_ResponseBody{
Expand Down
34 changes: 14 additions & 20 deletions internal/extproc/translator/anthropic_gcpanthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseHeaders(_ map[string]string)

// ResponseBody implements [AnthropicMessagesTranslator.ResponseBody] for Anthropic to GCP Anthropic.
// This is essentially a passthrough since both use the same Anthropic response format.
func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, body io.Reader, endOfStream bool) (
func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, body io.Reader, isStreaming bool) (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why renaming this? endOfStream is different from isStreaming

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question, but I indeed mean isStreaming rather than endOfStream basically we have 2 cases:

  1. translation for streaming
  2. translation for regular request

This boolean basically tells which one to use.

Once you raised this question let me elaborate. Previously we would do translation while it's streaming, now because we accumulate we don't care about endOfStream. Let me know if that makes sense

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a comment so other maintainers are aware of the difference?

headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tokenUsage LLMTokenUsage, responseModel string, err error,
) {
// Read the response body for both streaming and non-streaming.
Expand All @@ -99,8 +99,8 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, bo
return nil, nil, LLMTokenUsage{}, "", fmt.Errorf("failed to read response body: %w", err)
}

// For streaming chunks, parse SSE format to extract token usage.
if !endOfStream {
// For streaming requests, parse SSE format to extract token usage.
if isStreaming {
// Parse SSE format - split by lines and look for data: lines.
for line := range bytes.Lines(bodyBytes) {
line = bytes.TrimSpace(line)
Expand All @@ -117,14 +117,16 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, bo
switch eventType {
case "message_start":
// Extract input tokens from message.usage.
// Now handles complete response with potentially multiple events.
if messageData, ok := eventData["message"].(map[string]any); ok {
if usageData, ok := messageData["usage"].(map[string]any); ok {
if inputTokens, ok := usageData["input_tokens"].(float64); ok {
tokenUsage.InputTokens = uint32(inputTokens) //nolint:gosec
// Accumulate input tokens (though typically only one message_start per conversation)
tokenUsage.InputTokens += uint32(inputTokens) //nolint:gosec
}
// Some message_start events may include initial output tokens.
if outputTokens, ok := usageData["output_tokens"].(float64); ok && outputTokens > 0 {
tokenUsage.OutputTokens = uint32(outputTokens) //nolint:gosec
tokenUsage.OutputTokens += uint32(outputTokens) //nolint:gosec
}
tokenUsage.TotalTokens = tokenUsage.InputTokens + tokenUsage.OutputTokens
}
Expand All @@ -143,18 +145,16 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, bo
}
}

return nil, &extprocv3.BodyMutation{
Mutation: &extprocv3.BodyMutation_Body{Body: bodyBytes},
}, tokenUsage, a.requestModel, nil
// For streaming responses, we only extract token usage, don't modify the body
// Return nil bodyMutation to pass through original data (potentially gzipped)
return nil, nil, tokenUsage, a.requestModel, nil
}

// Parse the Anthropic response to extract token usage.
var anthropicResp anthropic.Message
if err = json.Unmarshal(bodyBytes, &anthropicResp); err != nil {
// If we can't parse as Anthropic format, pass through as-is.
return nil, &extprocv3.BodyMutation{
Mutation: &extprocv3.BodyMutation_Body{Body: bodyBytes},
}, LLMTokenUsage{}, a.requestModel, nil
// If we can't parse as Anthropic format, pass through as-is without modification.
return nil, nil, LLMTokenUsage{}, a.requestModel, nil
}

// Extract token usage from the response.
Expand All @@ -164,12 +164,6 @@ func (a *anthropicToGCPAnthropicTranslator) ResponseBody(_ map[string]string, bo
TotalTokens: uint32(anthropicResp.Usage.InputTokens + anthropicResp.Usage.OutputTokens), //nolint:gosec
}

// Pass through the response body unchanged since both input and output are Anthropic format.
headerMutation = &extprocv3.HeaderMutation{}
setContentLength(headerMutation, bodyBytes)
bodyMutation = &extprocv3.BodyMutation{
Mutation: &extprocv3.BodyMutation_Body{Body: bodyBytes},
}

return headerMutation, bodyMutation, tokenUsage, a.requestModel, nil
// Pass through the response body unchanged - don't create body mutation to preserve original encoding.
return nil, nil, tokenUsage, a.requestModel, nil
}
12 changes: 6 additions & 6 deletions internal/extproc/translator/anthropic_gcpanthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,12 +569,12 @@ func TestAnthropicToGCPAnthropicTranslator_ResponseBody_StreamingTokenUsage(t *t
bodyReader := bytes.NewReader([]byte(tt.chunk))
respHeaders := map[string]string{"content-type": "application/json"}

headerMutation, bodyMutation, tokenUsage, _, err := translator.ResponseBody(respHeaders, bodyReader, tt.endOfStream)
// These tests are for streaming SSE chunks, so use isStreaming=true
headerMutation, bodyMutation, tokenUsage, _, err := translator.ResponseBody(respHeaders, bodyReader, true)

require.NoError(t, err)
require.Nil(t, headerMutation)
require.NotNil(t, bodyMutation)
require.Equal(t, tt.expectedBody, string(bodyMutation.GetBody()))
require.Nil(t, bodyMutation) // No body mutation to preserve original encoding
require.Equal(t, tt.expectedUsage, tokenUsage)
})
}
Expand Down Expand Up @@ -649,12 +649,12 @@ func TestAnthropicToGCPAnthropicTranslator_ResponseBody_StreamingEdgeCases(t *te
bodyReader := bytes.NewReader([]byte(tt.chunk))
respHeaders := map[string]string{"content-type": "application/json"}

headerMutation, bodyMutation, tokenUsage, _, err := translator.ResponseBody(respHeaders, bodyReader, false)
// These are streaming edge case tests, so use isStreaming=true
headerMutation, bodyMutation, tokenUsage, _, err := translator.ResponseBody(respHeaders, bodyReader, true)

require.NoError(t, err)
require.Nil(t, headerMutation)
require.NotNil(t, bodyMutation)
require.Equal(t, tt.chunk, string(bodyMutation.GetBody()))
require.Nil(t, bodyMutation) // No body mutation to preserve original encoding
require.Equal(t, tt.expectedUsage, tokenUsage)
})
}
Expand Down
77 changes: 77 additions & 0 deletions internal/extproc/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"compress/gzip"
"fmt"
"io"
"log/slog"

extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
)
Expand Down Expand Up @@ -42,6 +43,82 @@ func decodeContentIfNeeded(body []byte, contentEncoding string) (contentDecoding
}
}

// decodeContentWithBuffering decompresses response body with buffering support for streaming.
// Accumulates chunks in the provided buffer until complete gzip data is available.
// Returns a reader for the (potentially decompressed) body and metadata about the encoding.
func decodeContentWithBuffering(body []byte, contentEncoding string, gzipBuffer *[]byte, endOfStream bool) (contentDecodingResult, error) {
switch contentEncoding {
case "gzip":

// Accumulate chunks in buffer
*gzipBuffer = append(*gzipBuffer, body...)

// Try to decompress the accumulated buffer
if len(*gzipBuffer) > 0 {
gzipReader, err := gzip.NewReader(bytes.NewReader(*gzipBuffer))
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to check the specific error like "unexpected EOF" ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, I'm not really that good with gzip errors so not sure what are other error and whether we can buffer or cannot buffer them

// If it's not endOfStream, keep buffering
if !endOfStream {
return contentDecodingResult{
reader: bytes.NewReader(nil), // Empty reader to signal buffering in progress
isEncoded: true,
}, nil
}
// If endOfStream and still can't read, pass through buffered data
slog.Info("gzip buffering: invalid header at end of stream, passing through buffered data",
"error", err,
"buffer_size", len(*gzipBuffer))
result := contentDecodingResult{
reader: bytes.NewReader(*gzipBuffer),
isEncoded: true,
}
*gzipBuffer = nil // Clear buffer
return result, nil
}
defer gzipReader.Close()

decompressedBody, err := io.ReadAll(gzipReader)
if err != nil {
// If it's not endOfStream, keep buffering
if !endOfStream {
return contentDecodingResult{
reader: bytes.NewReader(nil), // Empty reader to signal buffering in progress
isEncoded: true,
}, nil
}
// If endOfStream and decompression failed, pass through buffered data
slog.Info("gzip buffering: decompression failed at end of stream, passing through buffered data",
"error", err,
"buffer_size", len(*gzipBuffer))
result := contentDecodingResult{
reader: bytes.NewReader(*gzipBuffer),
isEncoded: true,
}
*gzipBuffer = nil // Clear buffer
return result, nil
}

// Successfully decompressed!
*gzipBuffer = nil // Clear buffer
return contentDecodingResult{
reader: bytes.NewReader(decompressedBody),
isEncoded: true,
}, nil
}

// Empty buffer, return empty
return contentDecodingResult{
reader: bytes.NewReader(nil), // Empty reader for empty buffer
isEncoded: true,
}, nil
default:
return contentDecodingResult{
reader: bytes.NewReader(body),
isEncoded: false,
}, nil
}
}

// removeContentEncodingIfNeeded removes the content-encoding header if the body was modified and was encoded.
// This is needed when the transformation modifies the body content but the response was originally compressed.
func removeContentEncodingIfNeeded(headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, isEncoded bool) *extprocv3.HeaderMutation {
Expand Down
Loading
Loading