Skip to content

Commit c984bc8

Browse files
committed
async tools
1 parent 5ce1d02 commit c984bc8

File tree

6 files changed

+246
-34
lines changed

6 files changed

+246
-34
lines changed

src/Utility/ChatTools.cs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ namespace OpenAI.Chat;
1515
/// <summary>
1616
/// Provides functionality to manage and execute OpenAI function tools for chat completions.
1717
/// </summary>
18+
//[Experimental("OPENAIMCP001")]
1819
public class ChatTools
1920
{
2021
private readonly Dictionary<string, MethodInfo> _methods = [];
@@ -223,26 +224,17 @@ await ToolsUtility.GetEmbeddingAsync(_client, prompt).ConfigureAwait(false) :
223224
}
224225
}
225226

226-
internal string CallLocal(ChatToolCall call)
227+
internal async Task<string> CallFunctionToolAsync(ChatToolCall call)
227228
{
228229
var arguments = new List<object>();
229230
if (call.FunctionArguments != null)
230231
{
231232
if (!_methods.TryGetValue(call.FunctionName, out MethodInfo method))
232-
return $"I don't have a tool called {call.FunctionName}";
233+
throw new InvalidOperationException($"Tool not found: {call.FunctionName}");
233234

234235
ToolsUtility.ParseFunctionCallArgs(method, call.FunctionArguments, out arguments);
235236
}
236-
return CallLocal(call.FunctionName, [.. arguments]);
237-
}
238-
239-
private string CallLocal(string name, object[] arguments)
240-
{
241-
if (!_methods.TryGetValue(name, out MethodInfo method))
242-
return $"I don't have a tool called {name}";
243-
244-
object result = method.Invoke(null, arguments);
245-
return result?.ToString() ?? string.Empty;
237+
return await ToolsUtility.CallFunctionToolAsync(_methods, call.FunctionName, [.. arguments]);
246238
}
247239

248240
internal async Task<string> CallMcpAsync(ChatToolCall call)
@@ -285,7 +277,7 @@ public async Task<IEnumerable<ToolChatMessage>> CallAsync(IEnumerable<ChatToolCa
285277
}
286278
}
287279

288-
var result = isMcpTool ? await CallMcpAsync(toolCall).ConfigureAwait(false) : CallLocal(toolCall);
280+
var result = isMcpTool ? await CallMcpAsync(toolCall).ConfigureAwait(false) : await CallFunctionToolAsync(toolCall).ConfigureAwait(false);
289281
messages.Add(new ToolChatMessage(toolCall.Id, result));
290282
}
291283

src/Utility/MCP/McpClient.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ namespace OpenAI.Agents;
77
/// <summary>
88
/// Client for interacting with a Model Context Protocol (MCP) server.
99
/// </summary>
10+
//[Experimental("OPENAIMCP001")]
1011
public class McpClient
1112
{
1213
private readonly McpSession _session;

src/Utility/ResponseTools.cs

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ namespace OpenAI.Responses;
1515
/// <summary>
1616
/// Provides functionality to manage and execute OpenAI function tools for responses.
1717
/// </summary>
18-
public class ResponseTools
18+
//[Experimental("OPENAIMCP001")
19+
public class CallLocalAsync
1920
{
2021
private readonly Dictionary<string, MethodInfo> _methods = [];
2122
private readonly Dictionary<string, Func<string, BinaryData, Task<BinaryData>>> _mcpMethods = [];
@@ -29,7 +30,7 @@ public class ResponseTools
2930
/// Initializes a new instance of the ResponseTools class with an optional embedding client.
3031
/// </summary>
3132
/// <param name="client">The embedding client used for tool vectorization, or null to disable vectorization.</param>
32-
public ResponseTools(EmbeddingClient client = null)
33+
public CallLocalAsync(EmbeddingClient client = null)
3334
{
3435
_client = client;
3536
}
@@ -38,7 +39,7 @@ public ResponseTools(EmbeddingClient client = null)
3839
/// Initializes a new instance of the ResponseTools class with the specified tool types.
3940
/// </summary>
4041
/// <param name="tools">Additional tool types to add.</param>
41-
public ResponseTools(params Type[] tools) : this((EmbeddingClient)null)
42+
public CallLocalAsync(params Type[] tools) : this((EmbeddingClient)null)
4243
{
4344
foreach (var t in tools)
4445
AddFunctionTool(t);
@@ -225,7 +226,7 @@ await ToolsUtility.GetEmbeddingAsync(_client, prompt).ConfigureAwait(false) :
225226
}
226227
}
227228

228-
internal string CallLocal(FunctionCallResponseItem call)
229+
internal async Task<string> CallFunctionToolAsync(FunctionCallResponseItem call)
229230
{
230231
List<object> arguments = new();
231232
if (call.FunctionArguments != null)
@@ -235,16 +236,8 @@ internal string CallLocal(FunctionCallResponseItem call)
235236

236237
ToolsUtility.ParseFunctionCallArgs(method, call.FunctionArguments, out arguments);
237238
}
238-
return CallLocal(call.FunctionName, [.. arguments]);
239-
}
240-
241-
private string CallLocal(string name, object[] arguments)
242-
{
243-
if (!_methods.TryGetValue(name, out MethodInfo method))
244-
return $"I don't have a tool called {name}";
245239

246-
object result = method.Invoke(null, arguments);
247-
return result?.ToString() ?? string.Empty;
240+
return await ToolsUtility.CallFunctionToolAsync(_methods, call.FunctionName, [.. arguments]);
248241
}
249242

250243
internal async Task<string> CallMcpAsync(FunctionCallResponseItem call)
@@ -282,7 +275,7 @@ public async Task<FunctionCallOutputResponseItem> CallAsync(FunctionCallResponse
282275
}
283276
}
284277

285-
var result = isMcpTool ? await CallMcpAsync(toolCall).ConfigureAwait(false) : CallLocal(toolCall);
278+
var result = isMcpTool ? await CallMcpAsync(toolCall).ConfigureAwait(false) : await CallFunctionToolAsync(toolCall);
286279
return new FunctionCallOutputResponseItem(toolCall.CallId, result);
287280
}
288281
}

src/Utility/ToolsUtility.cs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,4 +210,44 @@ internal static BinaryData SerializeTool(string name, string description, Binary
210210
stream.Position = 0;
211211
return BinaryData.FromStream(stream);
212212
}
213+
214+
internal static async Task<string> CallFunctionToolAsync(Dictionary<string, MethodInfo> methods, string name, object[] arguments)
215+
{
216+
if (!methods.TryGetValue(name, out MethodInfo method))
217+
throw new InvalidOperationException($"Tool not found: {name}");
218+
219+
object result;
220+
if (IsGenericTask(method.ReturnType, out Type taskResultType))
221+
{
222+
// Method is async, invoke and await
223+
var task = (Task)method.Invoke(null, arguments);
224+
await task.ConfigureAwait(false);
225+
// Get the Result property from the Task
226+
result = taskResultType.GetProperty("Result").GetValue(task);
227+
}
228+
else
229+
{
230+
// Method is synchronous
231+
result = method.Invoke(null, arguments);
232+
}
233+
234+
return result?.ToString() ?? string.Empty;
235+
}
236+
237+
private static bool IsGenericTask(Type type, out Type taskResultType)
238+
{
239+
while (type != null && type != typeof(object))
240+
{
241+
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Task<>))
242+
{
243+
taskResultType = type;//type.GetGenericArguments()[0];
244+
return true;
245+
}
246+
247+
type = type.BaseType!;
248+
}
249+
250+
taskResultType = null;
251+
return false;
252+
}
213253
}

tests/Utility/ChatToolsTests.cs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,45 @@ private class TestTools
2929
public static string ConcatWithBool(string text, bool flag) => $"{text}:{flag}";
3030
}
3131

32+
private class TestToolsAsync
33+
{
34+
public static async Task<string> EchoAsync(string message)
35+
{
36+
await Task.Delay(1); // Simulate async work
37+
return message;
38+
}
39+
40+
public static async Task<int> AddAsync(int a, int b)
41+
{
42+
await Task.Delay(1); // Simulate async work
43+
return a + b;
44+
}
45+
46+
public static async Task<double> MultiplyAsync(double x, double y)
47+
{
48+
await Task.Delay(1); // Simulate async work
49+
return x * y;
50+
}
51+
52+
public static async Task<bool> IsGreaterThanAsync(long value1, long value2)
53+
{
54+
await Task.Delay(1); // Simulate async work
55+
return value1 > value2;
56+
}
57+
58+
public static async Task<float> DivideAsync(float numerator, float denominator)
59+
{
60+
await Task.Delay(1); // Simulate async work
61+
return numerator / denominator;
62+
}
63+
64+
public static async Task<string> ConcatWithBoolAsync(string text, bool flag)
65+
{
66+
await Task.Delay(1); // Simulate async work
67+
return $"{text}:{flag}";
68+
}
69+
}
70+
3271
private Mock<EmbeddingClient> mockEmbeddingClient;
3372

3473
[SetUp]
@@ -52,6 +91,21 @@ public void CanAddLocalTools()
5291
Assert.That(tools.Tools.Any(t => t.FunctionName == "ConcatWithBool"));
5392
}
5493

94+
[Test]
95+
public void CanAddAsyncLocalTools()
96+
{
97+
var tools = new ChatTools();
98+
tools.AddFunctionTools(typeof(TestToolsAsync));
99+
100+
Assert.That(tools.Tools, Has.Count.EqualTo(6));
101+
Assert.That(tools.Tools.Any(t => t.FunctionName == "EchoAsync"));
102+
Assert.That(tools.Tools.Any(t => t.FunctionName == "AddAsync"));
103+
Assert.That(tools.Tools.Any(t => t.FunctionName == "MultiplyAsync"));
104+
Assert.That(tools.Tools.Any(t => t.FunctionName == "IsGreaterThanAsync"));
105+
Assert.That(tools.Tools.Any(t => t.FunctionName == "DivideAsync"));
106+
Assert.That(tools.Tools.Any(t => t.FunctionName == "ConcatWithBoolAsync"));
107+
}
108+
55109
[Test]
56110
public async Task CanCallToolsAsync()
57111
{
@@ -86,6 +140,40 @@ public async Task CanCallToolsAsync()
86140
Assert.That(resultsList[5].Content[0].Text, Is.EqualTo("Test:True"));
87141
}
88142

143+
[Test]
144+
public async Task CanCallAsyncToolsAsync()
145+
{
146+
var tools = new ChatTools();
147+
tools.AddFunctionTools(typeof(TestToolsAsync));
148+
149+
var toolCalls = new[]
150+
{
151+
ChatToolCall.CreateFunctionToolCall("call1", "EchoAsync", BinaryData.FromString(@"{""message"": ""Hello""}")),
152+
ChatToolCall.CreateFunctionToolCall("call2", "AddAsync", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")),
153+
ChatToolCall.CreateFunctionToolCall("call3", "MultiplyAsync", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")),
154+
ChatToolCall.CreateFunctionToolCall("call4", "IsGreaterThanAsync", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")),
155+
ChatToolCall.CreateFunctionToolCall("call5", "DivideAsync", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")),
156+
ChatToolCall.CreateFunctionToolCall("call6", "ConcatWithBoolAsync", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}"))
157+
};
158+
159+
var results = await tools.CallAsync(toolCalls);
160+
var resultsList = results.ToList();
161+
162+
Assert.That(resultsList, Has.Count.EqualTo(6));
163+
Assert.That(resultsList[0].ToolCallId, Is.EqualTo("call1"));
164+
Assert.That(resultsList[0].Content[0].Text, Is.EqualTo("Hello"));
165+
Assert.That(resultsList[1].ToolCallId, Is.EqualTo("call2"));
166+
Assert.That(resultsList[1].Content[0].Text, Is.EqualTo("5"));
167+
Assert.That(resultsList[2].ToolCallId, Is.EqualTo("call3"));
168+
Assert.That(resultsList[2].Content[0].Text, Is.EqualTo("7.5"));
169+
Assert.That(resultsList[3].ToolCallId, Is.EqualTo("call4"));
170+
Assert.That(resultsList[3].Content[0].Text, Is.EqualTo("True"));
171+
Assert.That(resultsList[4].ToolCallId, Is.EqualTo("call5"));
172+
Assert.That(resultsList[4].Content[0].Text, Is.EqualTo("5"));
173+
Assert.That(resultsList[5].ToolCallId, Is.EqualTo("call6"));
174+
Assert.That(resultsList[5].Content[0].Text, Is.EqualTo("Test:True"));
175+
}
176+
89177
[Test]
90178
public void CreatesCompletionOptionsWithTools()
91179
{

0 commit comments

Comments
 (0)