Skip to content

Commit 0a4a4d1

Browse files
committed
Allow finding tool calls by the result type only
Allows ignoring the tool name for cases where the result type is distinct enough for the lookup.
1 parent 8d068f1 commit 0a4a4d1

File tree

2 files changed

+84
-10
lines changed

2 files changed

+84
-10
lines changed

src/AI.Tests/ToolsTests.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,39 @@ public async Task RunToolResult()
5151
Assert.Null(result.Outcome.Exception);
5252
}
5353

54+
[SecretsFact("OPENAI_API_KEY")]
55+
public async Task FindToolResultByTypeOnly()
56+
{
57+
var chat = new Chat()
58+
{
59+
{ "system", "You make up a tool run by making up a name, description and content based on whatever the user says." },
60+
{ "user", "I want to create an order for a dozen eggs" },
61+
};
62+
63+
var client = new OpenAIChatClient(Configuration["OPENAI_API_KEY"]!, "gpt-4.1",
64+
global::OpenAI.OpenAIClientOptions.WriteTo(output))
65+
.AsBuilder()
66+
.UseFunctionInvocation()
67+
.Build();
68+
69+
var tool = ToolFactory.Create(RunTool);
70+
var options = new ChatOptions
71+
{
72+
ToolMode = ChatToolMode.RequireSpecific(tool.Name),
73+
Tools = [tool]
74+
};
75+
76+
var response = await client.GetResponseAsync(chat, options);
77+
var result = response.FindCalls<ToolResult>().FirstOrDefault();
78+
79+
Assert.NotNull(result);
80+
Assert.NotNull(result.Call);
81+
Assert.Equal(tool.Name, result.Call.Name);
82+
Assert.NotNull(result.Outcome);
83+
Assert.Null(result.Outcome.Exception);
84+
}
85+
86+
5487
[SecretsFact("OPENAI_API_KEY")]
5588
public async Task RunToolTerminateResult()
5689
{

src/AI/ToolExtensions.cs

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,29 +39,32 @@ public static IEnumerable<ToolCall> FindCalls(this IEnumerable<ChatMessage> mess
3939
=> FindCalls(messages, tool.Name);
4040

4141
/// <summary>
42-
/// Looks for calls to a tool and their outcome.
42+
/// Looks for calls to a tool and their outcome, optionally filtering by tool name.
4343
/// </summary>
44-
public static IEnumerable<ToolCall> FindCalls(this IEnumerable<ChatMessage> messages, string tool)
44+
public static IEnumerable<ToolCall> FindCalls(this IEnumerable<ChatMessage> messages, string? tool = default)
4545
{
46-
var calls = messages
46+
var filtered = messages
4747
.Where(x => x.Role == ChatRole.Assistant)
4848
.SelectMany(x => x.Contents)
49-
.OfType<FunctionCallContent>()
50-
.Where(x => x.Name == tool)
51-
.ToDictionary(x => x.CallId);
49+
.OfType<FunctionCallContent>();
50+
51+
if (!string.IsNullOrEmpty(tool))
52+
filtered = filtered.Where(x => x.Name == tool);
53+
54+
var calls = filtered.ToDictionary(x => x.CallId);
5255

5356
var results = messages
5457
.Where(x => x.Role == ChatRole.Tool)
5558
.SelectMany(x => x.Contents)
5659
.OfType<FunctionResultContent>()
57-
.Where(x => calls.TryGetValue(x.CallId, out var call) && call.Name == tool)
60+
.Where(x => calls.ContainsKey(x.CallId))
5861
.Select(x => new ToolCall(calls[x.CallId], x));
5962

6063
return results;
6164
}
6265

6366
/// <summary>
64-
/// Looks for a user prompt in the chat response messages.
67+
/// Looks for calls to a tool where the result is of a given type <typeparamref name="TResult"/>
6568
/// </summary>
6669
/// <remarks>
6770
/// In order for this to work, the <see cref="AIFunctionFactory"/> must have been invoked using
@@ -73,7 +76,19 @@ public static IEnumerable<ToolCall<TResult>> FindCalls<TResult>(this ChatRespons
7376
=> FindCalls<TResult>(response.Messages, tool.Name);
7477

7578
/// <summary>
76-
/// Looks for a user prompt in the chat response messages.
79+
/// Looks for calls where the result is of a given type <typeparamref name="TResult"/> regadless of the tool.
80+
/// </summary>
81+
/// <remarks>
82+
/// In order for this to work, the <see cref="AIFunctionFactory"/> must have been invoked using
83+
/// the <see cref="ToolJsonOptions.Default"/> or with a <see cref="JsonSerializerOptions"/> configured
84+
/// with <see cref="TypeInjectingResolverExtensions.WithTypeInjection(JsonSerializerOptions)"/> so
85+
/// that the tool result type can be properly inspected.
86+
/// </remarks>
87+
public static IEnumerable<ToolCall<TResult>> FindCalls<TResult>(this ChatResponse response)
88+
=> FindCalls<TResult>(response.Messages);
89+
90+
/// <summary>
91+
/// Looks for calls to a tool where the result is of a given type <typeparamref name="TResult"/>
7792
/// </summary>
7893
/// <remarks>
7994
/// In order for this to work, the <see cref="AIFunctionFactory"/> must have been invoked using
@@ -85,7 +100,7 @@ public static IEnumerable<ToolCall<TResult>> FindCalls<TResult>(this IEnumerable
85100
=> FindCalls<TResult>(messages, tool.Name);
86101

87102
/// <summary>
88-
/// Looks for a user prompt in the chat response messages.
103+
/// Looks for calls to a tool where the result is of a given type <typeparamref name="TResult"/>
89104
/// </summary>
90105
/// <remarks>
91106
/// In order for this to work, the <see cref="AIFunctionFactory"/> must have been invoked using
@@ -108,4 +123,30 @@ public static IEnumerable<ToolCall<TResult>> FindCalls<TResult>(this IEnumerable
108123

109124
return calls;
110125
}
126+
127+
/// <summary>
128+
/// Looks for calls to a tool where the result is of a given type <typeparamref name="TResult"/>
129+
/// regardless of the tool name.
130+
/// </summary>
131+
/// <remarks>
132+
/// In order for this to work, the <see cref="AIFunctionFactory"/> must have been invoked using
133+
/// the <see cref="ToolJsonOptions.Default"/> or with a <see cref="JsonSerializerOptions"/> configured
134+
/// with <see cref="TypeInjectingResolverExtensions.WithTypeInjection(JsonSerializerOptions)"/> so
135+
/// that the tool result type can be properly inspected.
136+
/// </remarks>
137+
public static IEnumerable<ToolCall<TResult>> FindCalls<TResult>(this IEnumerable<ChatMessage> messages)
138+
{
139+
var calls = FindCalls(messages)
140+
.Where(x => x.Outcome.Result is JsonElement element &&
141+
element.ValueKind == JsonValueKind.Object &&
142+
element.TryGetProperty("$type", out var type) &&
143+
type.GetString() == typeof(TResult).FullName)
144+
.Select(x => new ToolCall<TResult>(
145+
Call: x.Call,
146+
Outcome: x.Outcome,
147+
Result: JsonSerializer.Deserialize<TResult>((JsonElement)x.Outcome.Result!, ToolJsonOptions.Default) ??
148+
throw new InvalidOperationException($"Failed to deserialize result for tool '{x.Call.Name}' to {typeof(TResult).FullName}.")));
149+
150+
return calls;
151+
}
111152
}

0 commit comments

Comments
 (0)