diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index 2d093c2207cd..1bbbe5839787 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Net.Http; using System.Runtime.CompilerServices; +using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -164,13 +165,21 @@ public async Task> GenerateChatMessageAsync( { var state = this.ValidateInputAndCreateChatCompletionState(chatHistory, kernel, executionSettings); + // Aggregation state for multi-iteration function calling loops. + // Text content from intermediate iterations (before tool calls) would otherwise be lost. + // Token usage must be summed across all API calls to report accurate totals. + StringBuilder? aggregatedContent = null; + int totalPromptTokens = 0; + int totalCandidatesTokens = 0; + bool hadMultipleIterations = false; + for (state.Iteration = 1; ; state.Iteration++) { List chatResponses; + GeminiResponse geminiResponse; using (var activity = ModelDiagnostics.StartCompletionActivity( this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings)) { - GeminiResponse geminiResponse; try { geminiResponse = await this.SendRequestAndReturnValidGeminiResponseAsync( @@ -190,19 +199,42 @@ public async Task> GenerateChatMessageAsync( geminiResponse.UsageMetadata?.CandidatesTokenCount); } + // Aggregate usage across all iterations. + totalPromptTokens += geminiResponse.UsageMetadata?.PromptTokenCount ?? 0; + totalCandidatesTokens += geminiResponse.UsageMetadata?.CandidatesTokenCount ?? 0; + // If we don't want to attempt to invoke any functions, just return the result. // Or if we are auto-invoking but we somehow end up with other than 1 choice even though only 1 was requested, similarly bail. if (!state.AutoInvoke || chatResponses.Count != 1) { + // Apply aggregated content and usage to the final message. + this.ApplyAggregatedState(chatResponses, aggregatedContent, totalPromptTokens, totalCandidatesTokens, hadMultipleIterations); return chatResponses; } state.LastMessage = chatResponses[0]; if (state.LastMessage.ToolCalls is null || state.LastMessage.ToolCalls.Count == 0) { + // Apply aggregated content and usage to the final message. + this.ApplyAggregatedState(chatResponses, aggregatedContent, totalPromptTokens, totalCandidatesTokens, hadMultipleIterations); return chatResponses; } + // We're about to process tool calls and continue the loop - mark that we have multiple iterations. + hadMultipleIterations = true; + + // Accumulate text content from this iteration before processing tool calls. + // The LLM may generate text (e.g., "Let me check that for you...") before tool calls. + if (!string.IsNullOrEmpty(state.LastMessage.Content)) + { + aggregatedContent ??= new StringBuilder(); + if (aggregatedContent.Length > 0) + { + aggregatedContent.Append("\n\n"); + } + aggregatedContent.Append(state.LastMessage.Content); + } + // ToolCallBehavior is not null because we are in auto-invoke mode but we check it again to be sure it wasn't changed in the meantime Verify.NotNull(state.ExecutionSettings.ToolCallBehavior); @@ -213,7 +245,10 @@ public async Task> GenerateChatMessageAsync( // and return the last chat message content that was added to chat history if (state.FilterTerminationRequested) { - return [state.ChatHistory.Last()]; + var lastMessage = state.ChatHistory.Last(); + // Apply aggregated content and usage to the filter-terminated message. + this.ApplyAggregatedState(lastMessage, aggregatedContent, totalPromptTokens, totalCandidatesTokens, hadMultipleIterations); + return [lastMessage]; } } } @@ -889,6 +924,85 @@ private static GeminiMetadata GetResponseMetadata( ResponseSafetyRatings = candidate.SafetyRatings?.ToList(), }; + /// + /// Applies aggregated text content and usage from previous iterations to the final message(s). + /// This ensures that text generated before tool calls is not lost and that token usage + /// reflects the total across all API calls in the function calling loop. + /// + /// The list of messages to update. + /// Accumulated text content from previous iterations, or null if none. + /// Total prompt tokens across all iterations. + /// Total candidates tokens across all iterations. + /// Whether the function calling loop had multiple iterations. + private void ApplyAggregatedState( + List messages, + StringBuilder? aggregatedContent, + int totalPromptTokens, + int totalCandidatesTokens, + bool hadMultipleIterations) + { + if (messages.Count == 0) + { + return; + } + + this.ApplyAggregatedStateToMessage(messages[0], aggregatedContent, totalPromptTokens, totalCandidatesTokens, hadMultipleIterations); + } + + /// + /// Applies aggregated text content and usage from previous iterations to a single message. + /// + /// The message to update. + /// Accumulated text content from previous iterations, or null if none. + /// Total prompt tokens across all iterations. + /// Total candidates tokens across all iterations. + /// Whether the function calling loop had multiple iterations. + private void ApplyAggregatedState( + ChatMessageContent message, + StringBuilder? aggregatedContent, + int totalPromptTokens, + int totalCandidatesTokens, + bool hadMultipleIterations) + { + this.ApplyAggregatedStateToMessage(message, aggregatedContent, totalPromptTokens, totalCandidatesTokens, hadMultipleIterations); + } + + /// + /// Core implementation for applying aggregated state to a message. + /// + private void ApplyAggregatedStateToMessage( + ChatMessageContent message, + StringBuilder? aggregatedContent, + int totalPromptTokens, + int totalCandidatesTokens, + bool hadMultipleIterations) + { + // Prepend aggregated content from previous iterations. + if (aggregatedContent is { Length: > 0 }) + { + if (!string.IsNullOrEmpty(message.Content)) + { + aggregatedContent.Append("\n\n"); + aggregatedContent.Append(message.Content); + } + message.Content = aggregatedContent.ToString(); + } + + // Update metadata with aggregated usage if we had multiple iterations. + // This ensures token counts are accurate even when intermediate iterations had no text content. + if (hadMultipleIterations && message.Metadata is GeminiMetadata existingMetadata) + { + // Create a new metadata dictionary with aggregated values. + var updatedDict = new Dictionary(existingMetadata) + { + [nameof(GeminiMetadata.PromptTokenCount)] = totalPromptTokens, + [nameof(GeminiMetadata.CandidatesTokenCount)] = totalCandidatesTokens, + [nameof(GeminiMetadata.TotalTokenCount)] = totalPromptTokens + totalCandidatesTokens + }; + message.Metadata = GeminiMetadata.FromDictionary(updatedDict); + } + } + private sealed class ChatCompletionState { internal ChatHistory ChatHistory { get; set; } = null!; diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs index cd3192c65b8e..44651d042f05 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs @@ -55,6 +55,14 @@ internal async Task> GetChatMessageContentsAsy var endpoint = this.GetEndpoint(mistralExecutionSettings, path: "chat/completions"); var autoInvoke = kernel is not null && mistralExecutionSettings.ToolCallBehavior?.MaximumAutoInvokeAttempts > 0 && s_inflightAutoInvokes.Value < MaxInflightAutoInvokes; + // Aggregation state for multi-iteration function calling loops. + // Text content from intermediate iterations (before tool calls) would otherwise be lost. + // Token usage must be summed across all API calls to report accurate totals. + StringBuilder? aggregatedContent = null; + int totalPromptTokens = 0; + int totalCompletionTokens = 0; + bool hadMultipleIterations = false; + for (int requestIndex = 1; ; requestIndex++) { var chatRequest = this.CreateChatCompletionRequest(modelId, stream: false, chatHistory, mistralExecutionSettings, kernel); @@ -105,10 +113,16 @@ internal async Task> GetChatMessageContentsAsy activity?.SetCompletionResponse(responseContent, responseData.Usage?.PromptTokens, responseData.Usage?.CompletionTokens); } + // Aggregate usage across all iterations. + totalPromptTokens += responseData.Usage?.PromptTokens ?? 0; + totalCompletionTokens += responseData.Usage?.CompletionTokens ?? 0; + // If we don't want to attempt to invoke any functions, just return the result. // Or if we are auto-invoking but we somehow end up with other than 1 choice even though only 1 was requested, similarly bail. if (!autoInvoke || responseData.Choices.Count != 1) { + // Apply aggregated content and usage to the final message. + ApplyAggregatedState(responseContent, aggregatedContent, totalPromptTokens, totalCompletionTokens, hadMultipleIterations); return responseContent; } @@ -120,9 +134,27 @@ internal async Task> GetChatMessageContentsAsy MistralChatChoice chatChoice = responseData.Choices[0]; // TODO Handle multiple choices if (!chatChoice.IsToolCall) { + // Apply aggregated content and usage to the final message. + ApplyAggregatedState(responseContent, aggregatedContent, totalPromptTokens, totalCompletionTokens, hadMultipleIterations); return responseContent; } + // We're about to process tool calls and continue the loop - mark that we have multiple iterations. + hadMultipleIterations = true; + + // Accumulate text content from this iteration before processing tool calls. + // The LLM may generate text (e.g., "Let me check that for you...") before tool calls. + var currentContent = responseContent.Count > 0 ? responseContent[0].Content : null; + if (!string.IsNullOrEmpty(currentContent)) + { + aggregatedContent ??= new StringBuilder(); + if (aggregatedContent.Length > 0) + { + aggregatedContent.Append("\n\n"); + } + aggregatedContent.Append(currentContent); + } + if (this._logger.IsEnabled(LogLevel.Debug)) { this._logger.LogDebug("Tool requests: {Requests}", chatChoice.ToolCallCount); @@ -226,7 +258,10 @@ internal async Task> GetChatMessageContentsAsy this._logger.LogDebug("Filter requested termination of automatic function invocation."); } - return [chatHistory.Last()]; + var lastMessage = chatHistory.Last(); + // Apply aggregated content and usage to the filter-terminated message. + ApplyAggregatedState(lastMessage, aggregatedContent, totalPromptTokens, totalCompletionTokens, hadMultipleIterations); + return [lastMessage]; } } @@ -1088,5 +1123,73 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(context await functionCallCallback(context).ConfigureAwait(false); } } + + /// + /// Applies aggregated text content and usage from previous iterations to the final message(s). + /// This ensures that text generated before tool calls is not lost and that token usage + /// reflects the total across all API calls in the function calling loop. + /// + /// The list of messages to update. + /// Accumulated text content from previous iterations, or null if none. + /// Total prompt tokens across all iterations. + /// Total completion tokens across all iterations. + /// Whether the function calling loop had multiple iterations. + private static void ApplyAggregatedState( + List messages, + StringBuilder? aggregatedContent, + int totalPromptTokens, + int totalCompletionTokens, + bool hadMultipleIterations) + { + if (messages.Count == 0) + { + return; + } + + ApplyAggregatedState(messages[0], aggregatedContent, totalPromptTokens, totalCompletionTokens, hadMultipleIterations); + } + + /// + /// Applies aggregated text content and usage from previous iterations to a single message. + /// + /// The message to update. + /// Accumulated text content from previous iterations, or null if none. + /// Total prompt tokens across all iterations. + /// Total completion tokens across all iterations. + /// Whether the function calling loop had multiple iterations. + private static void ApplyAggregatedState( + ChatMessageContent message, + StringBuilder? aggregatedContent, + int totalPromptTokens, + int totalCompletionTokens, + bool hadMultipleIterations) + { + // Prepend aggregated content from previous iterations. + if (aggregatedContent is { Length: > 0 }) + { + if (!string.IsNullOrEmpty(message.Content)) + { + aggregatedContent.Append("\n\n"); + aggregatedContent.Append(message.Content); + } + message.Content = aggregatedContent.ToString(); + } + + // Update metadata with aggregated usage if we had multiple iterations. + // This ensures token counts are accurate even when intermediate iterations had no text content. + if (hadMultipleIterations && message.Metadata is not null) + { + var updatedMetadata = new Dictionary(message.Metadata) + { + ["AggregatedUsage"] = new Dictionary + { + ["PromptTokens"] = totalPromptTokens, + ["CompletionTokens"] = totalCompletionTokens, + ["TotalTokens"] = totalPromptTokens + totalCompletionTokens + } + }; + message.Metadata = new System.Collections.ObjectModel.ReadOnlyDictionary(updatedMetadata); + } + } #endregion } diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj index faf94420f29b..f5bbc229249b 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj @@ -102,6 +102,12 @@ Always + + Always + + + Always + diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/FunctionCallingContentAggregationTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/FunctionCallingContentAggregationTests.cs new file mode 100644 index 000000000000..ff40a91cef28 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/FunctionCallingContentAggregationTests.cs @@ -0,0 +1,298 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using Xunit; + +namespace SemanticKernel.Connectors.OpenAI.UnitTests.Core; + +/// +/// Tests for non-streaming function calling content and usage aggregation. +/// Verifies that text content generated before tool calls is preserved and that +/// token usage is correctly summed across all API calls in the function calling loop. +/// +public sealed class FunctionCallingContentAggregationTests : IDisposable +{ + private readonly MultipleHttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + + public FunctionCallingContentAggregationTests() + { + this._messageHandlerStub = new MultipleHttpMessageHandlerStub(); + this._httpClient = new HttpClient(this._messageHandlerStub, false); + } + + [Fact] + public async Task NonStreaming_IntermediateTextBeforeToolCall_IsAggregatedInFinalResponseAsync() + { + // Arrange + var function = KernelFunctionFactory.CreateFromMethod((string parameter) => "function-result", "Function1"); + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function]); + var kernel = this.CreateKernel(plugin); + + // First response: LLM generates text AND a tool call + // Second response: LLM generates final text after function execution + this._messageHandlerStub.ResponsesToReturn = GetTextWithToolCallResponses(); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("What is the answer?"); + + var settings = new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + }; + + var chatCompletionService = kernel.GetRequiredService(); + + // Act + var result = await chatCompletionService.GetChatMessageContentAsync(chatHistory, settings, kernel); + + // Assert + Assert.NotNull(result); + Assert.NotNull(result.Content); + + // The final content should contain BOTH the intermediate text and the final response + Assert.Contains("Let me check that for you", result.Content); + Assert.Contains("Based on my research, the answer is 42", result.Content); + } + + [Fact] + public async Task NonStreaming_TokenUsage_IsAggregatedAcrossAllIterationsAsync() + { + // Arrange + var function = KernelFunctionFactory.CreateFromMethod((string parameter) => "function-result", "Function1"); + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function]); + var kernel = this.CreateKernel(plugin); + + this._messageHandlerStub.ResponsesToReturn = GetTextWithToolCallResponses(); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("What is the answer?"); + + var settings = new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + }; + + var chatCompletionService = kernel.GetRequiredService(); + + // Act + var result = await chatCompletionService.GetChatMessageContentAsync(chatHistory, settings, kernel); + + // Assert + Assert.NotNull(result); + Assert.NotNull(result.Metadata); + + // Verify aggregated usage is present + // First call: 100 prompt + 50 completion = 150 total + // Second call: 200 prompt + 25 completion = 225 total + // Aggregated: 300 prompt + 75 completion = 375 total + Assert.True(result.Metadata.ContainsKey("AggregatedUsage"), "Metadata should contain AggregatedUsage"); + + var aggregatedUsage = result.Metadata["AggregatedUsage"] as Dictionary; + Assert.NotNull(aggregatedUsage); + Assert.Equal(300, aggregatedUsage["InputTokens"]); + Assert.Equal(75, aggregatedUsage["OutputTokens"]); + Assert.Equal(375, aggregatedUsage["TotalTokens"]); + } + + [Fact] + public async Task NonStreaming_SingleIteration_NoAggregationMetadataAddedAsync() + { + // Arrange - No function, so only one API call + var kernel = this.CreateKernel(plugin: null); + + this._messageHandlerStub.ResponsesToReturn = + [ + new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StreamContent(File.OpenRead("TestData/chat_completion_test_response.json")) + } + ]; + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Hello"); + + var settings = new OpenAIPromptExecutionSettings(); + + var chatCompletionService = kernel.GetRequiredService(); + + // Act + var result = await chatCompletionService.GetChatMessageContentAsync(chatHistory, settings, kernel); + + // Assert + Assert.NotNull(result); + Assert.Equal("Test chat response", result.Content); + + // Single iteration should NOT have AggregatedUsage metadata + Assert.False( + result.Metadata?.ContainsKey("AggregatedUsage") ?? false, + "Single iteration should not have AggregatedUsage metadata"); + } + + [Fact] + public async Task NonStreaming_ToolCallWithoutIntermediateText_OnlyFinalTextReturnedAsync() + { + // Arrange + var function = KernelFunctionFactory.CreateFromMethod((string parameter) => "function-result", "Function1"); + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function]); + var kernel = this.CreateKernel(plugin); + + // First response: tool call without text content (content: null) + // Second response: final text + this._messageHandlerStub.ResponsesToReturn = GetToolCallWithoutTextResponses(); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("What is the answer?"); + + var settings = new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + }; + + var chatCompletionService = kernel.GetRequiredService(); + + // Act + var result = await chatCompletionService.GetChatMessageContentAsync(chatHistory, settings, kernel); + + // Assert + Assert.NotNull(result); + Assert.Equal("Based on my research, the answer is 42.", result.Content); + } + + [Fact] + public async Task NonStreaming_FilterTerminatesEarly_AggregatedContentStillAppliedAsync() + { + // Arrange + var function = KernelFunctionFactory.CreateFromMethod((string parameter) => "function-result", "Function1"); + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function]); + + var builder = Kernel.CreateBuilder(); + builder.Plugins.Add(plugin); + + // Add filter that terminates after first function call + builder.Services.AddSingleton(new TerminatingFilter()); + builder.Services.AddSingleton( + _ => new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient)); + + var kernel = builder.Build(); + + this._messageHandlerStub.ResponsesToReturn = GetTextWithToolCallResponses(); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("What is the answer?"); + + var settings = new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + }; + + var chatCompletionService = kernel.GetRequiredService(); + + // Act + var result = await chatCompletionService.GetChatMessageContentAsync(chatHistory, settings, kernel); + + // Assert + Assert.NotNull(result); + + // Even though filter terminated early, the intermediate text from the first iteration + // (generated before the tool call) should be preserved in the result. + // The first API response contained: "Let me check that for you. I'll look up the current information." + Assert.NotNull(result.Content); + Assert.Contains("Let me check that for you", result.Content); + + // The second API call should NOT have been made (filter terminated), + // so the final response text should NOT be present. + Assert.DoesNotContain("Based on my research", result.Content); + + // Verify aggregated usage contains tokens from the first iteration only. + // First call: 100 prompt + 50 completion tokens + Assert.NotNull(result.Metadata); + Assert.True(result.Metadata.ContainsKey("AggregatedUsage"), "Metadata should contain AggregatedUsage even when filter terminates"); + + var aggregatedUsage = result.Metadata["AggregatedUsage"] as Dictionary; + Assert.NotNull(aggregatedUsage); + Assert.Equal(100, aggregatedUsage["InputTokens"]); + Assert.Equal(50, aggregatedUsage["OutputTokens"]); + Assert.Equal(150, aggregatedUsage["TotalTokens"]); + } + + public void Dispose() + { + this._httpClient.Dispose(); + this._messageHandlerStub.Dispose(); + } + + #region Private Helpers + + private Kernel CreateKernel(KernelPlugin? plugin) + { + var builder = Kernel.CreateBuilder(); + + if (plugin is not null) + { + builder.Plugins.Add(plugin); + } + + builder.Services.AddSingleton( + _ => new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient)); + + return builder.Build(); + } + +#pragma warning disable CA2000 // Dispose objects before losing scope + private static List GetTextWithToolCallResponses() + { + return + [ + // First response: text + tool call + new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StreamContent(File.OpenRead("TestData/aggregation_function_call_with_text_response.json")) + }, + // Second response: final text only + new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StreamContent(File.OpenRead("TestData/aggregation_final_response.json")) + } + ]; + } + + private static List GetToolCallWithoutTextResponses() + { + return + [ + // First response: tool call without text + new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StreamContent(File.OpenRead("TestData/filters_multiple_function_calls_test_response.json")) + }, + // Second response: final text only + new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StreamContent(File.OpenRead("TestData/aggregation_final_response.json")) + } + ]; + } +#pragma warning restore CA2000 + + private sealed class TerminatingFilter : IAutoFunctionInvocationFilter + { + public Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + context.Terminate = true; + return Task.CompletedTask; + } + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/aggregation_final_response.json b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/aggregation_final_response.json new file mode 100644 index 000000000000..6e49fb65e11d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/aggregation_final_response.json @@ -0,0 +1,23 @@ +{ + "id": "response-id-2", + "object": "chat.completion", + "created": 1699896920, + "model": "gpt-4", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Based on my research, the answer is 42." + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 200, + "completion_tokens": 25, + "total_tokens": 225 + } +} + diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/aggregation_function_call_with_text_response.json b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/aggregation_function_call_with_text_response.json new file mode 100644 index 000000000000..3485f4fe30aa --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/aggregation_function_call_with_text_response.json @@ -0,0 +1,33 @@ +{ + "id": "response-id-1", + "object": "chat.completion", + "created": 1699896916, + "model": "gpt-4", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Let me check that for you. I'll look up the current information.", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "MyPlugin-Function1", + "arguments": "{\"parameter\": \"test-value\"}" + } + } + ] + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150 + } +} + diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs index 3387601ed189..c8fea2093fa4 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs @@ -4,6 +4,7 @@ using System.ClientModel; using System.ClientModel.Primitives; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Diagnostics; using System.Diagnostics.Metrics; using System.Linq; @@ -156,6 +157,14 @@ internal async Task> GetChatMessageContentsAsy ValidateMaxTokens(chatExecutionSettings.MaxTokens); + // Aggregation state for multi-iteration function calling loops. + // Text content from intermediate iterations (before tool calls) would otherwise be lost. + // Token usage must be summed across all API calls to report accurate totals. + StringBuilder? aggregatedContent = null; + int totalInputTokens = 0; + int totalOutputTokens = 0; + bool hadMultipleIterations = false; + for (int requestIndex = 0; ; requestIndex++) { var chatForRequest = CreateChatCompletionMessages(chatExecutionSettings, chatHistory); @@ -194,12 +203,33 @@ internal async Task> GetChatMessageContentsAsy activity?.SetCompletionResponse([chatMessageContent], chatCompletion.Usage.InputTokenCount, chatCompletion.Usage.OutputTokenCount); } + // Aggregate usage across all iterations. + totalInputTokens += chatCompletion.Usage.InputTokenCount; + totalOutputTokens += chatCompletion.Usage.OutputTokenCount; + // If we don't want to attempt to invoke any functions or there is nothing to call, just return the result. if (!functionCallingConfig.AutoInvoke || chatCompletion.ToolCalls.Count == 0) { + // Apply aggregated content and usage to the final message. + this.ApplyAggregatedState(chatMessageContent, aggregatedContent, totalInputTokens, totalOutputTokens, hadMultipleIterations); return [chatMessageContent]; } + // We're about to process tool calls and continue the loop - mark that we have multiple iterations. + hadMultipleIterations = true; + + // Accumulate text content from this iteration before processing tool calls. + // The LLM may generate text (e.g., "Let me check that for you...") before tool calls. + if (!string.IsNullOrEmpty(chatMessageContent.Content)) + { + aggregatedContent ??= new StringBuilder(); + if (aggregatedContent.Length > 0) + { + aggregatedContent.Append("\n\n"); + } + aggregatedContent.Append(chatMessageContent.Content); + } + // Process function calls by invoking the functions and adding the results to the chat history. // Each function call will trigger auto-function-invocation filters, which can terminate the process. // In such cases, we'll return the last message in the chat history. @@ -216,6 +246,8 @@ internal async Task> GetChatMessageContentsAsy if (lastMessage != null) { + // Apply aggregated content and usage to the filter-terminated message. + this.ApplyAggregatedState(lastMessage, aggregatedContent, totalInputTokens, totalOutputTokens, hadMultipleIterations); return [lastMessage]; } @@ -1030,6 +1062,45 @@ private OpenAIChatMessageContent CreateChatMessageContent(ChatMessageRole chatRo return message; } + /// + /// Applies aggregated text content and token usage from multi-iteration function calling loops to the final message. + /// + /// The final message to update. + /// Accumulated text content from previous iterations, or null if none. + /// Total input tokens across all iterations. + /// Total output tokens across all iterations. + /// Whether the function calling loop had multiple iterations. + private void ApplyAggregatedState(ChatMessageContent message, StringBuilder? aggregatedContent, int totalInputTokens, int totalOutputTokens, bool hadMultipleIterations) + { + // Prepend aggregated content from previous iterations if any. + if (aggregatedContent is { Length: > 0 }) + { + if (!string.IsNullOrEmpty(message.Content)) + { + aggregatedContent.Append("\n\n"); + aggregatedContent.Append(message.Content); + } + message.Content = aggregatedContent.ToString(); + } + + // Update metadata with aggregated usage if we had multiple iterations. + // This ensures token counts are accurate even when intermediate iterations had no text content. + if (hadMultipleIterations) + { + var updatedMetadata = message.Metadata is not null + ? new Dictionary(message.Metadata) + : new Dictionary(); + + updatedMetadata["AggregatedUsage"] = new Dictionary + { + ["InputTokens"] = totalInputTokens, + ["OutputTokens"] = totalOutputTokens, + ["TotalTokens"] = totalInputTokens + totalOutputTokens + }; + message.Metadata = new ReadOnlyDictionary(updatedMetadata); + } + } + private List GetFunctionCallContents(IEnumerable toolCalls, bool retainArgumentTypes) { List result = [];