diff --git a/internal/llminternal/contents_processor_test.go b/internal/llminternal/contents_processor_test.go index 4220d21ff..f3f66ba33 100644 --- a/internal/llminternal/contents_processor_test.go +++ b/internal/llminternal/contents_processor_test.go @@ -15,6 +15,7 @@ package llminternal_test import ( + "bytes" "iter" "slices" "strings" @@ -241,6 +242,54 @@ func TestContentsRequestProcessor_IncludeContents(t *testing.T) { } } +func TestContentsRequestProcessor_PreservesThoughtSignature(t *testing.T) { + const agentName = "testAgent" + testModel := &testModel{} + + sig := []byte("signature-bytes") + part := genai.NewPartFromFunctionCall("do_work", map[string]any{"a": "b"}) + part.FunctionCall.ID = "call-1" + part.ThoughtSignature = append([]byte(nil), sig...) + + events := []*session.Event{ + { + Author: agentName, + LLMResponse: model.LLMResponse{ + Content: &genai.Content{ + Role: "model", + Parts: []*genai.Part{part}, + }, + }, + }, + } + + testAgent := utils.Must(llmagent.New(llmagent.Config{ + Name: agentName, + Model: testModel, + })) + + ctx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{ + Agent: testAgent, + Session: &fakeSession{ + events: events, + }, + }) + + req := &model.LLMRequest{} + if err := llminternal.ContentsRequestProcessor(ctx, req); err != nil { + t.Fatalf("ContentsRequestProcessor failed: %v", err) + } + + if len(req.Contents) != 1 || len(req.Contents[0].Parts) != 1 { + t.Fatalf("unexpected contents length: got %d parts %d", len(req.Contents), len(req.Contents[0].Parts)) + } + + gotSig := req.Contents[0].Parts[0].ThoughtSignature + if !bytes.Equal(gotSig, sig) { + t.Fatalf("thought signature mismatch: got %v want %v", gotSig, sig) + } +} + func TestContentsRequestProcessor(t *testing.T) { const agentName = "testAgent" testModel := &testModel{} diff --git a/internal/llminternal/stream_aggregator.go b/internal/llminternal/stream_aggregator.go index 2b1865451..ee30a0ce4 100644 --- a/internal/llminternal/stream_aggregator.go +++ b/internal/llminternal/stream_aggregator.go @@ -19,6 +19,8 @@ import ( "fmt" "iter" "reflect" + "strconv" + "strings" "google.golang.org/genai" @@ -30,10 +32,19 @@ import ( // It aggregates content from partial responses, and generates LlmResponses for // individual (partial) model responses, as well as for aggregated content. type streamingResponseAggregator struct { - text string - thoughtText string - response *model.LLMResponse - role string + response *model.LLMResponse + role string + + textParts []*genai.Part + currentTextBuffer string + currentTextIsThought *bool + + currentFunctionCalls map[string]*functionCallState + activeFunctionCallOrder []string + activeFunctionCallKeysByName map[string][]string + lastFunctionCallKey string + unnamedSequence int + unnamedCursor int } // NewStreamingResponseAggregator creates a new, initialized streamingResponseAggregator. @@ -71,35 +82,53 @@ func (s *streamingResponseAggregator) ProcessResponse(ctx context.Context, genRe func (s *streamingResponseAggregator) aggregateResponse(llmResponse *model.LLMResponse) *model.LLMResponse { s.response = llmResponse - var part0 *genai.Part - if llmResponse.Content != nil && len(llmResponse.Content.Parts) > 0 { - part0 = llmResponse.Content.Parts[0] + if llmResponse.Content != nil { s.role = llmResponse.Content.Role } - // If part is text append it - if part0 != nil && part0.Text != "" { - if part0.Thought { - s.thoughtText += part0.Text - } else { - s.text += part0.Text + if llmResponse.Content == nil || len(llmResponse.Content.Parts) == 0 { + if s.hasPendingTextParts() { + return s.createAggregateResponse() } - llmResponse.Partial = true return nil } - // gemini 3 in streaming returns a last response with an empty part. We need to filter it out. - if part0 != nil && reflect.ValueOf(*part0).IsZero() { - llmResponse.Partial = true - return nil + parts := llmResponse.Content.Parts + sawNonEmptyText := false + sawFunctionCall := false + sawInlineData := false + + for _, part := range parts { + if part == nil { + continue + } + + if part.FunctionCall != nil { + sawFunctionCall = true + s.flushTextBuffer() + s.handleFunctionCall(part, llmResponse) + continue + } + + if part.Text != "" || len(part.ThoughtSignature) > 0 { + if part.Text != "" { + sawNonEmptyText = true + } + s.handleTextPart(part) + llmResponse.Partial = true + continue + } + + if reflect.ValueOf(*part).IsZero() { + llmResponse.Partial = true + continue + } + + sawInlineData = true + s.flushTextBuffer() } - // If there is aggregated text and there is no content or parts return aggregated response - if (s.thoughtText != "" || s.text != "") && - (llmResponse.Content == nil || - len(llmResponse.Content.Parts) == 0 || - // don't yield the merged text event when receiving audio data - (len(llmResponse.Content.Parts) > 0 && llmResponse.Content.Parts[0].InlineData == nil)) { + if s.hasPendingTextParts() && (sawInlineData || (!sawNonEmptyText && !sawFunctionCall)) { return s.createAggregateResponse() } @@ -109,37 +138,579 @@ func (s *streamingResponseAggregator) aggregateResponse(llmResponse *model.LLMRe // Close generates an aggregated response at the end, if needed, // this should be called after all the model responses are processed. func (s *streamingResponseAggregator) Close() *model.LLMResponse { - return s.createAggregateResponse() + if resp := s.createAggregateResponse(); resp != nil { + return resp + } + if resp := s.createPendingFunctionCallResponse(); resp != nil { + return resp + } + s.clearTextBuffers() + return nil } func (s *streamingResponseAggregator) createAggregateResponse() *model.LLMResponse { - if (s.text != "" || s.thoughtText != "") && s.response != nil { - var parts []*genai.Part - if s.thoughtText != "" { - parts = append(parts, &genai.Part{Text: s.thoughtText, Thought: true}) + s.flushTextBuffer() + if len(s.textParts) == 0 || s.response == nil { + return nil + } + + parts := make([]*genai.Part, len(s.textParts)) + copy(parts, s.textParts) + + response := &model.LLMResponse{ + Content: &genai.Content{Parts: parts, Role: s.role}, + ErrorCode: s.response.ErrorCode, + ErrorMessage: s.response.ErrorMessage, + UsageMetadata: s.response.UsageMetadata, + GroundingMetadata: s.response.GroundingMetadata, + FinishReason: s.response.FinishReason, + } + s.clearTextBuffers() + return response +} + +func (s *streamingResponseAggregator) clearTextBuffers() { + s.response = nil + s.textParts = nil + s.currentTextBuffer = "" + s.currentTextIsThought = nil + s.role = "" +} + +func (s *streamingResponseAggregator) handleTextPart(part *genai.Part) { + if len(part.ThoughtSignature) > 0 { + s.flushTextBuffer() + s.textParts = append(s.textParts, cloneTextPart(part)) + return + } + + if part.Text == "" { + return + } + + if s.currentTextIsThought == nil || *s.currentTextIsThought != part.Thought { + s.flushTextBuffer() + val := part.Thought + s.currentTextIsThought = &val + } + s.currentTextBuffer += part.Text +} + +func (s *streamingResponseAggregator) flushTextBuffer() { + if s.currentTextBuffer == "" { + return + } + thought := false + if s.currentTextIsThought != nil { + thought = *s.currentTextIsThought + } + s.textParts = append(s.textParts, &genai.Part{Text: s.currentTextBuffer, Thought: thought}) + s.currentTextBuffer = "" + s.currentTextIsThought = nil +} + +func (s *streamingResponseAggregator) hasPendingTextParts() bool { + return s.currentTextBuffer != "" || len(s.textParts) > 0 +} + +func (s *streamingResponseAggregator) handleFunctionCall(part *genai.Part, llmResponse *model.LLMResponse) { + fc := part.FunctionCall + if fc == nil { + return + } + + if !isStreamingFunctionCall(fc) { + if !s.hasPendingFunctionCall() || fc.Name != "" || fc.ID != "" || len(fc.PartialArgs) > 0 { + return } - if s.text != "" { - parts = append(parts, &genai.Part{Text: s.text, Thought: false}) + // Empty functionCall chunk can mark the end of streamed args. + } + + state := s.ensureFunctionCallState(fc) + + if fc.Name != "" { + state.name = fc.Name + } + if fc.ID != "" { + state.id = fc.ID + } + if len(part.ThoughtSignature) > 0 && len(state.thoughtSignature) == 0 { + state.thoughtSignature = append([]byte(nil), part.ThoughtSignature...) + } + if state.args == nil { + state.args = make(map[string]any) + } + + for _, partialArg := range fc.PartialArgs { + value, ok := convertPartialArgValue(partialArg) + if !ok { + continue + } + pathTokens, err := parseJSONPath(partialArg.JsonPath) + if err != nil { + continue + } + if strVal, isString := value.(string); isString { + if existing, ok := getValueAtPath(state.args, pathTokens); ok { + if existingStr, ok := existing.(string); ok { + value = existingStr + strVal + } + } + } + updated := setValueAtPath(state.args, pathTokens, value) + if root, ok := updated.(map[string]any); ok { + state.args = root } + } + + if fcWillContinue(fc) { + llmResponse.Partial = true + return + } + + if !state.hasData() { + return + } - response := &model.LLMResponse{ - Content: &genai.Content{Parts: parts, Role: s.role}, - ErrorCode: s.response.ErrorCode, - ErrorMessage: s.response.ErrorMessage, - UsageMetadata: s.response.UsageMetadata, - GroundingMetadata: s.response.GroundingMetadata, - FinishReason: s.response.FinishReason, + if finalPart := s.buildFunctionCallPart(state); finalPart != nil { + if llmResponse.Content == nil { + llmResponse.Content = &genai.Content{Role: s.role} } - s.clear() - return response + llmResponse.Content.Parts = []*genai.Part{finalPart} + llmResponse.Partial = false } - s.clear() - return nil + s.clearFunctionCallState(state.key) } -func (s *streamingResponseAggregator) clear() { - s.response = nil - s.text = "" - s.thoughtText = "" - s.role = "" +func (s *streamingResponseAggregator) buildFunctionCallPart(state *functionCallState) *genai.Part { + if state == nil || !state.hasData() { + return nil + } + args := cloneValue(state.args).(map[string]any) + part := genai.NewPartFromFunctionCall(state.name, args) + if part.FunctionCall != nil { + part.FunctionCall.ID = state.id + } + if len(state.thoughtSignature) > 0 { + part.ThoughtSignature = append([]byte(nil), state.thoughtSignature...) + } + return part +} + +func (s *streamingResponseAggregator) clearFunctionCallState(key string) { + if key == "" || s.currentFunctionCalls == nil { + return + } + delete(s.currentFunctionCalls, key) + s.removeActiveFunctionCallKey(key) +} + +func (s *streamingResponseAggregator) hasPendingFunctionCall() bool { + return s.currentFunctionCalls != nil && len(s.currentFunctionCalls) > 0 +} + +func (s *streamingResponseAggregator) createPendingFunctionCallResponse() *model.LLMResponse { + if !s.hasPendingFunctionCall() || s.response == nil { + return nil + } + parts := s.buildPendingFunctionCallParts() + if len(parts) == 0 { + return nil + } + response := &model.LLMResponse{ + Content: &genai.Content{Parts: parts, Role: s.role}, + ErrorCode: s.response.ErrorCode, + ErrorMessage: s.response.ErrorMessage, + UsageMetadata: s.response.UsageMetadata, + GroundingMetadata: s.response.GroundingMetadata, + FinishReason: s.response.FinishReason, + } + s.clearAllFunctionCallState() + s.clearTextBuffers() + return response +} + +func cloneTextPart(part *genai.Part) *genai.Part { + if part == nil { + return nil + } + out := &genai.Part{ + Text: part.Text, + Thought: part.Thought, + } + if len(part.ThoughtSignature) > 0 { + out.ThoughtSignature = append([]byte(nil), part.ThoughtSignature...) + } + return out +} + +type functionCallState struct { + key string + name string + id string + args map[string]any + thoughtSignature []byte +} + +func (s *functionCallState) hasData() bool { + return s.name != "" || len(s.args) > 0 +} + +func (s *streamingResponseAggregator) ensureFunctionCallState(fc *genai.FunctionCall) *functionCallState { + key := s.resolveFunctionCallKey(fc) + if s.currentFunctionCalls == nil { + s.currentFunctionCalls = make(map[string]*functionCallState) + } + state, ok := s.currentFunctionCalls[key] + if !ok { + state = &functionCallState{key: key} + s.currentFunctionCalls[key] = state + s.trackActiveFunctionCall(fc, key) + } + s.lastFunctionCallKey = key + return state +} + +func (s *streamingResponseAggregator) resolveFunctionCallKey(fc *genai.FunctionCall) string { + if fc == nil { + return "__default__" + } + if fc.ID != "" { + return fc.ID + } + if fc.Name != "" { + if s.shouldStartNewUnnamedCall(fc) { + return s.newSyntheticKey(fc.Name) + } + if key := s.singleActiveKeyForName(fc.Name); key != "" { + return key + } + if key := s.mostRecentKeyForName(fc.Name); key != "" { + return key + } + return s.newSyntheticKey(fc.Name) + } + if key := s.nextUnnamedFunctionCallKey(); key != "" { + return key + } + return "__default__" +} + +func isStreamingFunctionCall(fc *genai.FunctionCall) bool { + if fc == nil { + return false + } + return len(fc.PartialArgs) > 0 || fc.WillContinue != nil +} + +func (s *streamingResponseAggregator) trackActiveFunctionCall(fc *genai.FunctionCall, key string) { + if key == "" { + return + } + s.activeFunctionCallOrder = append(s.activeFunctionCallOrder, key) + name := "" + if fc != nil { + name = fc.Name + } + if name == "" { + return + } + if s.activeFunctionCallKeysByName == nil { + s.activeFunctionCallKeysByName = make(map[string][]string) + } + s.activeFunctionCallKeysByName[name] = append(s.activeFunctionCallKeysByName[name], key) +} + +func (s *streamingResponseAggregator) removeActiveFunctionCallKey(key string) { + if key == "" { + return + } + if len(s.activeFunctionCallOrder) > 0 { + out := s.activeFunctionCallOrder[:0] + for _, k := range s.activeFunctionCallOrder { + if k != key { + out = append(out, k) + } + } + s.activeFunctionCallOrder = out + if s.unnamedCursor >= len(s.activeFunctionCallOrder) { + s.unnamedCursor = 0 + } + } + if len(s.activeFunctionCallKeysByName) == 0 { + return + } + for name, keys := range s.activeFunctionCallKeysByName { + out := keys[:0] + for _, k := range keys { + if k != key { + out = append(out, k) + } + } + if len(out) == 0 { + delete(s.activeFunctionCallKeysByName, name) + } else { + s.activeFunctionCallKeysByName[name] = out + } + } +} + +func (s *streamingResponseAggregator) clearAllFunctionCallState() { + s.currentFunctionCalls = nil + s.activeFunctionCallOrder = nil + s.activeFunctionCallKeysByName = nil + s.lastFunctionCallKey = "" + s.unnamedSequence = 0 + s.unnamedCursor = 0 +} + +func (s *streamingResponseAggregator) singleActiveKeyForName(name string) string { + if name == "" || len(s.activeFunctionCallKeysByName) == 0 { + return "" + } + keys := s.activeFunctionCallKeysByName[name] + if len(keys) == 1 { + return keys[0] + } + return "" +} + +func (s *streamingResponseAggregator) mostRecentKeyForName(name string) string { + if name == "" || len(s.activeFunctionCallKeysByName) == 0 { + return "" + } + keys := s.activeFunctionCallKeysByName[name] + if len(keys) == 0 { + return "" + } + return keys[len(keys)-1] +} + +func (s *streamingResponseAggregator) shouldStartNewUnnamedCall(fc *genai.FunctionCall) bool { + if fc == nil || fc.Name == "" { + return false + } + keys := s.activeFunctionCallKeysByName[fc.Name] + return len(keys) == 0 +} + +func (s *streamingResponseAggregator) newSyntheticKey(name string) string { + s.unnamedSequence++ + if name == "" { + return fmt.Sprintf("__call_%d__", s.unnamedSequence) + } + return fmt.Sprintf("%s#%d", name, s.unnamedSequence) +} + +func (s *streamingResponseAggregator) nextUnnamedFunctionCallKey() string { + if len(s.activeFunctionCallOrder) == 0 { + return "" + } + if s.unnamedCursor >= len(s.activeFunctionCallOrder) { + s.unnamedCursor = 0 + } + key := s.activeFunctionCallOrder[s.unnamedCursor] + s.unnamedCursor++ + return key +} + +func (s *streamingResponseAggregator) buildPendingFunctionCallParts() []*genai.Part { + if s.currentFunctionCalls == nil { + return nil + } + var parts []*genai.Part + for _, key := range s.activeFunctionCallOrder { + state := s.currentFunctionCalls[key] + if state == nil || !state.hasData() { + continue + } + if part := s.buildFunctionCallPart(state); part != nil { + parts = append(parts, part) + } + } + if len(parts) == 0 { + for _, state := range s.currentFunctionCalls { + if state == nil || !state.hasData() { + continue + } + if part := s.buildFunctionCallPart(state); part != nil { + parts = append(parts, part) + } + } + } + return parts +} + +type jsonPathToken struct { + key string + index int + isIndex bool +} + +func parseJSONPath(path string) ([]jsonPathToken, error) { + if path == "" { + return nil, fmt.Errorf("json path cannot be empty") + } + if !strings.HasPrefix(path, "$") { + return nil, fmt.Errorf("json path must start with '$'") + } + var tokens []jsonPathToken + i := 1 + for i < len(path) { + switch path[i] { + case '.': + i++ + start := i + for i < len(path) && path[i] != '.' && path[i] != '[' { + i++ + } + if start == i { + return nil, fmt.Errorf("invalid json path %q", path) + } + tokens = append(tokens, jsonPathToken{key: path[start:i]}) + case '[': + i++ + if i >= len(path) { + return nil, fmt.Errorf("unterminated '[' in json path %q", path) + } + if path[i] == '\'' || path[i] == '"' { + quote := path[i] + i++ + start := i + for i < len(path) && path[i] != quote { + i++ + } + if i >= len(path) { + return nil, fmt.Errorf("unterminated quoted key in json path %q", path) + } + key := path[start:i] + i++ + if i >= len(path) || path[i] != ']' { + return nil, fmt.Errorf("missing closing bracket in json path %q", path) + } + tokens = append(tokens, jsonPathToken{key: key}) + i++ + continue + } + start := i + for i < len(path) && path[i] != ']' { + i++ + } + if i >= len(path) { + return nil, fmt.Errorf("unterminated array index in json path %q", path) + } + idx, err := strconv.Atoi(path[start:i]) + if err != nil { + return nil, fmt.Errorf("invalid array index in json path %q: %w", path, err) + } + if idx < 0 { + return nil, fmt.Errorf("negative array indices are not supported in json path %q", path) + } + tokens = append(tokens, jsonPathToken{index: idx, isIndex: true}) + i++ + default: + return nil, fmt.Errorf("unexpected character %q in json path %q", path[i], path) + } + } + return tokens, nil +} + +func convertPartialArgValue(arg *genai.PartialArg) (any, bool) { + if arg == nil { + return nil, false + } + switch { + case arg.StringValue != "": + return arg.StringValue, true + case arg.NumberValue != nil: + return *arg.NumberValue, true + case arg.BoolValue != nil: + return *arg.BoolValue, true + case arg.NULLValue != "": + return nil, true + default: + return nil, false + } +} + +func getValueAtPath(current any, tokens []jsonPathToken) (any, bool) { + if len(tokens) == 0 { + return current, true + } + head := tokens[0] + rest := tokens[1:] + if head.isIndex { + slice, ok := current.([]any) + if !ok || slice == nil { + return nil, false + } + if head.index < 0 || head.index >= len(slice) { + return nil, false + } + return getValueAtPath(slice[head.index], rest) + } + obj, ok := current.(map[string]any) + if !ok || obj == nil { + return nil, false + } + val, ok := obj[head.key] + if !ok { + return nil, false + } + return getValueAtPath(val, rest) +} + +func setValueAtPath(current any, tokens []jsonPathToken, value any) any { + if len(tokens) == 0 { + return value + } + head := tokens[0] + rest := tokens[1:] + if head.isIndex { + var slice []any + if existing, ok := current.([]any); ok && existing != nil { + slice = existing + } + if len(slice) <= head.index { + newSlice := make([]any, head.index+1) + copy(newSlice, slice) + slice = newSlice + } + slice[head.index] = setValueAtPath(slice[head.index], rest, value) + return slice + } + var obj map[string]any + if existing, ok := current.(map[string]any); ok && existing != nil { + obj = existing + } else { + obj = make(map[string]any) + } + obj[head.key] = setValueAtPath(obj[head.key], rest, value) + return obj +} + +func cloneValue(v any) any { + switch val := v.(type) { + case map[string]any: + out := make(map[string]any, len(val)) + for k, v := range val { + out[k] = cloneValue(v) + } + return out + case []any: + out := make([]any, len(val)) + for i, v := range val { + out[i] = cloneValue(v) + } + return out + default: + return v + } +} + +func fcWillContinue(fc *genai.FunctionCall) bool { + if fc == nil || fc.WillContinue == nil { + return false + } + return *fc.WillContinue } diff --git a/internal/llminternal/stream_aggregator_test.go b/internal/llminternal/stream_aggregator_test.go index dc8f1ae97..c514267a2 100644 --- a/internal/llminternal/stream_aggregator_test.go +++ b/internal/llminternal/stream_aggregator_test.go @@ -15,11 +15,15 @@ package llminternal_test import ( + "bytes" + "context" + "iter" "testing" "github.com/google/go-cmp/cmp" "google.golang.org/genai" + llminternal "google.golang.org/adk/internal/llminternal" "google.golang.org/adk/internal/testutil" "google.golang.org/adk/model" ) @@ -33,6 +37,10 @@ type streamAggregatorTest struct { wantPartial []bool } +// Doc: https://ai.google.dev/gemini-api/docs/thought-signatures +// Quote: "It is advisable to parse the entire request until the `finish_reason` is returned by the model." +// Summary: Streaming consumers should parse all chunks before final aggregation. +// Scenario: Two streaming calls emit partial responses followed by aggregated output. func TestStreamAggregator(t *testing.T) { ctx := t.Context() testCases := []streamAggregatorTest{ @@ -205,3 +213,555 @@ func TestStreamAggregator(t *testing.T) { }) } } + +// Doc: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling +// Quote: "Intermediate responses will contain a `functionCall` object with `partialArgs` and `willContinue` fields." +// Summary: Streamed function-call args arrive across multiple chunks and must be aggregated. +// Scenario: A single function call streams args across multiple chunks; signature is on the first chunk. +func TestStreamAggregatorStreamingFunctionCallArguments(t *testing.T) { + t.Parallel() + ctx := context.Background() + aggregator := llminternal.NewStreamingResponseAggregator() + thoughtSignature := []byte("signature-bytes") + var finalResponse *model.LLMResponse + + process := func(resp *genai.GenerateContentResponse) { + for llmResp, err := range aggregator.ProcessResponse(ctx, resp) { + if err != nil { + t.Fatalf("ProcessResponse returned error: %v", err) + } + if llmResp == nil || llmResp.Content == nil || len(llmResp.Content.Parts) == 0 { + continue + } + part := llmResp.Content.Parts[0] + if part.FunctionCall != nil && !llmResp.Partial { + finalResponse = llmResp + } + } + } + + process(newFunctionCallChunk("get_weather", "fc_001", thoughtSignature, true, []*genai.PartialArg{ + {JsonPath: "$.location", StringValue: "New "}, + }...)) + + if finalResponse != nil { + t.Fatalf("got final response before stream finished: %+v", finalResponse) + } + + process(newFunctionCallChunk("", "fc_001", nil, true, []*genai.PartialArg{ + {JsonPath: "$.location", StringValue: "York"}, + }...)) + + process(newFunctionCallChunk("", "fc_001", nil, false, []*genai.PartialArg{ + {JsonPath: "$.unit", StringValue: "celsius"}, + }...)) + + if finalResponse == nil { + t.Fatalf("expected final response after streaming function call") + } + + if len(finalResponse.Content.Parts) != 1 { + t.Fatalf("expected 1 part, got %d", len(finalResponse.Content.Parts)) + } + fcPart := finalResponse.Content.Parts[0] + if fcPart.FunctionCall == nil { + t.Fatalf("expected function call part in final response") + } + if got, want := fcPart.FunctionCall.Args["location"], "New York"; got != want { + t.Fatalf("location arg mismatch: got %v want %v", got, want) + } + if got, want := fcPart.FunctionCall.Args["unit"], "celsius"; got != want { + t.Fatalf("unit arg mismatch: got %v want %v", got, want) + } + if got := fcPart.FunctionCall.ID; got != "fc_001" { + t.Fatalf("function call id mismatch: got %q want %q", got, "fc_001") + } + if !bytes.Equal(fcPart.ThoughtSignature, thoughtSignature) { + t.Fatalf("thought signature mismatch: got %v want %v", fcPart.ThoughtSignature, thoughtSignature) + } + + if closeResp := aggregator.Close(); closeResp != nil { + t.Fatalf("expected no additional response from Close, got %+v", closeResp) + } +} + +// Doc: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling +// Quote: "\"functionCall\": {}" +// Summary: The streaming example includes an empty functionCall chunk as a marker. +// Scenario: An empty chunk should finalize the pending call so signatures/args are not dropped. +func TestStreamAggregatorFinalizesFunctionCallWhenEmptyChunkArrives(t *testing.T) { + t.Parallel() + ctx := context.Background() + aggregator := llminternal.NewStreamingResponseAggregator() + thoughtSignature := []byte("signature-bytes") + + process := func(resp *genai.GenerateContentResponse) *model.LLMResponse { + for llmResp, err := range aggregator.ProcessResponse(ctx, resp) { + if err != nil { + t.Fatalf("ProcessResponse returned error: %v", err) + } + if llmResp != nil { + return llmResp + } + } + return nil + } + + // Start a streaming function call with partial args. + process(newFunctionCallChunk("get_weather", "fc_001", thoughtSignature, true, []*genai.PartialArg{ + {JsonPath: "$.location", StringValue: "San"}, + }...)) + + // Simulate an empty function call chunk (no name/id/partialArgs/willContinue). + emptyChunk := newPartsChunk([]*genai.Part{ + {FunctionCall: &genai.FunctionCall{}}, + }, genai.FinishReason("")) + finalResp := process(emptyChunk) + + if finalResp == nil || finalResp.Content == nil || len(finalResp.Content.Parts) == 0 { + t.Fatalf("expected final response after empty chunk, got nil/empty") + } + + part := finalResp.Content.Parts[0] + if part.FunctionCall == nil { + t.Fatalf("expected function call part in final response") + } + if got, want := part.FunctionCall.Args["location"], "San"; got != want { + t.Fatalf("location arg mismatch: got %v want %v", got, want) + } + if !bytes.Equal(part.ThoughtSignature, thoughtSignature) { + t.Fatalf("thought signature mismatch: got %v want %v", part.ThoughtSignature, thoughtSignature) + } + + if closeResp := aggregator.Close(); closeResp != nil { + t.Fatalf("expected no additional response from Close, got %+v", closeResp) + } +} + +// Doc: https://ai.google.dev/gemini-api/docs/thought-signatures +// Quote: "the `thought_signature` is attached only to the first `functionCall` part." +// Summary: Parallel tool calls keep order and only the first part carries the signature. +// Scenario: Two parallel function calls in one response preserve order/signature placement. +func TestStreamAggregatorParallelFunctionCallsPreserveOrderAndSignature(t *testing.T) { + t.Parallel() + ctx := context.Background() + aggregator := llminternal.NewStreamingResponseAggregator() + thoughtSignature := []byte("signature-bytes") + + part1 := genai.NewPartFromFunctionCall("fn_one", map[string]any{"x": "1"}) + part1.ThoughtSignature = thoughtSignature + part2 := genai.NewPartFromFunctionCall("fn_two", map[string]any{"y": "2"}) + + responses := collectResponses(t, aggregator, ctx, newPartsChunk([]*genai.Part{part1, part2}, genai.FinishReasonStop)) + finalResp := lastNonPartial(responses) + if finalResp == nil || finalResp.Content == nil { + t.Fatalf("expected final response with content") + } + + if got := len(finalResp.Content.Parts); got != 2 { + t.Fatalf("expected 2 parts, got %d", got) + } + if finalResp.Content.Parts[0].FunctionCall == nil || finalResp.Content.Parts[1].FunctionCall == nil { + t.Fatalf("expected function call parts") + } + if !bytes.Equal(finalResp.Content.Parts[0].ThoughtSignature, thoughtSignature) { + t.Fatalf("first function call signature mismatch: got %v want %v", finalResp.Content.Parts[0].ThoughtSignature, thoughtSignature) + } + if len(finalResp.Content.Parts[1].ThoughtSignature) != 0 { + t.Fatalf("expected second function call to have no signature") + } + if got, want := finalResp.Content.Parts[0].FunctionCall.Name, "fn_one"; got != want { + t.Fatalf("first function call name mismatch: got %q want %q", got, want) + } + if got, want := finalResp.Content.Parts[1].FunctionCall.Name, "fn_two"; got != want { + t.Fatalf("second function call name mismatch: got %q want %q", got, want) + } +} + +// Doc: https://ai.google.dev/gemini-api/docs/thought-signatures +// Quote: "each function call will have a signature and you must pass all signatures back." +// Summary: Sequential function calls must each preserve their own signatures. +// Scenario: Two sequential function calls each keep their own signatures. +func TestStreamAggregatorSequentialFunctionCallsPreserveSignatures(t *testing.T) { + t.Parallel() + ctx := context.Background() + aggregator := llminternal.NewStreamingResponseAggregator() + sig1 := []byte("sig-one") + sig2 := []byte("sig-two") + + part1 := genai.NewPartFromFunctionCall("fn_first", map[string]any{"a": "1"}) + part1.ThoughtSignature = sig1 + part2 := genai.NewPartFromFunctionCall("fn_second", map[string]any{"b": "2"}) + part2.ThoughtSignature = sig2 + + responses := collectResponses(t, aggregator, ctx, + newPartsChunk([]*genai.Part{part1}, genai.FinishReasonStop), + newPartsChunk([]*genai.Part{part2}, genai.FinishReasonStop), + ) + + var gotSigs [][]byte + for _, resp := range responses { + if resp == nil || resp.Partial || resp.Content == nil || len(resp.Content.Parts) == 0 { + continue + } + part := resp.Content.Parts[0] + if part.FunctionCall != nil { + gotSigs = append(gotSigs, part.ThoughtSignature) + } + } + if len(gotSigs) != 2 { + t.Fatalf("expected 2 function call responses, got %d", len(gotSigs)) + } + if !bytes.Equal(gotSigs[0], sig1) || !bytes.Equal(gotSigs[1], sig2) { + t.Fatalf("function call signatures mismatch: got %v want [%v %v]", gotSigs, sig1, sig2) + } +} + +// Doc: https://ai.google.dev/gemini-api/docs/thought-signatures +// Quote: "the model may return the thought signature in a part with an empty text content part." +// Summary: A final empty text part can carry the signature and must be preserved. +// Scenario: Preserve the empty final part to keep the signature. +func TestStreamAggregatorPreservesSignatureOnEmptyFinalTextPart(t *testing.T) { + t.Parallel() + ctx := context.Background() + aggregator := llminternal.NewStreamingResponseAggregator() + signature := []byte("final-signature") + + responses := collectResponses(t, aggregator, ctx, + newTextChunk("Hello", nil, false, genai.FinishReason("")), + newTextChunk("", signature, false, genai.FinishReasonStop), + ) + + finalResp := lastNonPartial(responses) + if finalResp == nil || finalResp.Content == nil { + t.Fatalf("expected final response with content") + } + if got := len(finalResp.Content.Parts); got != 2 { + t.Fatalf("expected 2 parts, got %d", got) + } + if finalResp.Content.Parts[0].Text != "Hello" { + t.Fatalf("unexpected first part text: %q", finalResp.Content.Parts[0].Text) + } + if len(finalResp.Content.Parts[1].ThoughtSignature) == 0 { + t.Fatalf("expected signature on final empty part") + } + if !bytes.Equal(finalResp.Content.Parts[1].ThoughtSignature, signature) { + t.Fatalf("signature mismatch: got %v want %v", finalResp.Content.Parts[1].ThoughtSignature, signature) + } +} + +// Doc: https://ai.google.dev/gemini-api/docs/function_calling +// Quote: "Don't merge a `Part` containing a signature with one that does not." +// Summary: Signed parts must remain distinct from adjacent text. +// Scenario: Signed text part remains distinct from adjacent text. +func TestStreamAggregatorDoesNotMergeSignedTextPart(t *testing.T) { + t.Parallel() + ctx := context.Background() + aggregator := llminternal.NewStreamingResponseAggregator() + signature := []byte("text-signature") + + responses := collectResponses(t, aggregator, ctx, + newTextChunk("A", nil, false, genai.FinishReason("")), + newTextChunk("B", signature, false, genai.FinishReason("")), + newTextChunk("C", nil, false, genai.FinishReasonStop), + ) + + finalResp := lastNonPartial(responses) + if finalResp == nil || finalResp.Content == nil { + t.Fatalf("expected final response with content") + } + if got := len(finalResp.Content.Parts); got != 3 { + t.Fatalf("expected 3 parts, got %d", got) + } + if finalResp.Content.Parts[0].Text != "A" || finalResp.Content.Parts[1].Text != "B" || finalResp.Content.Parts[2].Text != "C" { + t.Fatalf("unexpected text parts: %q %q %q", finalResp.Content.Parts[0].Text, finalResp.Content.Parts[1].Text, finalResp.Content.Parts[2].Text) + } + if !bytes.Equal(finalResp.Content.Parts[1].ThoughtSignature, signature) { + t.Fatalf("signed part signature mismatch: got %v want %v", finalResp.Content.Parts[1].ThoughtSignature, signature) + } +} + +// Doc: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling +// Quote: "If `true`, more `partialArgs` are expected for the overall function call in subsequent streamed responses." +// Summary: Streaming args may span multiple chunks; keep each call's args isolated. +// Scenario: Interleaved chunks for two calls are merged by ID and keep signatures. +func TestStreamAggregatorInterleavedStreamingFunctionCalls(t *testing.T) { + t.Parallel() + ctx := context.Background() + aggregator := llminternal.NewStreamingResponseAggregator() + sig1 := []byte("sig-one") + sig2 := []byte("sig-two") + + responses := collectResponses(t, aggregator, ctx, + newFunctionCallChunk("fn_one", "fc_1", sig1, true, []*genai.PartialArg{{JsonPath: "$.x", StringValue: "A"}}...), + newFunctionCallChunk("fn_two", "fc_2", sig2, true, []*genai.PartialArg{{JsonPath: "$.y", StringValue: "X"}}...), + newFunctionCallChunk("", "fc_1", nil, false, []*genai.PartialArg{{JsonPath: "$.x", StringValue: "B"}}...), + newFunctionCallChunk("", "fc_2", nil, false, []*genai.PartialArg{{JsonPath: "$.y", StringValue: "Y"}}...), + ) + + var finals []*genai.Part + for _, resp := range responses { + if resp == nil || resp.Partial || resp.Content == nil || len(resp.Content.Parts) == 0 { + continue + } + part := resp.Content.Parts[0] + if part.FunctionCall != nil { + finals = append(finals, part) + } + } + if len(finals) != 2 { + t.Fatalf("expected 2 finalized function calls, got %d", len(finals)) + } + if got, want := finals[0].FunctionCall.Name, "fn_one"; got != want { + t.Fatalf("first function call name mismatch: got %q want %q", got, want) + } + if got, want := finals[0].FunctionCall.Args["x"], "AB"; got != want { + t.Fatalf("first function call args mismatch: got %v want %v", got, want) + } + if !bytes.Equal(finals[0].ThoughtSignature, sig1) { + t.Fatalf("first function call signature mismatch: got %v want %v", finals[0].ThoughtSignature, sig1) + } + if got, want := finals[1].FunctionCall.Name, "fn_two"; got != want { + t.Fatalf("second function call name mismatch: got %q want %q", got, want) + } + if got, want := finals[1].FunctionCall.Args["y"], "XY"; got != want { + t.Fatalf("second function call args mismatch: got %v want %v", got, want) + } + if !bytes.Equal(finals[1].ThoughtSignature, sig2) { + t.Fatalf("second function call signature mismatch: got %v want %v", finals[1].ThoughtSignature, sig2) + } +} + +// Doc: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling +// Quote: "Intermediate responses will contain a `functionCall` object with `partialArgs` and `willContinue` fields." +// Summary: Streaming examples omit IDs, so clients should not depend on them. +// Scenario: With missing IDs, fall back to deterministic ordering for parallel streams. +func TestStreamAggregatorInterleavedFunctionCallsWithoutIDs(t *testing.T) { + t.Parallel() + ctx := context.Background() + aggregator := llminternal.NewStreamingResponseAggregator() + sig1 := []byte("sig-one") + sig2 := []byte("sig-two") + + responses := collectResponses(t, aggregator, ctx, + newFunctionCallChunk("fn_one", "", sig1, true, []*genai.PartialArg{{JsonPath: "$.x", StringValue: "A"}}...), + newFunctionCallChunk("fn_two", "", sig2, true, []*genai.PartialArg{{JsonPath: "$.y", StringValue: "X"}}...), + newFunctionCallChunk("", "", nil, false, []*genai.PartialArg{{JsonPath: "$.x", StringValue: "B"}}...), + newFunctionCallChunk("", "", nil, false, []*genai.PartialArg{{JsonPath: "$.y", StringValue: "Y"}}...), + ) + + var finals []*genai.Part + for _, resp := range responses { + if resp == nil || resp.Partial || resp.Content == nil || len(resp.Content.Parts) == 0 { + continue + } + part := resp.Content.Parts[0] + if part.FunctionCall != nil { + finals = append(finals, part) + } + } + if len(finals) != 2 { + t.Fatalf("expected 2 finalized function calls, got %d", len(finals)) + } + if got, want := finals[0].FunctionCall.Name, "fn_one"; got != want { + t.Fatalf("first function call name mismatch: got %q want %q", got, want) + } + if got, want := finals[0].FunctionCall.Args["x"], "AB"; got != want { + t.Fatalf("first function call args mismatch: got %v want %v", got, want) + } + if !bytes.Equal(finals[0].ThoughtSignature, sig1) { + t.Fatalf("first function call signature mismatch: got %v want %v", finals[0].ThoughtSignature, sig1) + } + if got, want := finals[1].FunctionCall.Name, "fn_two"; got != want { + t.Fatalf("second function call name mismatch: got %q want %q", got, want) + } + if got, want := finals[1].FunctionCall.Args["y"], "XY"; got != want { + t.Fatalf("second function call args mismatch: got %v want %v", got, want) + } + if !bytes.Equal(finals[1].ThoughtSignature, sig2) { + t.Fatalf("second function call signature mismatch: got %v want %v", finals[1].ThoughtSignature, sig2) + } +} + +// Doc: https://ai.google.dev/gemini-api/docs/thought-signatures +// Quote: "You must return this signature in the exact part where it was received when sending the conversation history back to the model." +// Summary: Pending function-call parts must be preserved to avoid dropping signatures. +// Scenario: If the stream ends early, flush all pending calls so signatures are preserved. +func TestStreamAggregatorCloseFlushesMultiplePendingFunctionCalls(t *testing.T) { + t.Parallel() + ctx := context.Background() + aggregator := llminternal.NewStreamingResponseAggregator() + sig1 := []byte("sig-one") + sig2 := []byte("sig-two") + + process := func(resp *genai.GenerateContentResponse) { + for llmResp, err := range aggregator.ProcessResponse(ctx, resp) { + if err != nil { + t.Fatalf("ProcessResponse returned error: %v", err) + } + if llmResp == nil { + continue + } + } + } + + process(newFunctionCallChunk("fn_one", "", sig1, true, []*genai.PartialArg{{JsonPath: "$.x", StringValue: "A"}}...)) + process(newFunctionCallChunk("fn_two", "", sig2, true, []*genai.PartialArg{{JsonPath: "$.y", StringValue: "X"}}...)) + + finalResp := aggregator.Close() + if finalResp == nil || finalResp.Content == nil { + t.Fatalf("expected final response from Close") + } + if got := len(finalResp.Content.Parts); got != 2 { + t.Fatalf("expected 2 function call parts, got %d", got) + } + + first := finalResp.Content.Parts[0] + second := finalResp.Content.Parts[1] + if first.FunctionCall == nil || second.FunctionCall == nil { + t.Fatalf("expected function call parts in Close response") + } + if got, want := first.FunctionCall.Name, "fn_one"; got != want { + t.Fatalf("first function call name mismatch: got %q want %q", got, want) + } + if got, want := first.FunctionCall.Args["x"], "A"; got != want { + t.Fatalf("first function call args mismatch: got %v want %v", got, want) + } + if !bytes.Equal(first.ThoughtSignature, sig1) { + t.Fatalf("first function call signature mismatch: got %v want %v", first.ThoughtSignature, sig1) + } + if got, want := second.FunctionCall.Name, "fn_two"; got != want { + t.Fatalf("second function call name mismatch: got %q want %q", got, want) + } + if got, want := second.FunctionCall.Args["y"], "X"; got != want { + t.Fatalf("second function call args mismatch: got %v want %v", got, want) + } + if !bytes.Equal(second.ThoughtSignature, sig2) { + t.Fatalf("second function call signature mismatch: got %v want %v", second.ThoughtSignature, sig2) + } +} + +// Doc: https://ai.google.dev/gemini-api/docs/function_calling +// Quote: "include the responses in the same order as they were requested." +// Summary: Preserve part ordering when mixing text and function calls. +// Scenario: Mixed text + functionCall + text stays in original order. +func TestStreamAggregatorMixedTextAndFunctionCallOrder(t *testing.T) { + t.Parallel() + ctx := context.Background() + aggregator := llminternal.NewStreamingResponseAggregator() + + fcPart := genai.NewPartFromFunctionCall("fn_call", map[string]any{"k": "v"}) + resp := newPartsChunk([]*genai.Part{ + {Text: "Hello"}, + fcPart, + {Text: "World"}, + }, genai.FinishReasonStop) + + responses := collectResponses(t, aggregator, ctx, resp) + var withFn *model.LLMResponse + for _, r := range responses { + if r == nil || r.Content == nil { + continue + } + for _, p := range r.Content.Parts { + if p.FunctionCall != nil { + withFn = r + break + } + } + if withFn != nil { + break + } + } + if withFn == nil || withFn.Content == nil { + t.Fatalf("expected response containing function call") + } + parts := withFn.Content.Parts + if len(parts) != 3 { + t.Fatalf("expected 3 parts, got %d", len(parts)) + } + if parts[0].Text != "Hello" || parts[1].FunctionCall == nil || parts[2].Text != "World" { + t.Fatalf("unexpected part order: %+v", parts) + } +} + +func newFunctionCallChunk(name, id string, sig []byte, willContinue bool, args ...*genai.PartialArg) *genai.GenerateContentResponse { + part := &genai.Part{ + FunctionCall: &genai.FunctionCall{ + Name: name, + ID: id, + PartialArgs: args, + WillContinue: genai.Ptr(willContinue), + }, + } + if len(sig) > 0 { + part.ThoughtSignature = sig + } + finishReason := genai.FinishReason("") + if !willContinue { + finishReason = genai.FinishReasonStop + } + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{ + { + Content: &genai.Content{ + Role: "model", + Parts: []*genai.Part{part}, + }, + FinishReason: finishReason, + }, + }, + } +} + +func newTextChunk(text string, sig []byte, thought bool, finishReason genai.FinishReason) *genai.GenerateContentResponse { + part := &genai.Part{Text: text, Thought: thought} + if len(sig) > 0 { + part.ThoughtSignature = sig + } + return newPartsChunk([]*genai.Part{part}, finishReason) +} + +func newPartsChunk(parts []*genai.Part, finishReason genai.FinishReason) *genai.GenerateContentResponse { + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{ + { + Content: &genai.Content{Role: "model", Parts: parts}, + FinishReason: finishReason, + }, + }, + } +} + +type responseAggregator interface { + ProcessResponse(ctx context.Context, genResp *genai.GenerateContentResponse) iter.Seq2[*model.LLMResponse, error] + Close() *model.LLMResponse +} + +func collectResponses(t *testing.T, aggregator responseAggregator, ctx context.Context, responses ...*genai.GenerateContentResponse) []*model.LLMResponse { + t.Helper() + var out []*model.LLMResponse + for _, resp := range responses { + for llmResp, err := range aggregator.ProcessResponse(ctx, resp) { + if err != nil { + t.Fatalf("ProcessResponse returned error: %v", err) + } + if llmResp != nil { + out = append(out, llmResp) + } + } + } + if final := aggregator.Close(); final != nil { + out = append(out, final) + } + return out +} + +func lastNonPartial(responses []*model.LLMResponse) *model.LLMResponse { + for i := len(responses) - 1; i >= 0; i-- { + if responses[i] != nil && !responses[i].Partial { + return responses[i] + } + } + return nil +} diff --git a/session/database/service_test.go b/session/database/service_test.go index 3e7290273..7ecf7de7e 100644 --- a/session/database/service_test.go +++ b/session/database/service_test.go @@ -15,7 +15,9 @@ package database import ( + "bytes" "maps" + "os" "strconv" "testing" "time" @@ -30,6 +32,14 @@ import ( "google.golang.org/adk/session" ) +// TestMain sets up the test environment. +// We set time.Local to UTC to ensure consistent timestamp formatting in tests, +// since the database stores timestamps without timezone info and tests expect UTC format. +func TestMain(m *testing.M) { + time.Local = time.UTC + os.Exit(m.Run()) +} + func Test_databaseService_Create(t *testing.T) { tests := []struct { name string @@ -358,6 +368,7 @@ func Test_databaseService_Get(t *testing.T) { if tt.wantEvents != nil { opts := []cmp.Option{ cmpopts.SortSlices(func(a, b *session.Event) bool { return a.Timestamp.Before(b.Timestamp) }), + cmp.Comparer(func(a, b time.Time) bool { return a.Equal(b) }), } if diff := cmp.Diff(events(tt.wantEvents), got.Session.Events(), opts...); diff != "" { t.Errorf("Get session events mismatch: (-want +got):\n%s", diff) @@ -473,6 +484,66 @@ func Test_databaseService_List(t *testing.T) { } } +func Test_databaseService_ThoughtSignatureRoundTrip(t *testing.T) { + s := emptyService(t) + ctx := t.Context() + + created, err := s.Create(ctx, &session.CreateRequest{ + AppName: "app", + UserID: "user", + SessionID: "s1", + }) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + + // Update to avoid stale validation during AppendEvent. + created.Session.(*localSession).updatedAt = time.Now() + + sig := []byte("signature-bytes") + part := genai.NewPartFromFunctionCall("do_work", map[string]any{"a": "b"}) + part.ThoughtSignature = append([]byte(nil), sig...) + + event := &session.Event{ + ID: "event1", + Author: "model", + Timestamp: time.Now(), + LLMResponse: model.LLMResponse{ + Content: &genai.Content{ + Role: "model", + Parts: []*genai.Part{part}, + }, + }, + } + + if err := s.AppendEvent(ctx, created.Session.(*localSession), event); err != nil { + t.Fatalf("AppendEvent failed: %v", err) + } + + got, err := s.Get(ctx, &session.GetRequest{ + AppName: "app", + UserID: "user", + SessionID: "s1", + }) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + + if got.Session.Events().Len() != 1 { + t.Fatalf("Get returned %d events, want 1", got.Session.Events().Len()) + } + + gotEvent := got.Session.Events().At(0) + if gotEvent == nil || gotEvent.Content == nil || len(gotEvent.Content.Parts) == 0 { + t.Fatalf("Get returned empty event content") + } + + gotSig := gotEvent.Content.Parts[0].ThoughtSignature + if !bytes.Equal(gotSig, sig) { + t.Fatalf("thought signature mismatch: got %v want %v", gotSig, sig) + } +} + func Test_databaseService_AppendEvent(t *testing.T) { tests := []struct { name string diff --git a/tool/agenttool/agent_tool_test.go b/tool/agenttool/agent_tool_test.go index 4ffa4ba85..1c2146d58 100644 --- a/tool/agenttool/agent_tool_test.go +++ b/tool/agenttool/agent_tool_test.go @@ -234,7 +234,7 @@ func TestAgentTool_Run_WithoutSchema(t *testing.T) { if err != nil { t.Fatalf("Run() failed unexpectedly: %v", err) } - want := map[string]any{"result": "First text part is returned"} + want := map[string]any{"result": "First text part is returnedThis should be ignored"} if diff := cmp.Diff(want, result); diff != "" { t.Errorf("Run() result diff (-want +got):\n%s", diff) }