Skip to content

Commit 9308b2b

Browse files
authored
Add ChatTools and ResponseTools helper classes (#422)
This PR adds functionality for managing OpenAI function tools in chat completions and responses, implementing two utility classes: `ChatTools` and `ResponseTools`. ### Key Features - Unified Tool Management - Both classes handle local and remote MCP (Model Context Protocol) tools - Support for both static method tools and remote MCP server tools - Vectorized tool lookup using embeddings for smart tool filtering - Tool Registration - Local tool registration via reflection from static methods - Remote tool registration from MCP servers - Tool Filtering - Vector-based tool filtering when an `EmbeddingClient` is provided - Configurable similarity thresholds - Customizable maximum tool limits
1 parent 06191fa commit 9308b2b

File tree

12 files changed

+2323
-199
lines changed

12 files changed

+2323
-199
lines changed

src/OpenAI.csproj

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@
6060
<PublicKey>0024000004800000940000000602000000240000525341310004000001000100097ad52abbeaa2e1a1982747cc0106534f65cfea6707eaed696a3a63daea80de2512746801a7e47f88e7781e71af960d89ba2e25561f70b0e2dbc93319e0af1961a719ccf5a4d28709b2b57a5d29b7c09dc8d269a490ebe2651c4b6e6738c27c5fb2c02469fe9757f0a3479ac310d6588a50a28d7dd431b907fd325e18b9e8ed</PublicKey>
6161
</InternalsVisibleTo>
6262
<InternalsVisibleTo Include="Azure.AI.OpenAI" Condition="'$(Configuration)' == 'Unsigned'" />
63+
<InternalsVisibleTo Include="OpenAI.Tests" Condition="'$(Configuration)' != 'Unsigned'">
64+
<PublicKey>0024000004800000940000000602000000240000525341310004000001000100b197326f2e5bfe2e2a49eb2a05bee871c55cc894325b3775159732ad816c4f304916e7f154295486f8ccabefa3c19b059d51cd19987cc2d31a3195d6203ad0948662f51cc61cc3eb535fc852dfe5159318c734b163f7d1387f1112e1ffe10f83aae7b809c4e36cf2025da5d1aed6b67e1556883d8778eeb63131c029555166de</PublicKey>
65+
</InternalsVisibleTo>
66+
<InternalsVisibleTo Include="OpenAI.Tests" Condition="'$(Configuration)' == 'Unsigned'" />
6367
</ItemGroup>
6468

6569
<PropertyGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
@@ -83,5 +87,6 @@
8387
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="8.0.0" PrivateAssets="All" />
8488
<PackageReference Include="System.ClientModel" Version="1.2.1" />
8589
<PackageReference Include="System.Diagnostics.DiagnosticSource" Version="6.0.1" />
90+
<PackageReference Update="Microsoft.Bcl.Numerics" Version="8.0.0" />
8691
</ItemGroup>
8792
</Project>

src/Utility/ChatTools.cs

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
using System;
2+
using System.ClientModel.Primitives;
3+
using System.Collections.Generic;
4+
using System.IO;
5+
using System.Linq;
6+
using System.Reflection;
7+
using System.Text.Json;
8+
using System.Threading.Tasks;
9+
using OpenAI.Agents;
10+
using OpenAI.Chat;
11+
using OpenAI.Embeddings;
12+
13+
namespace OpenAI.Chat;
14+
15+
/// <summary>
16+
/// Provides functionality to manage and execute OpenAI function tools for chat completions.
17+
/// </summary>
18+
//[Experimental("OPENAIMCP001")]
19+
public class ChatTools
20+
{
21+
private readonly Dictionary<string, MethodInfo> _methods = [];
22+
private readonly Dictionary<string, Func<string, BinaryData, Task<BinaryData>>> _mcpMethods = [];
23+
private readonly List<ChatTool> _tools = [];
24+
private readonly EmbeddingClient _client;
25+
private readonly List<VectorDatabaseEntry> _entries = [];
26+
private readonly List<McpClient> _mcpClients = [];
27+
private readonly Dictionary<string, McpClient> _mcpClientsByEndpoint = [];
28+
29+
/// <summary>
30+
/// Initializes a new instance of the ChatTools class with an optional embedding client.
31+
/// </summary>
32+
/// <param name="client">The embedding client used for tool vectorization, or null to disable vectorization.</param>
33+
public ChatTools(EmbeddingClient client = null)
34+
{
35+
_client = client;
36+
}
37+
38+
/// <summary>
39+
/// Initializes a new instance of the ChatTools class with the specified tool types.
40+
/// </summary>
41+
/// <param name="tools">Additional tool types to add.</param>
42+
public ChatTools(params Type[] tools) : this((EmbeddingClient)null)
43+
{
44+
foreach (var t in tools)
45+
AddFunctionTool(t);
46+
}
47+
48+
/// <summary>
49+
/// Gets the list of defined tools.
50+
/// </summary>
51+
public IList<ChatTool> Tools => _tools;
52+
53+
/// <summary>
54+
/// Gets whether tools can be filtered using embeddings provided by the provided <see cref="EmbeddingClient"/> .
55+
/// </summary>
56+
public bool CanFilterTools => _client != null;
57+
58+
/// <summary>
59+
/// Adds local tool implementations from the provided types.
60+
/// </summary>
61+
/// <param name="tools">Types containing static methods to be used as tools.</param>
62+
public void AddFunctionTools(params Type[] tools)
63+
{
64+
foreach (Type functionHolder in tools)
65+
AddFunctionTool(functionHolder);
66+
}
67+
68+
/// <summary>
69+
/// Adds all public static methods from the specified type as tools.
70+
/// </summary>
71+
/// <param name="tool">The type containing tool methods.</param>
72+
internal void AddFunctionTool(Type tool)
73+
{
74+
#pragma warning disable IL2070
75+
foreach (MethodInfo function in tool.GetMethods(BindingFlags.Public | BindingFlags.Static))
76+
{
77+
AddFunctionTool(function);
78+
}
79+
#pragma warning restore IL2070
80+
}
81+
82+
internal void AddFunctionTool(MethodInfo function)
83+
{
84+
string name = function.Name;
85+
var tool = ChatTool.CreateFunctionTool(name, ToolsUtility.GetMethodDescription(function), ToolsUtility.BuildParametersJson(function.GetParameters()));
86+
_tools.Add(tool);
87+
_methods[name] = function;
88+
}
89+
90+
/// <summary>
91+
/// Adds a remote MCP server as a tool provider.
92+
/// </summary>
93+
/// <param name="client">The MCP client instance.</param>
94+
/// <returns>A task representing the asynchronous operation.</returns>
95+
internal async Task AddMcpToolsAsync(McpClient client)
96+
{
97+
if (client == null) throw new ArgumentNullException(nameof(client));
98+
_mcpClientsByEndpoint[client.Endpoint.AbsoluteUri] = client;
99+
await client.StartAsync().ConfigureAwait(false);
100+
BinaryData tools = await client.ListToolsAsync().ConfigureAwait(false);
101+
await AddMcpToolsAsync(tools, client).ConfigureAwait(false);
102+
_mcpClients.Add(client);
103+
}
104+
105+
/// <summary>
106+
/// Adds a remote MCP server as a tool provider.
107+
/// </summary>
108+
/// <param name="mcpEndpoint">The URI endpoint of the MCP server.</param>
109+
/// <returns>A task representing the asynchronous operation.</returns>
110+
public async Task AddMcpToolsAsync(Uri mcpEndpoint)
111+
{
112+
var client = new McpClient(mcpEndpoint);
113+
await AddMcpToolsAsync(client).ConfigureAwait(false);
114+
}
115+
116+
private async Task AddMcpToolsAsync(BinaryData toolDefinitions, McpClient client)
117+
{
118+
List<ChatTool> toolsToVectorize = new();
119+
var parsedTools = ToolsUtility.ParseMcpToolDefinitions(toolDefinitions, client);
120+
121+
foreach (var (name, description, inputSchema) in parsedTools)
122+
{
123+
var chatTool = ChatTool.CreateFunctionTool(name, description, BinaryData.FromString(inputSchema));
124+
_tools.Add(chatTool);
125+
toolsToVectorize.Add(chatTool);
126+
_mcpMethods[name] = client.CallToolAsync;
127+
}
128+
129+
if (_client != null)
130+
{
131+
var embeddings = await _client.GenerateEmbeddingsAsync(toolsToVectorize.Select(t => t.FunctionDescription).ToList()).ConfigureAwait(false);
132+
foreach (var embedding in embeddings.Value)
133+
{
134+
var vector = embedding.ToFloats();
135+
var item = toolsToVectorize[embedding.Index];
136+
var toolDefinition = SerializeTool(item);
137+
_entries.Add(new VectorDatabaseEntry(vector, toolDefinition));
138+
}
139+
}
140+
}
141+
142+
private BinaryData SerializeTool(ChatTool tool)
143+
{
144+
return ToolsUtility.SerializeTool(tool.FunctionName, tool.FunctionDescription, tool.FunctionParameters);
145+
}
146+
147+
private ChatTool ParseToolDefinition(BinaryData data)
148+
{
149+
using var document = JsonDocument.Parse(data);
150+
var root = document.RootElement;
151+
152+
return ChatTool.CreateFunctionTool(
153+
root.GetProperty("name").GetString()!,
154+
root.GetProperty("description").GetString()!,
155+
BinaryData.FromString(root.GetProperty("inputSchema").GetRawText()));
156+
}
157+
158+
/// <summary>
159+
/// Converts the tools collection to chat completion options.
160+
/// </summary>
161+
/// <returns>A new ChatCompletionOptions containing all defined tools.</returns>
162+
public ChatCompletionOptions ToChatCompletionOptions()
163+
{
164+
var options = new ChatCompletionOptions();
165+
foreach (var tool in _tools)
166+
options.Tools.Add(tool);
167+
return options;
168+
}
169+
170+
/// <summary>
171+
/// Converts the tools collection to <see cref="ChatCompletionOptions"/>, filtered by relevance to the given prompt.
172+
/// </summary>
173+
/// <param name="prompt">The prompt to find relevant tools for.</param>
174+
/// <param name="maxTools">The maximum number of tools to return. Default is 3.</param>
175+
/// <param name="minVectorDistance">The similarity threshold for including tools. Default is 0.29.</param>
176+
/// <returns>A new <see cref="ChatCompletionOptions"/> containing the most relevant tools.</returns>
177+
public ChatCompletionOptions CreateCompletionOptions(string prompt, int maxTools = 5, float minVectorDistance = 0.29f)
178+
{
179+
if (!CanFilterTools)
180+
return ToChatCompletionOptions();
181+
182+
var completionOptions = new ChatCompletionOptions();
183+
foreach (var tool in FindRelatedTools(false, prompt, maxTools, minVectorDistance).GetAwaiter().GetResult())
184+
completionOptions.Tools.Add(tool);
185+
return completionOptions;
186+
}
187+
188+
/// <summary>
189+
/// Converts the tools collection to <see cref="ChatCompletionOptions"/>, filtered by relevance to the given prompt.
190+
/// </summary>
191+
/// <param name="prompt">The prompt to find relevant tools for.</param>
192+
/// <param name="maxTools">The maximum number of tools to return. Default is 3.</param>
193+
/// <param name="minVectorDistance">The similarity threshold for including tools. Default is 0.29.</param>
194+
/// <returns>A new <see cref="ChatCompletionOptions"/> containing the most relevant tools.</returns>
195+
public async Task<ChatCompletionOptions> ToChatCompletionOptions(string prompt, int maxTools = 5, float minVectorDistance = 0.29f)
196+
{
197+
if (!CanFilterTools)
198+
return ToChatCompletionOptions();
199+
200+
var completionOptions = new ChatCompletionOptions();
201+
foreach (var tool in await FindRelatedTools(true, prompt, maxTools, minVectorDistance).ConfigureAwait(false))
202+
completionOptions.Tools.Add(tool);
203+
return completionOptions;
204+
}
205+
206+
private async Task<IEnumerable<ChatTool>> FindRelatedTools(bool async, string prompt, int maxTools, float minVectorDistance)
207+
{
208+
if (!CanFilterTools)
209+
return _tools;
210+
211+
return (await FindVectorMatches(async, prompt, maxTools, minVectorDistance).ConfigureAwait(false))
212+
.Select(e => ParseToolDefinition(e.Data));
213+
}
214+
215+
private async Task<IEnumerable<VectorDatabaseEntry>> FindVectorMatches(bool async, string prompt, int maxTools, float minVectorDistance)
216+
{
217+
var vector = async ?
218+
await ToolsUtility.GetEmbeddingAsync(_client, prompt).ConfigureAwait(false) :
219+
ToolsUtility.GetEmbedding(_client, prompt);
220+
221+
lock (_entries)
222+
{
223+
return ToolsUtility.GetClosestEntries(_entries, maxTools, minVectorDistance, vector);
224+
}
225+
}
226+
227+
internal async Task<string> CallFunctionToolAsync(ChatToolCall call)
228+
{
229+
var arguments = new List<object>();
230+
if (call.FunctionArguments != null)
231+
{
232+
if (!_methods.TryGetValue(call.FunctionName, out MethodInfo method))
233+
throw new InvalidOperationException($"Tool not found: {call.FunctionName}");
234+
235+
ToolsUtility.ParseFunctionCallArgs(method, call.FunctionArguments, out arguments);
236+
}
237+
return await ToolsUtility.CallFunctionToolAsync(_methods, call.FunctionName, [.. arguments]);
238+
}
239+
240+
internal async Task<string> CallMcpAsync(ChatToolCall call)
241+
{
242+
if (!_mcpMethods.TryGetValue(call.FunctionName, out var method))
243+
throw new NotImplementedException($"MCP tool {call.FunctionName} not found.");
244+
245+
#if !NETSTANDARD2_0
246+
var actualFunctionName = call.FunctionName.Split(ToolsUtility.McpToolSeparator, 2)[1];
247+
#else
248+
var index = call.FunctionName.IndexOf(ToolsUtility.McpToolSeparator);
249+
var actualFunctionName = call.FunctionName.Substring(index + ToolsUtility.McpToolSeparator.Length);
250+
#endif
251+
var result = await method(actualFunctionName, call.FunctionArguments).ConfigureAwait(false);
252+
if (result == null)
253+
throw new InvalidOperationException($"MCP tool {call.FunctionName} returned null. Function tools should always return a value.");
254+
return result.ToString();
255+
}
256+
257+
/// <summary>
258+
/// Executes all tool calls and returns their results.
259+
/// </summary>
260+
/// <param name="toolCalls">The collection of tool calls to execute.</param>
261+
/// <returns>A collection of tool chat messages containing the results.</returns>
262+
public async Task<IEnumerable<ToolChatMessage>> CallAsync(IEnumerable<ChatToolCall> toolCalls)
263+
{
264+
var messages = new List<ToolChatMessage>();
265+
foreach (ChatToolCall toolCall in toolCalls)
266+
{
267+
bool isMcpTool = false;
268+
if (!_methods.ContainsKey(toolCall.FunctionName))
269+
{
270+
if (_mcpMethods.ContainsKey(toolCall.FunctionName))
271+
{
272+
isMcpTool = true;
273+
}
274+
else
275+
{
276+
throw new InvalidOperationException("Tool not found: " + toolCall.FunctionName);
277+
}
278+
}
279+
280+
var result = isMcpTool ? await CallMcpAsync(toolCall).ConfigureAwait(false) : await CallFunctionToolAsync(toolCall).ConfigureAwait(false);
281+
messages.Add(new ToolChatMessage(toolCall.Id, result));
282+
}
283+
284+
return messages;
285+
}
286+
}
287+

src/Utility/MCP/McpClient.cs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
using System;
2+
using System.ClientModel.Primitives;
3+
using System.Threading.Tasks;
4+
5+
namespace OpenAI.Agents;
6+
7+
/// <summary>
8+
/// Client for interacting with a Model Context Protocol (MCP) server.
9+
/// </summary>
10+
//[Experimental("OPENAIMCP001")]
11+
internal class McpClient
12+
{
13+
private readonly McpSession _session;
14+
private readonly ClientPipeline _pipeline;
15+
16+
/// <summary>
17+
/// Gets the endpoint URI of the MCP server.
18+
/// </summary>
19+
public virtual Uri Endpoint { get; }
20+
21+
/// <summary>
22+
/// Initializes a new instance of the <see cref="McpClient"/> class.
23+
/// </summary>
24+
/// <param name="endpoint">The URI endpoint of the MCP server.</param>
25+
public McpClient(Uri endpoint)
26+
{
27+
_pipeline = ClientPipeline.Create();
28+
_session = new McpSession(endpoint, _pipeline);
29+
Endpoint = endpoint;
30+
}
31+
32+
/// <summary>
33+
/// Starts the MCP client session by initializing the connection to the server.
34+
/// </summary>
35+
/// <returns>A task that represents the asynchronous operation.</returns>
36+
public virtual async Task StartAsync()
37+
{
38+
await _session.EnsureInitializedAsync().ConfigureAwait(false);
39+
}
40+
41+
/// <summary>
42+
/// Lists all available tools from the MCP server.
43+
/// </summary>
44+
/// <returns>A task that represents the asynchronous operation. The task result contains the binary data representing the tools list.</returns>
45+
/// <exception cref="InvalidOperationException">Thrown when the session is not initialized.</exception>
46+
public virtual async Task<BinaryData> ListToolsAsync()
47+
{
48+
if (_session == null)
49+
throw new InvalidOperationException("Session is not initialized. Call StartAsync() first.");
50+
51+
return await _session.SendMethod("tools/list").ConfigureAwait(false);
52+
}
53+
54+
/// <summary>
55+
/// Calls a specific tool on the MCP server.
56+
/// </summary>
57+
/// <param name="toolName">The name of the tool to call.</param>
58+
/// <param name="parameters">The parameters to pass to the tool as binary data.</param>
59+
/// <returns>A task that represents the asynchronous operation. The task result contains the binary data representing the tool's response.</returns>
60+
/// <exception cref="InvalidOperationException">Thrown when the session is not initialized.</exception>
61+
public virtual async Task<BinaryData> CallToolAsync(string toolName, BinaryData parameters)
62+
{
63+
if (_session == null)
64+
throw new InvalidOperationException("Session is not initialized. Call StartAsync() first.");
65+
66+
return await _session.CallTool(toolName, parameters).ConfigureAwait(false);
67+
}
68+
}

0 commit comments

Comments
 (0)