Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,15 @@ public override async Task<ChatResponse> GetResponseAsync(
{
functionCallContents?.Clear();

// On the last iteration, we won't be processing any function calls, so we should not
// include AIFunctionDeclaration tools in the request to prevent the inner client from
// returning tool call requests that won't be handled.
if (iteration >= MaximumIterationsPerRequest)
{
LogMaximumIterationsReached(MaximumIterationsPerRequest);
PrepareOptionsForLastIteration(ref options);
}

// Make the call to the inner client.
response = await base.GetResponseAsync(messages, options, cancellationToken);
if (response is null)
Expand Down Expand Up @@ -486,6 +495,15 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
updates.Clear();
functionCallContents?.Clear();

// On the last iteration, we won't be processing any function calls, so we should not
// include AIFunctionDeclaration tools in the request to prevent the inner client from
// returning tool call requests that won't be handled.
if (iteration >= MaximumIterationsPerRequest)
{
LogMaximumIterationsReached(MaximumIterationsPerRequest);
PrepareOptionsForLastIteration(ref options);
}

bool hasApprovalRequiringFcc = false;
int lastApprovalCheckedFCCIndex = 0;
int lastYieldedUpdateIndex = 0;
Expand Down Expand Up @@ -824,6 +842,48 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions? options, stri
}
}

/// <summary>
/// Prepares options for the last iteration by removing AIFunctionDeclaration tools.
/// </summary>
/// <param name="options">The chat options to prepare.</param>
/// <remarks>
/// On the last iteration, we won't be processing any function calls, so we should not
/// include AIFunctionDeclaration tools in the request. This prevents the inner client
/// from returning tool call requests that won't be handled.
/// </remarks>
private static void PrepareOptionsForLastIteration(ref ChatOptions? options)
{
if (options?.Tools is not { Count: > 0 })
{
return;
}

// Filter out AIFunctionDeclaration tools, keeping only non-function tools
List<AITool>? remainingTools = null;
foreach (var tool in options.Tools)
{
if (tool is not AIFunctionDeclaration)
{
remainingTools ??= [];
remainingTools.Add(tool);
}
}

// If we removed any tools (including removing all of them), clone and update options
int remainingCount = remainingTools?.Count ?? 0;
if (remainingCount < options.Tools.Count)
{
options = options.Clone();
options.Tools = remainingTools;

// If no tools remain, clear the ToolMode as well
if (remainingCount == 0)
{
options.ToolMode = null;
}
}
}

/// <summary>Gets whether the function calling loop should exit based on the function call requests.</summary>
/// <param name="functionCalls">The call requests.</param>
/// <param name="toolMap">The map from tool names to tools.</param>
Expand Down Expand Up @@ -1693,6 +1753,9 @@ private static TimeSpan GetElapsedTime(long startingTimestamp) =>
[LoggerMessage(LogLevel.Error, "{MethodName} invocation failed.")]
private partial void LogInvocationFailed(string methodName, Exception error);

[LoggerMessage(LogLevel.Debug, "Reached maximum iteration count of {MaximumIterationsPerRequest}. Stopping function invocation loop.")]
private partial void LogMaximumIterationsReached(int maximumIterationsPerRequest);

/// <summary>Provides information about the invocation of a function call.</summary>
public sealed class FunctionInvocationResult
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,179 @@ public async Task ContinuesWithSuccessfulCallsUntilMaximumIterations()
Assert.Equal(maxIterations, actualCallCount);
}

[Fact]
public async Task LastIteration_RemovesFunctionDeclarationTools_NonStreaming()
{
List<ChatOptions?> capturedOptions = [];
var maxIterations = 2;

using var innerClient = new TestChatClient
{
GetResponseAsyncCallback = (contents, options, cancellationToken) =>
{
capturedOptions.Add(options?.Clone());

var message = new ChatMessage(ChatRole.Assistant, [new FunctionCallContent($"callId{capturedOptions.Count}", "Func1")]);
return Task.FromResult(new ChatResponse(message));
}
};

using var client = new FunctionInvokingChatClient(innerClient)
{
MaximumIterationsPerRequest = maxIterations
};

var options = new ChatOptions
{
Tools = [AIFunctionFactory.Create(() => "Result", "Func1")],
ToolMode = ChatToolMode.Auto
};

await client.GetResponseAsync("hello", options);

Assert.Equal(maxIterations + 1, capturedOptions.Count);

for (int i = 0; i < maxIterations; i++)
{
Assert.NotNull(capturedOptions[i]?.Tools);
Assert.Single(capturedOptions[i]!.Tools!);
}

var lastOptions = capturedOptions[maxIterations];
Assert.NotNull(lastOptions);
Assert.Null(lastOptions!.Tools);
Assert.Null(lastOptions.ToolMode);
}

[Fact]
public async Task LastIteration_RemovesFunctionDeclarationTools_Streaming()
{
List<ChatOptions?> capturedOptions = [];
var maxIterations = 2;

using var innerClient = new TestChatClient
{
GetStreamingResponseAsyncCallback = (contents, options, cancellationToken) =>
{
capturedOptions.Add(options?.Clone());

var message = new ChatMessage(ChatRole.Assistant, [new FunctionCallContent($"callId{capturedOptions.Count}", "Func1")]);
return YieldAsync(new ChatResponse(message).ToChatResponseUpdates());
}
};

using var client = new FunctionInvokingChatClient(innerClient)
{
MaximumIterationsPerRequest = maxIterations
};

var options = new ChatOptions
{
Tools = [AIFunctionFactory.Create(() => "Result", "Func1")],
ToolMode = ChatToolMode.Auto
};

await client.GetStreamingResponseAsync("hello", options).ToChatResponseAsync();

Assert.Equal(maxIterations + 1, capturedOptions.Count);

for (int i = 0; i < maxIterations; i++)
{
Assert.NotNull(capturedOptions[i]?.Tools);
Assert.Single(capturedOptions[i]!.Tools!);
}

var lastOptions = capturedOptions[maxIterations];
Assert.NotNull(lastOptions);
Assert.Null(lastOptions!.Tools);
Assert.Null(lastOptions.ToolMode);
}

[Fact]
public async Task LastIteration_PreservesNonFunctionDeclarationTools()
{
var hostedTool = new HostedWebSearchTool();
List<ChatOptions?> capturedOptions = [];
var maxIterations = 1;

using var innerClient = new TestChatClient
{
GetResponseAsyncCallback = (contents, options, cancellationToken) =>
{
capturedOptions.Add(options?.Clone());

if (capturedOptions.Count == 1)
{
var message = new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]);
return Task.FromResult(new ChatResponse(message));
}
else
{
var message = new ChatMessage(ChatRole.Assistant, "Done");
return Task.FromResult(new ChatResponse(message));
}
}
};

using var client = new FunctionInvokingChatClient(innerClient)
{
MaximumIterationsPerRequest = maxIterations
};

var options = new ChatOptions
{
Tools = [AIFunctionFactory.Create(() => "Result", "Func1"), hostedTool],
ToolMode = ChatToolMode.Auto
};

await client.GetResponseAsync("hello", options);

Assert.Equal(2, capturedOptions.Count);
Assert.NotNull(capturedOptions[0]?.Tools);
Assert.Equal(2, capturedOptions[0]!.Tools!.Count);

Assert.NotNull(capturedOptions[1]?.Tools);
Assert.Single(capturedOptions[1]!.Tools!);
Assert.IsType<HostedWebSearchTool>(capturedOptions[1]!.Tools![0]);
Assert.NotNull(capturedOptions[1]?.ToolMode);
}

[Fact]
public async Task LastIteration_DoesNotModifyOriginalOptions()
{
List<ChatOptions?> capturedOptions = [];
var maxIterations = 1;

using var innerClient = new TestChatClient
{
GetResponseAsyncCallback = (contents, options, cancellationToken) =>
{
capturedOptions.Add(options);
var message = new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]);
return Task.FromResult(new ChatResponse(message));
}
};

using var client = new FunctionInvokingChatClient(innerClient)
{
MaximumIterationsPerRequest = maxIterations
};

var originalTool = AIFunctionFactory.Create(() => "Result", "Func1");
var originalOptions = new ChatOptions
{
Tools = [originalTool],
ToolMode = ChatToolMode.Auto
};

await client.GetResponseAsync("hello", originalOptions);

Assert.NotNull(originalOptions.Tools);
Assert.Single(originalOptions.Tools);
Assert.Same(originalTool, originalOptions.Tools[0]);
Assert.NotNull(originalOptions.ToolMode);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
Expand Down
Loading