diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatReduction/SummarizingChatReducer.cs b/src/Libraries/Microsoft.Extensions.AI/ChatReduction/SummarizingChatReducer.cs index f097c1c9a35..673a9d7ad71 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatReduction/SummarizingChatReducer.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatReduction/SummarizingChatReducer.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; @@ -73,18 +74,24 @@ public async Task> ReduceAsync(IEnumerable { _ = Throw.IfNull(messages); - var summarizedConversion = SummarizedConversation.FromChatMessages(messages); - if (summarizedConversion.ShouldResummarize(_targetCount, _thresholdCount)) + var summarizedConversation = SummarizedConversation.FromChatMessages(messages); + var indexOfFirstMessageToKeep = summarizedConversation.FindIndexOfFirstMessageToKeep(_targetCount, _thresholdCount); + if (indexOfFirstMessageToKeep > 0) { - summarizedConversion = await summarizedConversion.ResummarizeAsync( - _chatClient, _targetCount, SummarizationPrompt, cancellationToken); + summarizedConversation = await summarizedConversation.ResummarizeAsync( + _chatClient, + indexOfFirstMessageToKeep, + SummarizationPrompt, + cancellationToken); } - return summarizedConversion.ToChatMessages(); + return summarizedConversation.ToChatMessages(); } + /// Represents a conversation with an optional summary. private readonly struct SummarizedConversation(string? summary, ChatMessage? systemMessage, IList unsummarizedMessages) { + /// Creates a from a list of chat messages. public static SummarizedConversation FromChatMessages(IEnumerable messages) { string? summary = null; @@ -102,7 +109,7 @@ public static SummarizedConversation FromChatMessages(IEnumerable m unsummarizedMessages.Clear(); summary = summaryValue; } - else if (!message.Contents.Any(m => m is FunctionCallContent or FunctionResultContent)) + else { unsummarizedMessages.Add(message); } @@ -111,31 +118,68 @@ public static SummarizedConversation FromChatMessages(IEnumerable m return new(summary, systemMessage, unsummarizedMessages); } - public bool ShouldResummarize(int targetCount, int thresholdCount) - => unsummarizedMessages.Count > targetCount + thresholdCount; - - public async Task ResummarizeAsync( - IChatClient chatClient, int targetCount, string summarizationPrompt, CancellationToken cancellationToken) + /// Performs summarization by calling the chat client and updating the conversation state. + public async ValueTask ResummarizeAsync( + IChatClient chatClient, int indexOfFirstMessageToKeep, string summarizationPrompt, CancellationToken cancellationToken) { - var messagesToResummarize = unsummarizedMessages.Count - targetCount; - if (messagesToResummarize <= 0) - { - // We're at or below the target count - no need to resummarize. - return this; - } + Debug.Assert(indexOfFirstMessageToKeep > 0, "Expected positive index for first message to keep."); - var summarizerChatMessages = ToSummarizerChatMessages(messagesToResummarize, summarizationPrompt); + // Generate the summary by sending unsummarized messages to the chat client + var summarizerChatMessages = ToSummarizerChatMessages(indexOfFirstMessageToKeep, summarizationPrompt); var response = await chatClient.GetResponseAsync(summarizerChatMessages, cancellationToken: cancellationToken); var newSummary = response.Text; - var lastSummarizedMessage = unsummarizedMessages[messagesToResummarize - 1]; + // Attach the summary metadata to the last message being summarized + // This is what allows us to build on previously-generated summaries + var lastSummarizedMessage = unsummarizedMessages[indexOfFirstMessageToKeep - 1]; var additionalProperties = lastSummarizedMessage.AdditionalProperties ??= []; additionalProperties[SummaryKey] = newSummary; - var newUnsummarizedMessages = unsummarizedMessages.Skip(messagesToResummarize).ToList(); + // Compute the new list of unsummarized messages + var newUnsummarizedMessages = unsummarizedMessages.Skip(indexOfFirstMessageToKeep).ToList(); return new SummarizedConversation(newSummary, systemMessage, newUnsummarizedMessages); } + /// Determines the index of the first message to keep (not summarize) based on target and threshold counts. + public int FindIndexOfFirstMessageToKeep(int targetCount, int thresholdCount) + { + var earliestAllowedIndex = unsummarizedMessages.Count - thresholdCount - targetCount; + if (earliestAllowedIndex <= 0) + { + // Not enough messages to warrant summarization + return 0; + } + + // Start at the ideal cut point (keeping exactly targetCount messages) + var indexOfFirstMessageToKeep = unsummarizedMessages.Count - targetCount; + + // Move backward to skip over function call/result content at the boundary + // We want to keep complete function call sequences together with their responses + while (indexOfFirstMessageToKeep > 0) + { + if (!unsummarizedMessages[indexOfFirstMessageToKeep - 1].Contents.Any(c => c is FunctionCallContent or FunctionResultContent)) + { + break; + } + + indexOfFirstMessageToKeep--; + } + + // Search backward within the threshold window to find a User message + // If found, cut right before it to avoid orphaning user questions from responses + for (var i = indexOfFirstMessageToKeep; i >= earliestAllowedIndex; i--) + { + if (unsummarizedMessages[i].Role == ChatRole.User) + { + return i; + } + } + + // No User message found within threshold - use the adjusted cut point + return indexOfFirstMessageToKeep; + } + + /// Converts the summarized conversation back into a collection of chat messages. public IEnumerable ToChatMessages() { if (systemMessage is not null) @@ -154,16 +198,21 @@ public IEnumerable ToChatMessages() } } - private IEnumerable ToSummarizerChatMessages(int messagesToResummarize, string summarizationPrompt) + /// Builds the list of messages to send to the chat client for summarization. + private IEnumerable ToSummarizerChatMessages(int indexOfFirstMessageToKeep, string summarizationPrompt) { if (summary is not null) { yield return new ChatMessage(ChatRole.Assistant, summary); } - for (var i = 0; i < messagesToResummarize; i++) + for (var i = 0; i < indexOfFirstMessageToKeep; i++) { - yield return unsummarizedMessages[i]; + var message = unsummarizedMessages[i]; + if (!message.Contents.Any(c => c is FunctionCallContent or FunctionResultContent)) + { + yield return message; + } } yield return new ChatMessage(ChatRole.System, summarizationPrompt); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatReduction/SummarizingChatReducerTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatReduction/SummarizingChatReducerTests.cs index 985b097ece8..258beeb5f7c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatReduction/SummarizingChatReducerTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatReduction/SummarizingChatReducerTests.cs @@ -84,27 +84,87 @@ public async Task ReduceAsync_PreservesSystemMessage() } [Fact] - public async Task ReduceAsync_IgnoresFunctionCallsAndResults() + public async Task ReduceAsync_PreservesCompleteToolCallSequence() { using var chatClient = new TestChatClient(); - var reducer = new SummarizingChatReducer(chatClient, targetCount: 3, threshold: 0); + + // Target 2 messages, but this would split a function call sequence + var reducer = new SummarizingChatReducer(chatClient, targetCount: 2, threshold: 0); List messages = [ + new ChatMessage(ChatRole.User, "What's the time?"), + new ChatMessage(ChatRole.Assistant, "Let me check"), new ChatMessage(ChatRole.User, "What's the weather?"), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call1", "get_weather", new Dictionary { ["location"] = "Seattle" })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("call1", "Sunny, 72°F")]), - new ChatMessage(ChatRole.Assistant, "The weather in Seattle is sunny and 72°F."), - new ChatMessage(ChatRole.User, "Thanks!"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call1", "get_weather")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("call1", "Sunny")]), + new ChatMessage(ChatRole.Assistant, "It's sunny"), ]; + chatClient.GetResponseAsyncCallback = (msgs, _, _) => + { + Assert.DoesNotContain(msgs, m => m.Contents.Any(c => c is FunctionCallContent or FunctionResultContent)); + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "Asked about time"))); + }; + var result = await reducer.ReduceAsync(messages, CancellationToken.None); + var resultList = result.ToList(); + + // Should have: summary + function content + last reply + Assert.Equal(4, resultList.Count); + + // Verify the complete sequence is preserved + Assert.Collection(resultList, + m => Assert.Contains("Asked about time", m.Text), + m => Assert.Contains(m.Contents, c => c is FunctionCallContent), + m => Assert.Contains(m.Contents, c => c is FunctionResultContent), + m => Assert.Contains("sunny", m.Text)); + } + + [Fact] + public async Task ReduceAsync_PreservesUserMessageWhenWithinThreshold() + { + using var chatClient = new TestChatClient(); + + // Target 3 messages with threshold of 2 + // This allows us to keep anywhere from 3 to 5 messages + var reducer = new SummarizingChatReducer(chatClient, targetCount: 3, threshold: 2); - // Function calls/results should be ignored, which means there aren't enough messages to generate a summary. + List messages = + [ + new ChatMessage(ChatRole.User, "First question"), + new ChatMessage(ChatRole.Assistant, "First answer"), + new ChatMessage(ChatRole.User, "Second question"), + new ChatMessage(ChatRole.Assistant, "Second answer"), + new ChatMessage(ChatRole.User, "Third question"), + new ChatMessage(ChatRole.Assistant, "Third answer"), + ]; + + chatClient.GetResponseAsyncCallback = (msgs, _, _) => + { + var msgList = msgs.ToList(); + + // Should summarize messages 0-1 (First question and answer) + // The reducer should find the User message at index 2 within the threshold + Assert.Equal(3, msgList.Count); // 2 messages to summarize + system prompt + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "Summary of first exchange"))); + }; + + var result = await reducer.ReduceAsync(messages, CancellationToken.None); var resultList = result.ToList(); - Assert.Equal(3, resultList.Count); // Function calls get removed in the summarized chat. - Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is FunctionCallContent)); - Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is FunctionResultContent)); + + // Should have: summary + 4 kept messages (from "Second question" onward) + Assert.Equal(5, resultList.Count); + + // Verify the summary is first + Assert.Contains("Summary", resultList[0].Text); + + // Verify we kept the User message at index 2 and everything after + Assert.Collection(resultList.Skip(1), + m => Assert.Contains("Second question", m.Text), + m => Assert.Contains("Second answer", m.Text), + m => Assert.Contains("Third question", m.Text), + m => Assert.Contains("Third answer", m.Text)); } [Theory] @@ -121,7 +181,7 @@ public async Task ReduceAsync_RespectsTargetAndThresholdCounts(int targetCount, var messages = new List(); for (int i = 0; i < messageCount; i++) { - messages.Add(new ChatMessage(i % 2 == 0 ? ChatRole.User : ChatRole.Assistant, $"Message {i}")); + messages.Add(new ChatMessage(ChatRole.Assistant, $"Message {i}")); } var summarizationCalled = false;