Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
Expand Down Expand Up @@ -73,18 +74,24 @@ public async Task<IEnumerable<ChatMessage>> ReduceAsync(IEnumerable<ChatMessage>
{
_ = Throw.IfNull(messages);

var summarizedConversion = SummarizedConversation.FromChatMessages(messages);
if (summarizedConversion.ShouldResummarize(_targetCount, _thresholdCount))
var summarizedConversation = SummarizedConversation.FromChatMessages(messages);
var indexOfFirstMessageToKeep = summarizedConversation.FindIndexOfFirstMessageToKeep(_targetCount, _thresholdCount);
if (indexOfFirstMessageToKeep > 0)
{
summarizedConversion = await summarizedConversion.ResummarizeAsync(
_chatClient, _targetCount, SummarizationPrompt, cancellationToken);
summarizedConversation = await summarizedConversation.ResummarizeAsync(
_chatClient,
indexOfFirstMessageToKeep,
SummarizationPrompt,
cancellationToken);
}

return summarizedConversion.ToChatMessages();
return summarizedConversation.ToChatMessages();
}

/// <summary>Represents a conversation with an optional summary.</summary>
private readonly struct SummarizedConversation(string? summary, ChatMessage? systemMessage, IList<ChatMessage> unsummarizedMessages)
{
/// <summary>Creates a <see cref="SummarizedConversation"/> from a list of chat messages.</summary>
public static SummarizedConversation FromChatMessages(IEnumerable<ChatMessage> messages)
{
string? summary = null;
Expand All @@ -102,7 +109,7 @@ public static SummarizedConversation FromChatMessages(IEnumerable<ChatMessage> m
unsummarizedMessages.Clear();
summary = summaryValue;
}
else if (!message.Contents.Any(m => m is FunctionCallContent or FunctionResultContent))
else
{
unsummarizedMessages.Add(message);
}
Expand All @@ -111,31 +118,68 @@ public static SummarizedConversation FromChatMessages(IEnumerable<ChatMessage> m
return new(summary, systemMessage, unsummarizedMessages);
}

public bool ShouldResummarize(int targetCount, int thresholdCount)
=> unsummarizedMessages.Count > targetCount + thresholdCount;

public async Task<SummarizedConversation> ResummarizeAsync(
IChatClient chatClient, int targetCount, string summarizationPrompt, CancellationToken cancellationToken)
/// <summary>Performs summarization by calling the chat client and updating the conversation state.</summary>
public async ValueTask<SummarizedConversation> ResummarizeAsync(
IChatClient chatClient, int indexOfFirstMessageToKeep, string summarizationPrompt, CancellationToken cancellationToken)
{
var messagesToResummarize = unsummarizedMessages.Count - targetCount;
if (messagesToResummarize <= 0)
{
// We're at or below the target count - no need to resummarize.
return this;
}
Debug.Assert(indexOfFirstMessageToKeep > 0, "Expected positive index for first message to keep.");

var summarizerChatMessages = ToSummarizerChatMessages(messagesToResummarize, summarizationPrompt);
// Generate the summary by sending unsummarized messages to the chat client
var summarizerChatMessages = ToSummarizerChatMessages(indexOfFirstMessageToKeep, summarizationPrompt);
var response = await chatClient.GetResponseAsync(summarizerChatMessages, cancellationToken: cancellationToken);
var newSummary = response.Text;

var lastSummarizedMessage = unsummarizedMessages[messagesToResummarize - 1];
// Attach the summary metadata to the last message being summarized
// This is what allows us to build on previously-generated summaries
var lastSummarizedMessage = unsummarizedMessages[indexOfFirstMessageToKeep - 1];
var additionalProperties = lastSummarizedMessage.AdditionalProperties ??= [];
additionalProperties[SummaryKey] = newSummary;

var newUnsummarizedMessages = unsummarizedMessages.Skip(messagesToResummarize).ToList();
// Compute the new list of unsummarized messages
var newUnsummarizedMessages = unsummarizedMessages.Skip(indexOfFirstMessageToKeep).ToList();
return new SummarizedConversation(newSummary, systemMessage, newUnsummarizedMessages);
}

/// <summary>Determines the index of the first message to keep (not summarize) based on target and threshold counts.</summary>
public int FindIndexOfFirstMessageToKeep(int targetCount, int thresholdCount)
{
var earliestAllowedIndex = unsummarizedMessages.Count - thresholdCount - targetCount;
if (earliestAllowedIndex <= 0)
{
// Not enough messages to warrant summarization
return 0;
}

// Start at the ideal cut point (keeping exactly targetCount messages)
var indexOfFirstMessageToKeep = unsummarizedMessages.Count - targetCount;

// Move backward to skip over function call/result content at the boundary
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is this never getting rid of any FunctionCall/ResultContent? Does that mean in a really long conversation we could end up with lots and lots of these, and eventually still overflow context windows?

// We want to keep complete function call sequences together with their responses
while (indexOfFirstMessageToKeep > 0)
{
if (!unsummarizedMessages[indexOfFirstMessageToKeep - 1].Contents.Any(c => c is FunctionCallContent or FunctionResultContent))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's also UserInputRequestContent / UserInputResponseContent. Does this need to pay attention to those as well?

{
break;
}

indexOfFirstMessageToKeep--;
}

// Search backward within the threshold window to find a User message
// If found, cut right before it to avoid orphaning user questions from responses
for (var i = indexOfFirstMessageToKeep; i >= earliestAllowedIndex; i--)
{
if (unsummarizedMessages[i].Role == ChatRole.User)
{
return i;
}
}

// No User message found within threshold - use the adjusted cut point
return indexOfFirstMessageToKeep;
}

/// <summary>Converts the summarized conversation back into a collection of chat messages.</summary>
public IEnumerable<ChatMessage> ToChatMessages()
{
if (systemMessage is not null)
Expand All @@ -154,16 +198,21 @@ public IEnumerable<ChatMessage> ToChatMessages()
}
}

private IEnumerable<ChatMessage> ToSummarizerChatMessages(int messagesToResummarize, string summarizationPrompt)
/// <summary>Builds the list of messages to send to the chat client for summarization.</summary>
private IEnumerable<ChatMessage> ToSummarizerChatMessages(int indexOfFirstMessageToKeep, string summarizationPrompt)
{
if (summary is not null)
{
yield return new ChatMessage(ChatRole.Assistant, summary);
}

for (var i = 0; i < messagesToResummarize; i++)
for (var i = 0; i < indexOfFirstMessageToKeep; i++)
{
yield return unsummarizedMessages[i];
var message = unsummarizedMessages[i];
if (!message.Contents.Any(c => c is FunctionCallContent or FunctionResultContent))
{
yield return message;
}
}

yield return new ChatMessage(ChatRole.System, summarizationPrompt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,27 +84,87 @@ public async Task ReduceAsync_PreservesSystemMessage()
}

[Fact]
public async Task ReduceAsync_IgnoresFunctionCallsAndResults()
public async Task ReduceAsync_PreservesCompleteToolCallSequence()
{
using var chatClient = new TestChatClient();
var reducer = new SummarizingChatReducer(chatClient, targetCount: 3, threshold: 0);

// Target 2 messages, but this would split a function call sequence
var reducer = new SummarizingChatReducer(chatClient, targetCount: 2, threshold: 0);

List<ChatMessage> messages =
[
new ChatMessage(ChatRole.User, "What's the time?"),
new ChatMessage(ChatRole.Assistant, "Let me check"),
new ChatMessage(ChatRole.User, "What's the weather?"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call1", "get_weather", new Dictionary<string, object?> { ["location"] = "Seattle" })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("call1", "Sunny, 72°F")]),
new ChatMessage(ChatRole.Assistant, "The weather in Seattle is sunny and 72°F."),
new ChatMessage(ChatRole.User, "Thanks!"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call1", "get_weather")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("call1", "Sunny")]),
new ChatMessage(ChatRole.Assistant, "It's sunny"),
];

chatClient.GetResponseAsyncCallback = (msgs, _, _) =>
{
Assert.DoesNotContain(msgs, m => m.Contents.Any(c => c is FunctionCallContent or FunctionResultContent));
return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "Asked about time")));
};

var result = await reducer.ReduceAsync(messages, CancellationToken.None);
var resultList = result.ToList();

// Should have: summary + function content + last reply
Assert.Equal(4, resultList.Count);

// Verify the complete sequence is preserved
Assert.Collection(resultList,
m => Assert.Contains("Asked about time", m.Text),
m => Assert.Contains(m.Contents, c => c is FunctionCallContent),
m => Assert.Contains(m.Contents, c => c is FunctionResultContent),
m => Assert.Contains("sunny", m.Text));
}

[Fact]
public async Task ReduceAsync_PreservesUserMessageWhenWithinThreshold()
{
using var chatClient = new TestChatClient();

// Target 3 messages with threshold of 2
// This allows us to keep anywhere from 3 to 5 messages
var reducer = new SummarizingChatReducer(chatClient, targetCount: 3, threshold: 2);

// Function calls/results should be ignored, which means there aren't enough messages to generate a summary.
List<ChatMessage> messages =
[
new ChatMessage(ChatRole.User, "First question"),
new ChatMessage(ChatRole.Assistant, "First answer"),
new ChatMessage(ChatRole.User, "Second question"),
new ChatMessage(ChatRole.Assistant, "Second answer"),
new ChatMessage(ChatRole.User, "Third question"),
new ChatMessage(ChatRole.Assistant, "Third answer"),
];

chatClient.GetResponseAsyncCallback = (msgs, _, _) =>
{
var msgList = msgs.ToList();

// Should summarize messages 0-1 (First question and answer)
// The reducer should find the User message at index 2 within the threshold
Assert.Equal(3, msgList.Count); // 2 messages to summarize + system prompt
return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "Summary of first exchange")));
};

var result = await reducer.ReduceAsync(messages, CancellationToken.None);
var resultList = result.ToList();
Assert.Equal(3, resultList.Count); // Function calls get removed in the summarized chat.
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is FunctionCallContent));
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is FunctionResultContent));

// Should have: summary + 4 kept messages (from "Second question" onward)
Assert.Equal(5, resultList.Count);

// Verify the summary is first
Assert.Contains("Summary", resultList[0].Text);

// Verify we kept the User message at index 2 and everything after
Assert.Collection(resultList.Skip(1),
m => Assert.Contains("Second question", m.Text),
m => Assert.Contains("Second answer", m.Text),
m => Assert.Contains("Third question", m.Text),
m => Assert.Contains("Third answer", m.Text));
}

[Theory]
Expand All @@ -121,7 +181,7 @@ public async Task ReduceAsync_RespectsTargetAndThresholdCounts(int targetCount,
var messages = new List<ChatMessage>();
for (int i = 0; i < messageCount; i++)
{
messages.Add(new ChatMessage(i % 2 == 0 ? ChatRole.User : ChatRole.Assistant, $"Message {i}"));
messages.Add(new ChatMessage(ChatRole.Assistant, $"Message {i}"));
}

var summarizationCalled = false;
Expand Down
Loading