Skip to content

Commit 4e7d496

Browse files
authored
feat: parse out thought summary from gemini models' response when include thought is true (#1521)
**Description** When `includeThought` is true, Gemini would also generate summary of thinking process. We need to parse out this kind of data to users. Otherwise, we would return thought process together with output to users. Depends/base or replace #1461 --------- Signed-off-by: yxia216 <[email protected]>
1 parent 96a2fdd commit 4e7d496

File tree

7 files changed

+185
-53
lines changed

7 files changed

+185
-53
lines changed

internal/apischema/openai/openai.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,7 +1444,7 @@ type ChatCompletionResponseChunkChoiceDelta struct {
14441444
Role string `json:"role,omitempty"`
14451445
ToolCalls []ChatCompletionChunkChoiceDeltaToolCall `json:"tool_calls,omitempty"`
14461446
Annotations *[]Annotation `json:"annotations,omitempty"`
1447-
ReasoningContent *AWSBedrockStreamReasoningContent `json:"reasoning_content,omitempty"`
1447+
ReasoningContent *StreamReasoningContent `json:"reasoning_content,omitempty"`
14481448
}
14491449

14501450
// Error is described in the OpenAI API documentation
@@ -1662,7 +1662,7 @@ func (r *ReasoningContentUnion) UnmarshalJSON(data []byte) error {
16621662
return nil
16631663
}
16641664

1665-
var content *AWSBedrockReasoningContent
1665+
var content *ReasoningContent
16661666
err = json.Unmarshal(data, &content)
16671667
if err == nil {
16681668
r.Value = content
@@ -1675,19 +1675,20 @@ func (r ReasoningContentUnion) MarshalJSON() ([]byte, error) {
16751675
if stringContent, ok := r.Value.(string); ok {
16761676
return json.Marshal(stringContent)
16771677
}
1678-
if reasoningContent, ok := r.Value.(*AWSBedrockReasoningContent); ok {
1678+
if reasoningContent, ok := r.Value.(*ReasoningContent); ok {
16791679
return json.Marshal(reasoningContent)
16801680
}
16811681

16821682
return nil, errors.New("no reasoning content to marshal")
16831683
}
16841684

1685-
type AWSBedrockReasoningContent struct {
1685+
// ReasoningContent is used on both aws bedrock and gemini's reasoning
1686+
type ReasoningContent struct {
16861687
// See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html for more information.
16871688
ReasoningContent *awsbedrock.ReasoningContentBlock `json:"reasoningContent,omitzero"`
16881689
}
16891690

1690-
type AWSBedrockStreamReasoningContent struct {
1691+
type StreamReasoningContent struct {
16911692
Text string `json:"text,omitzero"`
16921693
Signature string `json:"signature,omitzero"`
16931694
RedactedContent []byte `json:"redactedContent,omitzero"`

internal/translator/gemini_helper.go

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
openaisdk "github.com/openai/openai-go/v2"
1919
"google.golang.org/genai"
2020

21+
"github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock"
2122
"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
2223
"github.com/envoyproxy/ai-gateway/internal/internalapi"
2324
)
@@ -640,9 +641,22 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode
640641
message := openai.ChatCompletionResponseChoiceMessage{
641642
Role: openai.ChatMessageRoleAssistant,
642643
}
643-
// Extract text from parts.
644-
content := extractTextFromGeminiParts(candidate.Content.Parts, responseMode)
645-
message.Content = &content
644+
// Extract thought summary and text from parts.
645+
thoughtSummary, content := extractTextAndThoughtSummaryFromGeminiParts(candidate.Content.Parts, responseMode)
646+
if thoughtSummary != "" {
647+
message.ReasoningContent = &openai.ReasoningContentUnion{
648+
Value: &openai.ReasoningContent{
649+
ReasoningContent: &awsbedrock.ReasoningContentBlock{
650+
ReasoningText: &awsbedrock.ReasoningTextBlock{
651+
Text: thoughtSummary,
652+
},
653+
},
654+
},
655+
}
656+
}
657+
if content != "" {
658+
message.Content = &content
659+
}
646660

647661
// Extract tool calls if any.
648662
toolCalls, err = extractToolCallsFromGeminiParts(toolCalls, candidate.Content.Parts)
@@ -657,6 +671,7 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode
657671
}
658672

659673
choice.Message = message
674+
660675
}
661676

662677
if candidate.SafetyRatings != nil {
@@ -704,23 +719,28 @@ func geminiFinishReasonToOpenAI[T toolCallSlice](reason genai.FinishReason, tool
704719
}
705720
}
706721

707-
// extractTextFromGeminiParts extracts text from Gemini parts.
708-
func extractTextFromGeminiParts(parts []*genai.Part, responseMode geminiResponseMode) string {
709-
var text string
722+
// extractTextAndThoughtSummaryFromGeminiParts extracts thought summary and text from Gemini parts.
723+
func extractTextAndThoughtSummaryFromGeminiParts(parts []*genai.Part, responseMode geminiResponseMode) (string, string) {
724+
text := ""
725+
thoughtSummary := ""
710726
for _, part := range parts {
711727
if part != nil && part.Text != "" {
712-
if responseMode == responseModeRegex {
713-
// GCP doesn't natively support REGEX response modes, so we instead express them as json schema.
714-
// This causes the response to be wrapped in double-quotes.
715-
// E.g. `"positive"` (the double-quotes at the start and end are unwanted)
716-
// Here we remove the wrapping double-quotes.
717-
part.Text = strings.TrimPrefix(part.Text, "\"")
718-
part.Text = strings.TrimSuffix(part.Text, "\"")
728+
if part.Thought {
729+
thoughtSummary += part.Text
730+
} else {
731+
if responseMode == responseModeRegex {
732+
// GCP doesn't natively support REGEX response modes, so we instead express them as json schema.
733+
// This causes the response to be wrapped in double-quotes.
734+
// E.g. `"positive"` (the double-quotes at the start and end are unwanted)
735+
// Here we remove the wrapping double-quotes.
736+
part.Text = strings.TrimPrefix(part.Text, "\"")
737+
part.Text = strings.TrimSuffix(part.Text, "\"")
738+
}
739+
text += part.Text
719740
}
720-
text += part.Text
721741
}
722742
}
723-
return text
743+
return thoughtSummary, text
724744
}
725745

726746
// extractToolCallsFromGeminiParts extracts tool calls from Gemini parts.

internal/translator/gemini_helper_test.go

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,33 +1829,37 @@ func TestGeminiFinishReasonToOpenAI(t *testing.T) {
18291829
}
18301830
}
18311831

1832-
func TestExtractTextFromGeminiParts(t *testing.T) {
1832+
func TestExtractTextAndThoughtSummaryFromGeminiParts(t *testing.T) {
18331833
tests := []struct {
1834-
name string
1835-
parts []*genai.Part
1836-
responseMode geminiResponseMode
1837-
expected string
1834+
name string
1835+
parts []*genai.Part
1836+
responseMode geminiResponseMode
1837+
expectedThoughtSummary string
1838+
expectedText string
18381839
}{
18391840
{
1840-
name: "nil parts",
1841-
parts: nil,
1842-
responseMode: responseModeNone,
1843-
expected: "",
1841+
name: "nil parts",
1842+
parts: nil,
1843+
responseMode: responseModeNone,
1844+
expectedThoughtSummary: "",
1845+
expectedText: "",
18441846
},
18451847
{
1846-
name: "empty parts",
1847-
parts: []*genai.Part{},
1848-
responseMode: responseModeNone,
1849-
expected: "",
1848+
name: "empty parts",
1849+
parts: []*genai.Part{},
1850+
responseMode: responseModeNone,
1851+
expectedThoughtSummary: "",
1852+
expectedText: "",
18501853
},
18511854
{
18521855
name: "multiple text parts without regex mode",
18531856
parts: []*genai.Part{
18541857
{Text: "Hello, "},
18551858
{Text: "world!"},
18561859
},
1857-
responseMode: responseModeJSON,
1858-
expected: "Hello, world!",
1860+
responseMode: responseModeJSON,
1861+
expectedThoughtSummary: "",
1862+
expectedText: "Hello, world!",
18591863
},
18601864
{
18611865
name: "regex mode with mixed quoted and unquoted text",
@@ -1864,40 +1868,56 @@ func TestExtractTextFromGeminiParts(t *testing.T) {
18641868
{Text: `unquoted`},
18651869
{Text: `"negative"`},
18661870
},
1867-
responseMode: responseModeRegex,
1868-
expected: "positiveunquotednegative",
1871+
responseMode: responseModeRegex,
1872+
expectedThoughtSummary: "",
1873+
expectedText: "positiveunquotednegative",
18691874
},
18701875
{
18711876
name: "regex mode with only double-quoted first and last words",
18721877
parts: []*genai.Part{
18731878
{Text: "\"\"ERROR\" Unable to connect to database \"DatabaseModule\"\""},
18741879
},
1875-
responseMode: responseModeRegex,
1876-
expected: "\"ERROR\" Unable to connect to database \"DatabaseModule\"",
1880+
responseMode: responseModeRegex,
1881+
expectedThoughtSummary: "",
1882+
expectedText: "\"ERROR\" Unable to connect to database \"DatabaseModule\"",
18771883
},
18781884
{
18791885
name: "non-regex mode with double-quoted text (should not remove quotes)",
18801886
parts: []*genai.Part{
18811887
{Text: `"positive"`},
18821888
},
1883-
responseMode: responseModeJSON,
1884-
expected: `"positive"`,
1889+
responseMode: responseModeJSON,
1890+
expectedThoughtSummary: "",
1891+
expectedText: `"positive"`,
18851892
},
18861893
{
18871894
name: "regex mode with text containing internal quotes",
18881895
parts: []*genai.Part{
18891896
{Text: `"He said \"hello\" to me"`},
18901897
},
1891-
responseMode: responseModeRegex,
1892-
expected: `He said \"hello\" to me`,
1898+
responseMode: responseModeRegex,
1899+
expectedThoughtSummary: "",
1900+
expectedText: `He said \"hello\" to me`,
1901+
},
1902+
{
1903+
name: "test thought summary",
1904+
parts: []*genai.Part{
1905+
{Text: "Let me think step by step", Thought: true},
1906+
{Text: "Here is the conclusion"},
1907+
},
1908+
expectedThoughtSummary: "Let me think step by step",
1909+
expectedText: "Here is the conclusion",
18931910
},
18941911
}
18951912

18961913
for _, tc := range tests {
18971914
t.Run(tc.name, func(t *testing.T) {
1898-
result := extractTextFromGeminiParts(tc.parts, tc.responseMode)
1899-
if result != tc.expected {
1900-
t.Errorf("extractTextFromGeminiParts() = %q, want %q", result, tc.expected)
1915+
thoughtSummary, text := extractTextAndThoughtSummaryFromGeminiParts(tc.parts, tc.responseMode)
1916+
if thoughtSummary != tc.expectedThoughtSummary {
1917+
t.Errorf("thought summary result of extractTextAndThoughtSummaryFromGeminiParts() = %q, want %q", thoughtSummary, tc.expectedText)
1918+
}
1919+
if text != tc.expectedText {
1920+
t.Errorf("text result of extractTextAndThoughtSummaryFromGeminiParts() = %q, want %q", text, tc.expectedText)
19011921
}
19021922
})
19031923
}

internal/translator/openai_awsbedrock.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(_ map[string
702702
}
703703
case output.ReasoningContent != nil:
704704
choice.Message.ReasoningContent = &openai.ReasoningContentUnion{
705-
Value: &openai.AWSBedrockReasoningContent{
705+
Value: &openai.ReasoningContent{
706706
ReasoningContent: output.ReasoningContent,
707707
},
708708
}
@@ -819,7 +819,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) convertEvent(event *awsbe
819819
},
820820
})
821821
case event.Delta.ReasoningContent != nil:
822-
reasoningDelta := &openai.AWSBedrockStreamReasoningContent{}
822+
reasoningDelta := &openai.StreamReasoningContent{}
823823

824824
// Map all relevant fields from the Bedrock delta to our flattened OpenAI delta struct.
825825
if event.Delta.ReasoningContent != nil {

internal/translator/openai_awsbedrock_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,7 +1665,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T)
16651665
Role: awsbedrock.ConversationRoleAssistant,
16661666
Content: ptr.To("This is the final answer."),
16671667
ReasoningContent: &openai.ReasoningContentUnion{
1668-
Value: &openai.AWSBedrockReasoningContent{
1668+
Value: &openai.ReasoningContent{
16691669
ReasoningContent: &awsbedrock.ReasoningContentBlock{
16701670
ReasoningText: &awsbedrock.ReasoningTextBlock{
16711671
Text: "This is the model's thought process.",
@@ -1990,7 +1990,7 @@ func TestOpenAIToAWSBedrockTranslator_convertEvent(t *testing.T) {
19901990
Choices: []openai.ChatCompletionResponseChunkChoice{
19911991
{
19921992
Delta: &openai.ChatCompletionResponseChunkChoiceDelta{
1993-
ReasoningContent: &openai.AWSBedrockStreamReasoningContent{
1993+
ReasoningContent: &openai.StreamReasoningContent{
19941994
Text: "thinking...",
19951995
},
19961996
},
@@ -2171,7 +2171,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody_WithReasoning
21712171
require.Equal(t, "9.11 is greater than 9.8.", *message.Content)
21722172

21732173
require.NotNil(t, message.ReasoningContent, "Reasoning content should not be nil")
2174-
reasoningBlock, _ := message.ReasoningContent.Value.(*openai.AWSBedrockReasoningContent)
2174+
reasoningBlock, _ := message.ReasoningContent.Value.(*openai.ReasoningContent)
21752175
require.NotNil(t, reasoningBlock, "The nested reasoning content block should not be nil")
21762176
require.NotEmpty(t, reasoningBlock.ReasoningContent.ReasoningText.Text, "The reasoning text itself should not be empty")
21772177

internal/translator/openai_gcpvertexai.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,14 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) geminiCandidatesToOpenAI
344344
Role: openai.ChatMessageRoleAssistant,
345345
}
346346

347-
// Extract text from parts for streaming (delta).
348-
content := extractTextFromGeminiParts(candidate.Content.Parts, responseMode)
347+
// Extract thought summary and text from parts for streaming (delta).
348+
thoughtSummary, content := extractTextAndThoughtSummaryFromGeminiParts(candidate.Content.Parts, responseMode)
349+
if thoughtSummary != "" {
350+
delta.ReasoningContent = &openai.StreamReasoningContent{
351+
Text: thoughtSummary,
352+
}
353+
}
354+
349355
if content != "" {
350356
delta.Content = &content
351357
}

internal/translator/openai_gcpvertexai_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,91 @@ data: [DONE]
10991099
}`),
11001100
wantTokenUsage: tokenUsageFrom(8, 0, 12, 20),
11011101
},
1102+
{
1103+
name: "response with thought summary",
1104+
respHeaders: map[string]string{
1105+
"content-type": "application/json",
1106+
},
1107+
body: `{
1108+
"candidates": [
1109+
{
1110+
"content": {
1111+
"parts": [
1112+
{
1113+
"text": "Let me think step by step.",
1114+
"thought": true
1115+
},
1116+
{
1117+
"text": "AI Gateways act as intermediaries between clients and LLM services."
1118+
}
1119+
]
1120+
},
1121+
"finishReason": "STOP",
1122+
"safetyRatings": []
1123+
}
1124+
],
1125+
"promptFeedback": {
1126+
"safetyRatings": []
1127+
},
1128+
"usageMetadata": {
1129+
"promptTokenCount": 10,
1130+
"candidatesTokenCount": 15,
1131+
"totalTokenCount": 25,
1132+
"cachedContentTokenCount": 10,
1133+
"thoughtsTokenCount": 10
1134+
}
1135+
}`,
1136+
endOfStream: true,
1137+
wantError: false,
1138+
wantHeaderMut: []internalapi.Header{{contentLengthHeaderName, "450"}},
1139+
wantBodyMut: []byte(`{
1140+
"choices": [
1141+
{
1142+
"finish_reason": "stop",
1143+
"index": 0,
1144+
"message": {
1145+
"content": "AI Gateways act as intermediaries between clients and LLM services.",
1146+
"reasoning_content": {"reasoningContent": {"reasoningText": {"text": "Let me think step by step."}}},
1147+
"role": "assistant"
1148+
}
1149+
}
1150+
],
1151+
"object": "chat.completion",
1152+
"usage": {
1153+
"completion_tokens": 25,
1154+
"completion_tokens_details": {
1155+
"reasoning_tokens": 10
1156+
},
1157+
"prompt_tokens": 10,
1158+
"prompt_tokens_details": {
1159+
"cached_tokens": 10
1160+
},
1161+
"total_tokens": 25
1162+
}
1163+
}`),
1164+
1165+
wantTokenUsage: tokenUsageFrom(10, 10, 15, 25),
1166+
},
1167+
{
1168+
name: "stream chunks with thought summary",
1169+
respHeaders: map[string]string{
1170+
"content-type": "application/json",
1171+
},
1172+
body: `data: {"candidates":[{"content":{"parts":[{"text":"let me think step by step and reply you.", "thought": true}]}}]}
1173+
1174+
data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3,"totalTokenCount":8}}`,
1175+
stream: true,
1176+
endOfStream: true,
1177+
wantError: false,
1178+
wantHeaderMut: nil,
1179+
wantBodyMut: []byte(`data: {"choices":[{"index":0,"delta":{"role":"assistant","reasoning_content":{"text":"let me think step by step and reply you."}}}],"object":"chat.completion.chunk"}
1180+
1181+
data: {"choices":[{"index":0,"delta":{"content":"Hello","role":"assistant"}}],"object":"chat.completion.chunk","usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8,"completion_tokens_details":{},"prompt_tokens_details":{}}}
1182+
1183+
data: [DONE]
1184+
`),
1185+
wantTokenUsage: tokenUsageFrom(5, 0, 3, 8),
1186+
},
11021187
}
11031188

11041189
for _, tc := range tests {

0 commit comments

Comments
 (0)