diff --git a/src/OpenAI.csproj b/src/OpenAI.csproj
index ce4029cd4..aaa57a364 100644
--- a/src/OpenAI.csproj
+++ b/src/OpenAI.csproj
@@ -60,6 +60,10 @@
0024000004800000940000000602000000240000525341310004000001000100097ad52abbeaa2e1a1982747cc0106534f65cfea6707eaed696a3a63daea80de2512746801a7e47f88e7781e71af960d89ba2e25561f70b0e2dbc93319e0af1961a719ccf5a4d28709b2b57a5d29b7c09dc8d269a490ebe2651c4b6e6738c27c5fb2c02469fe9757f0a3479ac310d6588a50a28d7dd431b907fd325e18b9e8ed
+
+ 0024000004800000940000000602000000240000525341310004000001000100b197326f2e5bfe2e2a49eb2a05bee871c55cc894325b3775159732ad816c4f304916e7f154295486f8ccabefa3c19b059d51cd19987cc2d31a3195d6203ad0948662f51cc61cc3eb535fc852dfe5159318c734b163f7d1387f1112e1ffe10f83aae7b809c4e36cf2025da5d1aed6b67e1556883d8778eeb63131c029555166de
+
+
@@ -83,5 +87,6 @@
+
diff --git a/src/Utility/ChatTools.cs b/src/Utility/ChatTools.cs
new file mode 100644
index 000000000..026920705
--- /dev/null
+++ b/src/Utility/ChatTools.cs
@@ -0,0 +1,287 @@
+using System;
+using System.ClientModel.Primitives;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Reflection;
+using System.Text.Json;
+using System.Threading.Tasks;
+using OpenAI.Agents;
+using OpenAI.Chat;
+using OpenAI.Embeddings;
+
+namespace OpenAI.Chat;
+
+///
+/// Provides functionality to manage and execute OpenAI function tools for chat completions.
+///
+//[Experimental("OPENAIMCP001")]
+public class ChatTools
+{
+ private readonly Dictionary _methods = [];
+ private readonly Dictionary>> _mcpMethods = [];
+ private readonly List _tools = [];
+ private readonly EmbeddingClient _client;
+ private readonly List _entries = [];
+ private readonly List _mcpClients = [];
+ private readonly Dictionary _mcpClientsByEndpoint = [];
+
+ ///
+ /// Initializes a new instance of the ChatTools class with an optional embedding client.
+ ///
+ /// The embedding client used for tool vectorization, or null to disable vectorization.
+ public ChatTools(EmbeddingClient client = null)
+ {
+ _client = client;
+ }
+
+ ///
+ /// Initializes a new instance of the ChatTools class with the specified tool types.
+ ///
+ /// Additional tool types to add.
+ public ChatTools(params Type[] tools) : this((EmbeddingClient)null)
+ {
+ foreach (var t in tools)
+ AddFunctionTool(t);
+ }
+
+ ///
+ /// Gets the list of defined tools.
+ ///
+ public IList Tools => _tools;
+
+ ///
+ /// Gets whether tools can be filtered using embeddings provided by the provided .
+ ///
+ public bool CanFilterTools => _client != null;
+
+ ///
+ /// Adds local tool implementations from the provided types.
+ ///
+ /// Types containing static methods to be used as tools.
+ public void AddFunctionTools(params Type[] tools)
+ {
+ foreach (Type functionHolder in tools)
+ AddFunctionTool(functionHolder);
+ }
+
+ ///
+ /// Adds all public static methods from the specified type as tools.
+ ///
+ /// The type containing tool methods.
+ internal void AddFunctionTool(Type tool)
+ {
+#pragma warning disable IL2070
+ foreach (MethodInfo function in tool.GetMethods(BindingFlags.Public | BindingFlags.Static))
+ {
+ AddFunctionTool(function);
+ }
+#pragma warning restore IL2070
+ }
+
+ internal void AddFunctionTool(MethodInfo function)
+ {
+ string name = function.Name;
+ var tool = ChatTool.CreateFunctionTool(name, ToolsUtility.GetMethodDescription(function), ToolsUtility.BuildParametersJson(function.GetParameters()));
+ _tools.Add(tool);
+ _methods[name] = function;
+ }
+
+ ///
+ /// Adds a remote MCP server as a tool provider.
+ ///
+ /// The MCP client instance.
+ /// A task representing the asynchronous operation.
+ internal async Task AddMcpToolsAsync(McpClient client)
+ {
+ if (client == null) throw new ArgumentNullException(nameof(client));
+ _mcpClientsByEndpoint[client.Endpoint.AbsoluteUri] = client;
+ await client.StartAsync().ConfigureAwait(false);
+ BinaryData tools = await client.ListToolsAsync().ConfigureAwait(false);
+ await AddMcpToolsAsync(tools, client).ConfigureAwait(false);
+ _mcpClients.Add(client);
+ }
+
+ ///
+ /// Adds a remote MCP server as a tool provider.
+ ///
+ /// The URI endpoint of the MCP server.
+ /// A task representing the asynchronous operation.
+ public async Task AddMcpToolsAsync(Uri mcpEndpoint)
+ {
+ var client = new McpClient(mcpEndpoint);
+ await AddMcpToolsAsync(client).ConfigureAwait(false);
+ }
+
+ private async Task AddMcpToolsAsync(BinaryData toolDefinitions, McpClient client)
+ {
+ List toolsToVectorize = new();
+ var parsedTools = ToolsUtility.ParseMcpToolDefinitions(toolDefinitions, client);
+
+ foreach (var (name, description, inputSchema) in parsedTools)
+ {
+ var chatTool = ChatTool.CreateFunctionTool(name, description, BinaryData.FromString(inputSchema));
+ _tools.Add(chatTool);
+ toolsToVectorize.Add(chatTool);
+ _mcpMethods[name] = client.CallToolAsync;
+ }
+
+ if (_client != null)
+ {
+ var embeddings = await _client.GenerateEmbeddingsAsync(toolsToVectorize.Select(t => t.FunctionDescription).ToList()).ConfigureAwait(false);
+ foreach (var embedding in embeddings.Value)
+ {
+ var vector = embedding.ToFloats();
+ var item = toolsToVectorize[embedding.Index];
+ var toolDefinition = SerializeTool(item);
+ _entries.Add(new VectorDatabaseEntry(vector, toolDefinition));
+ }
+ }
+ }
+
+ private BinaryData SerializeTool(ChatTool tool)
+ {
+ return ToolsUtility.SerializeTool(tool.FunctionName, tool.FunctionDescription, tool.FunctionParameters);
+ }
+
+ private ChatTool ParseToolDefinition(BinaryData data)
+ {
+ using var document = JsonDocument.Parse(data);
+ var root = document.RootElement;
+
+ return ChatTool.CreateFunctionTool(
+ root.GetProperty("name").GetString()!,
+ root.GetProperty("description").GetString()!,
+ BinaryData.FromString(root.GetProperty("inputSchema").GetRawText()));
+ }
+
+ ///
+ /// Converts the tools collection to chat completion options.
+ ///
+ /// A new ChatCompletionOptions containing all defined tools.
+ public ChatCompletionOptions ToChatCompletionOptions()
+ {
+ var options = new ChatCompletionOptions();
+ foreach (var tool in _tools)
+ options.Tools.Add(tool);
+ return options;
+ }
+
+ ///
+ /// Converts the tools collection to , filtered by relevance to the given prompt.
+ ///
+ /// The prompt to find relevant tools for.
+ /// The maximum number of tools to return. Default is 3.
+ /// The similarity threshold for including tools. Default is 0.29.
+ /// A new containing the most relevant tools.
+ public ChatCompletionOptions CreateCompletionOptions(string prompt, int maxTools = 5, float minVectorDistance = 0.29f)
+ {
+ if (!CanFilterTools)
+ return ToChatCompletionOptions();
+
+ var completionOptions = new ChatCompletionOptions();
+ foreach (var tool in FindRelatedTools(false, prompt, maxTools, minVectorDistance).GetAwaiter().GetResult())
+ completionOptions.Tools.Add(tool);
+ return completionOptions;
+ }
+
+ ///
+ /// Converts the tools collection to , filtered by relevance to the given prompt.
+ ///
+ /// The prompt to find relevant tools for.
+ /// The maximum number of tools to return. Default is 3.
+ /// The similarity threshold for including tools. Default is 0.29.
+ /// A new containing the most relevant tools.
+ public async Task ToChatCompletionOptions(string prompt, int maxTools = 5, float minVectorDistance = 0.29f)
+ {
+ if (!CanFilterTools)
+ return ToChatCompletionOptions();
+
+ var completionOptions = new ChatCompletionOptions();
+ foreach (var tool in await FindRelatedTools(true, prompt, maxTools, minVectorDistance).ConfigureAwait(false))
+ completionOptions.Tools.Add(tool);
+ return completionOptions;
+ }
+
+ private async Task> FindRelatedTools(bool async, string prompt, int maxTools, float minVectorDistance)
+ {
+ if (!CanFilterTools)
+ return _tools;
+
+ return (await FindVectorMatches(async, prompt, maxTools, minVectorDistance).ConfigureAwait(false))
+ .Select(e => ParseToolDefinition(e.Data));
+ }
+
+ private async Task> FindVectorMatches(bool async, string prompt, int maxTools, float minVectorDistance)
+ {
+ var vector = async ?
+ await ToolsUtility.GetEmbeddingAsync(_client, prompt).ConfigureAwait(false) :
+ ToolsUtility.GetEmbedding(_client, prompt);
+
+ lock (_entries)
+ {
+ return ToolsUtility.GetClosestEntries(_entries, maxTools, minVectorDistance, vector);
+ }
+ }
+
+ internal async Task CallFunctionToolAsync(ChatToolCall call)
+ {
+ var arguments = new List
+
+
+
+ true
+ ..\src\OpenAI.snk
+
+
@@ -22,8 +29,4 @@
-
-
-
-
\ No newline at end of file
diff --git a/tests/Utility/ChatToolsTests.cs b/tests/Utility/ChatToolsTests.cs
new file mode 100644
index 000000000..4fb2508cd
--- /dev/null
+++ b/tests/Utility/ChatToolsTests.cs
@@ -0,0 +1,372 @@
+using System;
+using System.ClientModel;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Moq;
+using NUnit.Framework;
+using OpenAI.Agents;
+using OpenAI.Chat;
+using OpenAI.Embeddings;
+
+namespace OpenAI.Tests.Utility;
+
+[TestFixture]
+[Category("Utility")]
+public class ChatToolsTests : ToolsTestsBase
+{
+ private Mock mockEmbeddingClient;
+
+ [SetUp]
+ public void Setup()
+ {
+ mockEmbeddingClient = new Mock("text-embedding-ada-002", new ApiKeyCredential("test-key"));
+ }
+
+ [Test]
+ public void CanAddLocalTools()
+ {
+ var tools = new ChatTools();
+ tools.AddFunctionTools(typeof(TestTools));
+
+ Assert.That(tools.Tools, Has.Count.EqualTo(6));
+ Assert.That(tools.Tools.Any(t => t.FunctionName == "Echo"));
+ Assert.That(tools.Tools.Any(t => t.FunctionName == "Add"));
+ Assert.That(tools.Tools.Any(t => t.FunctionName == "Multiply"));
+ Assert.That(tools.Tools.Any(t => t.FunctionName == "IsGreaterThan"));
+ Assert.That(tools.Tools.Any(t => t.FunctionName == "Divide"));
+ Assert.That(tools.Tools.Any(t => t.FunctionName == "ConcatWithBool"));
+ }
+
+ [Test]
+ public void CanAddAsyncLocalTools()
+ {
+ var tools = new ChatTools();
+ tools.AddFunctionTools(typeof(TestToolsAsync));
+
+ Assert.That(tools.Tools, Has.Count.EqualTo(6));
+ Assert.That(tools.Tools.Any(t => t.FunctionName == "EchoAsync"));
+ Assert.That(tools.Tools.Any(t => t.FunctionName == "AddAsync"));
+ Assert.That(tools.Tools.Any(t => t.FunctionName == "MultiplyAsync"));
+ Assert.That(tools.Tools.Any(t => t.FunctionName == "IsGreaterThanAsync"));
+ Assert.That(tools.Tools.Any(t => t.FunctionName == "DivideAsync"));
+ Assert.That(tools.Tools.Any(t => t.FunctionName == "ConcatWithBoolAsync"));
+ }
+
+ [Test]
+ public async Task CanCallToolsAsync()
+ {
+ var tools = new ChatTools();
+ tools.AddFunctionTools(typeof(TestTools));
+
+ var toolCalls = new[]
+ {
+ ChatToolCall.CreateFunctionToolCall("call1", "Echo", BinaryData.FromString(@"{""message"": ""Hello""}")),
+ ChatToolCall.CreateFunctionToolCall("call2", "Add", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")),
+ ChatToolCall.CreateFunctionToolCall("call3", "Multiply", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")),
+ ChatToolCall.CreateFunctionToolCall("call4", "IsGreaterThan", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")),
+ ChatToolCall.CreateFunctionToolCall("call5", "Divide", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")),
+ ChatToolCall.CreateFunctionToolCall("call6", "ConcatWithBool", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}"))
+ };
+
+ var results = await tools.CallAsync(toolCalls);
+ var resultsList = results.ToList();
+
+ Assert.That(resultsList, Has.Count.EqualTo(6));
+ Assert.That(resultsList[0].ToolCallId, Is.EqualTo("call1"));
+ Assert.That(resultsList[0].Content[0].Text, Is.EqualTo("Hello"));
+ Assert.That(resultsList[1].ToolCallId, Is.EqualTo("call2"));
+ Assert.That(resultsList[1].Content[0].Text, Is.EqualTo("5"));
+ Assert.That(resultsList[2].ToolCallId, Is.EqualTo("call3"));
+ Assert.That(resultsList[2].Content[0].Text, Is.EqualTo("7.5"));
+ Assert.That(resultsList[3].ToolCallId, Is.EqualTo("call4"));
+ Assert.That(resultsList[3].Content[0].Text, Is.EqualTo("True"));
+ Assert.That(resultsList[4].ToolCallId, Is.EqualTo("call5"));
+ Assert.That(resultsList[4].Content[0].Text, Is.EqualTo("5"));
+ Assert.That(resultsList[5].ToolCallId, Is.EqualTo("call6"));
+ Assert.That(resultsList[5].Content[0].Text, Is.EqualTo("Test:True"));
+ }
+
+ [Test]
+ public async Task CanCallAsyncToolsAsync()
+ {
+ var tools = new ChatTools();
+ tools.AddFunctionTools(typeof(TestToolsAsync));
+
+ var toolCalls = new[]
+ {
+ ChatToolCall.CreateFunctionToolCall("call1", "EchoAsync", BinaryData.FromString(@"{""message"": ""Hello""}")),
+ ChatToolCall.CreateFunctionToolCall("call2", "AddAsync", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")),
+ ChatToolCall.CreateFunctionToolCall("call3", "MultiplyAsync", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")),
+ ChatToolCall.CreateFunctionToolCall("call4", "IsGreaterThanAsync", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")),
+ ChatToolCall.CreateFunctionToolCall("call5", "DivideAsync", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")),
+ ChatToolCall.CreateFunctionToolCall("call6", "ConcatWithBoolAsync", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}"))
+ };
+
+ var results = await tools.CallAsync(toolCalls);
+ var resultsList = results.ToList();
+
+ Assert.That(resultsList, Has.Count.EqualTo(6));
+ Assert.That(resultsList[0].ToolCallId, Is.EqualTo("call1"));
+ Assert.That(resultsList[0].Content[0].Text, Is.EqualTo("Hello"));
+ Assert.That(resultsList[1].ToolCallId, Is.EqualTo("call2"));
+ Assert.That(resultsList[1].Content[0].Text, Is.EqualTo("5"));
+ Assert.That(resultsList[2].ToolCallId, Is.EqualTo("call3"));
+ Assert.That(resultsList[2].Content[0].Text, Is.EqualTo("7.5"));
+ Assert.That(resultsList[3].ToolCallId, Is.EqualTo("call4"));
+ Assert.That(resultsList[3].Content[0].Text, Is.EqualTo("True"));
+ Assert.That(resultsList[4].ToolCallId, Is.EqualTo("call5"));
+ Assert.That(resultsList[4].Content[0].Text, Is.EqualTo("5"));
+ Assert.That(resultsList[5].ToolCallId, Is.EqualTo("call6"));
+ Assert.That(resultsList[5].Content[0].Text, Is.EqualTo("Test:True"));
+ }
+
+ [Test]
+ public void CreatesCompletionOptionsWithTools()
+ {
+ var tools = new ChatTools();
+ tools.AddFunctionTools(typeof(TestTools));
+
+ var options = tools.ToChatCompletionOptions();
+
+ Assert.That(options.Tools, Has.Count.EqualTo(6));
+ Assert.That(options.Tools.Any(t => t.FunctionName == "Echo"));
+ Assert.That(options.Tools.Any(t => t.FunctionName == "Add"));
+ Assert.That(options.Tools.Any(t => t.FunctionName == "Multiply"));
+ Assert.That(options.Tools.Any(t => t.FunctionName == "IsGreaterThan"));
+ Assert.That(options.Tools.Any(t => t.FunctionName == "Divide"));
+ Assert.That(options.Tools.Any(t => t.FunctionName == "ConcatWithBool"));
+ }
+
+ [Test]
+ public async Task CanFilterToolsByRelevance()
+ {
+ // Setup mock embedding client to return a mock response
+ var embedding = OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.5f, 0.5f });
+ var embeddingCollection = OpenAIEmbeddingsModelFactory.OpenAIEmbeddingCollection(
+ items: new[] { embedding },
+ model: "text-embedding-ada-002",
+ usage: OpenAIEmbeddingsModelFactory.EmbeddingTokenUsage(10, 10));
+ var mockResponse = new MockPipelineResponse(200);
+
+ mockEmbeddingClient
+ .Setup(c => c.GenerateEmbeddingAsync(
+ It.IsAny(),
+ It.IsAny(),
+ It.IsAny()))
+ .ReturnsAsync(ClientResult.FromValue(embedding, mockResponse));
+
+ var tools = new ChatTools(mockEmbeddingClient.Object);
+ tools.AddFunctionTools(typeof(TestTools));
+
+ var options = await tools.ToChatCompletionOptions("Need to add two numbers", 1, 0.5f);
+
+ Assert.That(options.Tools, Has.Count.LessThanOrEqualTo(1));
+ }
+
+ [Test]
+ public void ThrowsWhenCallingNonExistentTool()
+ {
+ var tools = new ChatTools();
+ var toolCalls = new[]
+ {
+ ChatToolCall.CreateFunctionToolCall("call1", "NonExistentTool", BinaryData.FromString("{}"))
+ };
+
+ Assert.ThrowsAsync(() => tools.CallAsync(toolCalls));
+ }
+
+ [Test]
+ public async Task AddMcpToolsAsync_AddsToolsCorrectly()
+ {
+ // Arrange
+ var mcpEndpoint = new Uri("http://localhost:1234");
+ var mockToolsResponse = BinaryData.FromString(@"
+ {
+ ""tools"": [
+ {
+ ""name"": ""mcp-tool-1"",
+ ""description"": ""This is the first MCP tool."",
+ ""inputSchema"": {
+ ""type"": ""object"",
+ ""properties"": {
+ ""param1"": {
+ ""type"": ""string"",
+ ""description"": ""The first param.""
+ },
+ ""param2"": {
+ ""type"": ""string"",
+ ""description"": ""The second param.""
+ }
+ },
+ ""required"": [""param1""]
+ }
+ },
+ {
+ ""name"": ""mcp-tool-2"",
+ ""description"": ""This is the second MCP tool."",
+ ""inputSchema"": {
+ ""type"": ""object"",
+ ""properties"": {
+ ""param1"": {
+ ""type"": ""string"",
+ ""description"": ""The first param.""
+ },
+ ""param2"": {
+ ""type"": ""string"",
+ ""description"": ""The second param.""
+ }
+ },
+ ""required"": []
+ }
+ }
+ ]
+ }");
+
+ var responsesByTool = new Dictionary
+ {
+ ["mcp-tool-1"] = "\"tool1 result\"",
+ ["mcp-tool-2"] = "\"tool2 result\""
+ };
+
+ var testClient = new TestMcpClient(
+ mcpEndpoint,
+ mockToolsResponse,
+ toolName => BinaryData.FromString(responsesByTool[toolName.Split('_').Last()]));
+ var tools = new ChatTools();
+
+ // Act
+ await tools.AddMcpToolsAsync(testClient);
+
+ // Assert
+ Assert.That(tools.Tools, Has.Count.EqualTo(2));
+ var toolNames = tools.Tools.Select(t => t.FunctionName).ToList();
+ Assert.That(toolNames, Contains.Item("localhost1234_-_mcp-tool-1"));
+ Assert.That(toolNames, Contains.Item("localhost1234_-_mcp-tool-2"));
+
+ // Verify we can call the tools with different responses
+ var toolCalls = new[]
+ {
+ ChatToolCall.CreateFunctionToolCall("call1", "localhost1234_-_mcp-tool-1", BinaryData.FromString(@"{""param1"": ""test""}")),
+ ChatToolCall.CreateFunctionToolCall("call2", "localhost1234_-_mcp-tool-2", BinaryData.FromString(@"{""param2"": ""test""}"))
+ };
+ var results = await tools.CallAsync(toolCalls);
+ var resultsList = results.ToList();
+
+ Assert.That(resultsList, Has.Count.EqualTo(2));
+ Assert.That(resultsList[0].ToolCallId, Is.EqualTo("call1"));
+ Assert.That(resultsList[0].Content[0].Text, Is.EqualTo("\"tool1 result\""));
+ Assert.That(resultsList[1].ToolCallId, Is.EqualTo("call2"));
+ Assert.That(resultsList[1].Content[0].Text, Is.EqualTo("\"tool2 result\""));
+ }
+
+ [Test]
+ public async Task CreateCompletionOptions_WithMaxToolsParameter_FiltersTools()
+ {
+ // Arrange
+ var mcpEndpoint = new Uri("http://localhost:1234");
+ var mockToolsResponse = BinaryData.FromString(@"
+ {
+ ""tools"": [
+ {
+ ""name"": ""math-tool"",
+ ""description"": ""Tool for performing mathematical calculations"",
+ ""inputSchema"": {
+ ""type"": ""object"",
+ ""properties"": {
+ ""expression"": {
+ ""type"": ""string"",
+ ""description"": ""The mathematical expression to evaluate""
+ }
+ }
+ }
+ },
+ {
+ ""name"": ""weather-tool"",
+ ""description"": ""Tool for getting weather information"",
+ ""inputSchema"": {
+ ""type"": ""object"",
+ ""properties"": {
+ ""location"": {
+ ""type"": ""string"",
+ ""description"": ""The location to get weather for""
+ }
+ }
+ }
+ },
+ {
+ ""name"": ""translate-tool"",
+ ""description"": ""Tool for translating text between languages"",
+ ""inputSchema"": {
+ ""type"": ""object"",
+ ""properties"": {
+ ""text"": {
+ ""type"": ""string"",
+ ""description"": ""Text to translate""
+ },
+ ""targetLanguage"": {
+ ""type"": ""string"",
+ ""description"": ""Target language code""
+ }
+ }
+ }
+ }
+ ]
+ }");
+
+ // Setup mock embedding responses
+ var embeddings = new[]
+ {
+ OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.8f, 0.5f }),
+ OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.6f, 0.4f }),
+ OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.3f, 0.2f })
+ };
+ var embeddingCollection = OpenAIEmbeddingsModelFactory.OpenAIEmbeddingCollection(
+ items: embeddings,
+ model: "text-embedding-ada-002",
+ usage: OpenAIEmbeddingsModelFactory.EmbeddingTokenUsage(30, 30));
+ var mockResponse = new MockPipelineResponse(200);
+
+ mockEmbeddingClient
+ .Setup(c => c.GenerateEmbeddingAsync(
+ It.IsAny(),
+ It.IsAny(),
+ It.IsAny()))
+ .ReturnsAsync(ClientResult.FromValue(embeddings[0], mockResponse));
+
+ mockEmbeddingClient
+ .Setup(c => c.GenerateEmbeddingsAsync(
+ It.IsAny>(),
+ It.IsAny(),
+ It.IsAny()))
+ .ReturnsAsync(ClientResult.FromValue(embeddingCollection, mockResponse));
+
+ var testClient = new TestMcpClient(
+ mcpEndpoint,
+ mockToolsResponse,
+ toolName => BinaryData.FromString($"\"{toolName} result\""));
+ var tools = new ChatTools(mockEmbeddingClient.Object);
+
+ // Add the tools
+ await tools.AddMcpToolsAsync(testClient);
+
+ // Act & Assert
+ // Test with maxTools = 1
+ var options1 = await tools.ToChatCompletionOptions("calculate 2+2", 1, 0.5f);
+ Assert.That(options1.Tools, Has.Count.EqualTo(1));
+
+ // Test with maxTools = 2
+ var options2 = await tools.ToChatCompletionOptions("calculate 2+2", 2, 0.5f);
+ Assert.That(options2.Tools, Has.Count.EqualTo(2));
+
+ // Test that we can call the tools after filtering
+ var toolCall = ChatToolCall.CreateFunctionToolCall(
+ "call1",
+ "localhost1234_-_math-tool",
+ BinaryData.FromString(@"{""expression"": ""2+2""}"));
+ var result = await tools.CallAsync(new[] { toolCall });
+ Assert.That(result.First().ToolCallId, Is.EqualTo("call1"));
+ Assert.That(result.First().Content[0].Text, Is.EqualTo("\"math-tool result\""));
+ }
+}
\ No newline at end of file
diff --git a/tests/Utility/MultipartFormDataBinaryContent.cs b/tests/Utility/MultipartFormDataBinaryContent.cs
deleted file mode 100644
index 335253450..000000000
--- a/tests/Utility/MultipartFormDataBinaryContent.cs
+++ /dev/null
@@ -1,195 +0,0 @@
-//
-
-#nullable disable
-
-using System;
-using System.ClientModel;
-using System.Globalization;
-using System.IO;
-using System.Net.Http;
-using System.Net.Http.Headers;
-using System.Threading;
-using System.Threading.Tasks;
-
-namespace OpenAI
-{
- internal partial class MultiPartFormDataBinaryContent : BinaryContent
- {
- private readonly MultipartFormDataContent _multipartContent;
- private static readonly Random _random = new Random();
- private static readonly char[] _boundaryValues = "0123456789=ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz".ToCharArray();
-
- public MultiPartFormDataBinaryContent()
- {
- _multipartContent = new MultipartFormDataContent(CreateBoundary());
- }
-
- public string ContentType
- {
- get
- {
- return _multipartContent.Headers.ContentType.ToString();
- }
- }
-
- internal HttpContent HttpContent => _multipartContent;
-
- private static string CreateBoundary()
- {
- Span chars = new char[70];
- byte[] random = new byte[70];
- _random.NextBytes(random);
- int mask = 255 >> 2;
- int i = 0;
- for (; i < 70; i++)
- {
- chars[i] = _boundaryValues[random[i] & mask];
- }
- return chars.ToString();
- }
-
- public void Add(string content, string name, string filename = default, string contentType = default)
- {
- //Argument.AssertNotNull(content, nameof(content));
- //Argument.AssertNotNullOrEmpty(name, nameof(name));
-
- Add(new StringContent(content), name, filename, contentType);
- }
-
- public void Add(int content, string name, string filename = default, string contentType = default)
- {
- //Argument.AssertNotNull(content, nameof(content));
- //Argument.AssertNotNullOrEmpty(name, nameof(name));
-
- string value = content.ToString("G", CultureInfo.InvariantCulture);
- Add(new StringContent(value), name, filename, contentType);
- }
-
- public void Add(long content, string name, string filename = default, string contentType = default)
- {
- //Argument.AssertNotNull(content, nameof(content));
- //Argument.AssertNotNullOrEmpty(name, nameof(name));
-
- string value = content.ToString("G", CultureInfo.InvariantCulture);
- Add(new StringContent(value), name, filename, contentType);
- }
-
- public void Add(float content, string name, string filename = default, string contentType = default)
- {
- //Argument.AssertNotNull(content, nameof(content));
- //Argument.AssertNotNullOrEmpty(name, nameof(name));
-
- string value = content.ToString("G", CultureInfo.InvariantCulture);
- Add(new StringContent(value), name, filename, contentType);
- }
-
- public void Add(double content, string name, string filename = default, string contentType = default)
- {
- //Argument.AssertNotNull(content, nameof(content));
- //Argument.AssertNotNullOrEmpty(name, nameof(name));
-
- string value = content.ToString("G", CultureInfo.InvariantCulture);
- Add(new StringContent(value), name, filename, contentType);
- }
-
- public void Add(decimal content, string name, string filename = default, string contentType = default)
- {
- //Argument.AssertNotNull(content, nameof(content));
- //Argument.AssertNotNullOrEmpty(name, nameof(name));
-
- string value = content.ToString("G", CultureInfo.InvariantCulture);
- Add(new StringContent(value), name, filename, contentType);
- }
-
- public void Add(bool content, string name, string filename = default, string contentType = default)
- {
- //Argument.AssertNotNull(content, nameof(content));
- //Argument.AssertNotNullOrEmpty(name, nameof(name));
-
- string value = content ? "true" : "false";
- Add(new StringContent(value), name, filename, contentType);
- }
-
- public void Add(Stream content, string name, string filename = default, string contentType = default)
- {
- //Argument.AssertNotNull(content, nameof(content));
- //Argument.AssertNotNullOrEmpty(name, nameof(name));
-
- Add(new StreamContent(content), name, filename, contentType);
- }
-
- public void Add(byte[] content, string name, string filename = default, string contentType = default)
- {
- //Argument.AssertNotNull(content, nameof(content));
- //Argument.AssertNotNullOrEmpty(name, nameof(name));
-
- Add(new ByteArrayContent(content), name, filename, contentType);
- }
-
- public void Add(BinaryData content, string name, string filename = default, string contentType = default)
- {
- //Argument.AssertNotNull(content, nameof(content));
- //Argument.AssertNotNullOrEmpty(name, nameof(name));
-
- Add(new ByteArrayContent(content.ToArray()), name, filename, contentType);
- }
-
- private void Add(HttpContent content, string name, string filename, string contentType)
- {
- if (contentType != null)
- {
- //Argument.AssertNotNullOrEmpty(contentType, nameof(contentType));
- AddContentTypeHeader(content, contentType);
- }
- if (filename != null)
- {
- //Argument.AssertNotNullOrEmpty(filename, nameof(filename));
- _multipartContent.Add(content, name, filename);
- }
- else
- {
- _multipartContent.Add(content, name);
- }
- }
-
- public static void AddContentTypeHeader(HttpContent content, string contentType)
- {
- MediaTypeHeaderValue header = new MediaTypeHeaderValue(contentType);
- content.Headers.ContentType = header;
- }
-
- public override bool TryComputeLength(out long length)
- {
- if (_multipartContent.Headers.ContentLength is long contentLength)
- {
- length = contentLength;
- return true;
- }
- length = 0;
- return false;
- }
-
- public override void WriteTo(Stream stream, CancellationToken cancellationToken = default)
- {
-#if NET6_0_OR_GREATER
- _multipartContent.CopyTo(stream, default, cancellationToken);
-#else
- _multipartContent.CopyToAsync(stream).GetAwaiter().GetResult();
-#endif
- }
-
- public override async Task WriteToAsync(Stream stream, CancellationToken cancellationToken = default)
- {
-#if NET6_0_OR_GREATER
- await _multipartContent.CopyToAsync(stream).ConfigureAwait(false);
-#else
- await _multipartContent.CopyToAsync(stream).ConfigureAwait(false);
-#endif
- }
-
- public override void Dispose()
- {
- _multipartContent.Dispose();
- }
- }
-}
diff --git a/tests/Utility/ResponseToolsTests.cs b/tests/Utility/ResponseToolsTests.cs
new file mode 100644
index 000000000..3f1fb8fe8
--- /dev/null
+++ b/tests/Utility/ResponseToolsTests.cs
@@ -0,0 +1,411 @@
+using System;
+using System.ClientModel;
+using System.ClientModel.Primitives;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Moq;
+using NUnit.Framework;
+using OpenAI.Agents;
+using OpenAI.Embeddings;
+using OpenAI.Responses;
+
+namespace OpenAI.Tests.Utility;
+
+[TestFixture]
+[Category("Utility")]
+public class ResponseToolsTests : ToolsTestsBase
+{
+ private Mock mockEmbeddingClient;
+
+ [SetUp]
+ public void Setup()
+ {
+ mockEmbeddingClient = new Mock("text-embedding-ada-002", new ApiKeyCredential("test-key"));
+ }
+
+ [Test]
+ public void CanAddLocalTools()
+ {
+ var tools = new ResponseTools();
+ tools.AddFunctionTools(typeof(TestTools));
+
+ Assert.That(tools.Tools, Has.Count.EqualTo(6));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Echo")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Add")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Multiply")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("IsGreaterThan")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Divide")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("ConcatWithBool")));
+ }
+
+ [Test]
+ public void CanAddLocalAsyncTools()
+ {
+ var tools = new ResponseTools();
+ tools.AddFunctionTools(typeof(TestToolsAsync));
+
+ Assert.That(tools.Tools, Has.Count.EqualTo(6));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("EchoAsync")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("AddAsync")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("MultiplyAsync")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("IsGreaterThanAsync")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("DivideAsync")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("ConcatWithBoolAsync")));
+ }
+
+ [Test]
+ public async Task CanCallToolAsync()
+ {
+ var tools = new ResponseTools();
+ tools.AddFunctionTools(typeof(TestTools));
+
+ var toolCalls = new[]
+ {
+ new FunctionCallResponseItem("call1", "Echo", BinaryData.FromString(@"{""message"": ""Hello""}")),
+ new FunctionCallResponseItem("call2", "Add", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")),
+ new FunctionCallResponseItem("call3", "Multiply", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")),
+ new FunctionCallResponseItem("call4", "IsGreaterThan", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")),
+ new FunctionCallResponseItem("call5", "Divide", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")),
+ new FunctionCallResponseItem("call6", "ConcatWithBool", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}"))
+ };
+
+ foreach (var toolCall in toolCalls)
+ {
+ var result = await tools.CallAsync(toolCall);
+ Assert.That(result.CallId, Is.EqualTo(toolCall.CallId));
+ switch (toolCall.CallId)
+ {
+ case "call1":
+ Assert.That(result.FunctionOutput, Is.EqualTo("Hello"));
+ break;
+ case "call2":
+ Assert.That(result.FunctionOutput, Is.EqualTo("5"));
+ break;
+ case "call3":
+ Assert.That(result.FunctionOutput, Is.EqualTo("7.5"));
+ break;
+ case "call4":
+ Assert.That(result.FunctionOutput, Is.EqualTo("True"));
+ break;
+ case "call5":
+ Assert.That(result.FunctionOutput, Is.EqualTo("5"));
+ break;
+ case "call6":
+ Assert.That(result.FunctionOutput, Is.EqualTo("Test:True"));
+ break;
+ }
+ }
+ }
+
+ [Test]
+ public async Task CanCallAsyncToolsAsync()
+ {
+ var tools = new ResponseTools();
+ tools.AddFunctionTools(typeof(TestToolsAsync));
+
+ var toolCalls = new[]
+ {
+ new FunctionCallResponseItem("call1", "EchoAsync", BinaryData.FromString(@"{""message"": ""Hello""}")),
+ new FunctionCallResponseItem("call2", "AddAsync", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")),
+ new FunctionCallResponseItem("call3", "MultiplyAsync", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")),
+ new FunctionCallResponseItem("call4", "IsGreaterThanAsync", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")),
+ new FunctionCallResponseItem("call5", "DivideAsync", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")),
+ new FunctionCallResponseItem("call6", "ConcatWithBoolAsync", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}"))
+ };
+
+ foreach (var toolCall in toolCalls)
+ {
+ var result = await tools.CallAsync(toolCall);
+ Assert.That(result.CallId, Is.EqualTo(toolCall.CallId));
+ switch (toolCall.CallId)
+ {
+ case "call1":
+ Assert.That(result.FunctionOutput, Is.EqualTo("Hello"));
+ break;
+ case "call2":
+ Assert.That(result.FunctionOutput, Is.EqualTo("5"));
+ break;
+ case "call3":
+ Assert.That(result.FunctionOutput, Is.EqualTo("7.5"));
+ break;
+ case "call4":
+ Assert.That(result.FunctionOutput, Is.EqualTo("True"));
+ break;
+ case "call5":
+ Assert.That(result.FunctionOutput, Is.EqualTo("5"));
+ break;
+ case "call6":
+ Assert.That(result.FunctionOutput, Is.EqualTo("Test:True"));
+ break;
+ }
+ }
+ }
+
+ [Test]
+ public void CreatesResponseOptionsWithTools()
+ {
+ var tools = new ResponseTools();
+ tools.AddFunctionTools(typeof(TestTools));
+
+ var options = tools.ToResponseCreationOptions();
+
+ Assert.That(options.Tools, Has.Count.EqualTo(6));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Echo")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Add")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Multiply")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("IsGreaterThan")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Divide")));
+ Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("ConcatWithBool")));
+ }
+
+ [Test]
+ public async Task CanFilterToolsByRelevance()
+ {
+ // Setup mock embedding client to return a mock response
+ var embeddings = new[]
+ {
+ OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.8f, 0.5f }),
+ OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.6f, 0.4f }),
+ OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.3f, 0.2f })
+ };
+ var embeddingCollection = OpenAIEmbeddingsModelFactory.OpenAIEmbeddingCollection(
+ items: embeddings,
+ model: "text-embedding-ada-002",
+ usage: OpenAIEmbeddingsModelFactory.EmbeddingTokenUsage(30, 30));
+ var mockResponse = new MockPipelineResponse(200);
+
+ mockEmbeddingClient
+ .Setup(c => c.GenerateEmbeddingAsync(
+ It.IsAny(),
+ It.IsAny(),
+ It.IsAny()))
+ .ReturnsAsync(ClientResult.FromValue(embeddings[0], mockResponse));
+
+ mockEmbeddingClient
+ .Setup(c => c.GenerateEmbeddingsAsync(
+ It.IsAny>(),
+ It.IsAny(),
+ It.IsAny()))
+ .ReturnsAsync(ClientResult.FromValue(embeddingCollection, mockResponse));
+
+ var tools = new ResponseTools(mockEmbeddingClient.Object);
+ tools.AddFunctionTools(typeof(TestTools));
+
+ var options = await tools.ToResponseCreationOptionsAsync("Need to add two numbers", 1, 0.5f);
+
+ Assert.That(options.Tools, Has.Count.LessThanOrEqualTo(1));
+ }
+
+ [Test]
+ public async Task ReturnsErrorForNonExistentTool()
+ {
+ var tools = new ResponseTools();
+ var toolCall = new FunctionCallResponseItem("call1", "NonExistentTool", BinaryData.FromString("{}"));
+
+ var result = await tools.CallAsync(toolCall);
+ Assert.That(result.FunctionOutput, Does.StartWith("I don't have a tool called"));
+ }
+
+ [Test]
+ public async Task AddMcpToolsAsync_AddsToolsCorrectly()
+ {
+ // Arrange
+ var mcpEndpoint = new Uri("http://localhost:1234");
+ var mockToolsResponse = BinaryData.FromString(@"
+ {
+ ""tools"": [
+ {
+ ""name"": ""mcp-tool-1"",
+ ""description"": ""This is the first MCP tool."",
+ ""inputSchema"": {
+ ""type"": ""object"",
+ ""properties"": {
+ ""param1"": {
+ ""type"": ""string"",
+ ""description"": ""The first param.""
+ },
+ ""param2"": {
+ ""type"": ""string"",
+ ""description"": ""The second param.""
+ }
+ },
+ ""required"": [""param1""]
+ }
+ },
+ {
+ ""name"": ""mcp-tool-2"",
+ ""description"": ""This is the second MCP tool."",
+ ""inputSchema"": {
+ ""type"": ""object"",
+ ""properties"": {
+ ""param1"": {
+ ""type"": ""string"",
+ ""description"": ""The first param.""
+ },
+ ""param2"": {
+ ""type"": ""string"",
+ ""description"": ""The second param.""
+ }
+ },
+ ""required"": []
+ }
+ }
+ ]
+ }");
+
+ var responsesByTool = new Dictionary
+ {
+ ["mcp-tool-1"] = "\"tool1 result\"",
+ ["mcp-tool-2"] = "\"tool2 result\""
+ };
+
+ var testClient = new TestMcpClient(
+ mcpEndpoint,
+ mockToolsResponse,
+ toolName => BinaryData.FromString(responsesByTool[toolName.Split('_').Last()]));
+ var tools = new ResponseTools();
+
+ // Act
+ await tools.AddMcpToolsAsync(testClient);
+
+ // Assert
+ Assert.That(tools.Tools, Has.Count.EqualTo(2));
+ var toolNames = tools.Tools.Select(t => (string)t.GetType().GetProperty("Name").GetValue(t)).ToList();
+ Assert.That(toolNames, Contains.Item("localhost1234_-_mcp-tool-1"));
+ Assert.That(toolNames, Contains.Item("localhost1234_-_mcp-tool-2"));
+
+ // Verify we can call the tools with different responses
+ var toolCall = new FunctionCallResponseItem("call1", "localhost1234_-_mcp-tool-1", BinaryData.FromString(@"{""param1"": ""test""}"));
+ var result = await tools.CallAsync(toolCall);
+ Assert.That(result.FunctionOutput, Is.EqualTo("\"tool1 result\""));
+
+ var toolCall2 = new FunctionCallResponseItem("call2", "localhost1234_-_mcp-tool-2", BinaryData.FromString(@"{""param2"": ""test""}"));
+ var result2 = await tools.CallAsync(toolCall2);
+ Assert.That(result2.FunctionOutput, Is.EqualTo("\"tool2 result\""));
+ }
+
+ [Test]
+ public async Task CreateResponseOptions_WithMaxToolsParameter_FiltersTools()
+ {
+ // Arrange
+ var mcpEndpoint = new Uri("http://localhost:1234");
+ var mockToolsResponse = BinaryData.FromString(@"
+ {
+ ""tools"": [
+ {
+ ""name"": ""math-tool"",
+ ""description"": ""Tool for performing mathematical calculations"",
+ ""inputSchema"": {
+ ""type"": ""object"",
+ ""properties"": {
+ ""expression"": {
+ ""type"": ""string"",
+ ""description"": ""The mathematical expression to evaluate""
+ }
+ }
+ }
+ },
+ {
+ ""name"": ""weather-tool"",
+ ""description"": ""Tool for getting weather information"",
+ ""inputSchema"": {
+ ""type"": ""object"",
+ ""properties"": {
+ ""location"": {
+ ""type"": ""string"",
+ ""description"": ""The location to get weather for""
+ }
+ }
+ }
+ },
+ {
+ ""name"": ""translate-tool"",
+ ""description"": ""Tool for translating text between languages"",
+ ""inputSchema"": {
+ ""type"": ""object"",
+ ""properties"": {
+ ""text"": {
+ ""type"": ""string"",
+ ""description"": ""Text to translate""
+ },
+ ""targetLanguage"": {
+ ""type"": ""string"",
+ ""description"": ""Target language code""
+ }
+ }
+ }
+ }
+ ]
+ }");
+
+ // Setup mock embedding responses
+ var embeddings = new[]
+ {
+ OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.8f, 0.5f }),
+ OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.6f, 0.4f }),
+ OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.3f, 0.2f })
+ };
+ var embeddingCollection = OpenAIEmbeddingsModelFactory.OpenAIEmbeddingCollection(
+ items: embeddings,
+ model: "text-embedding-ada-002",
+ usage: OpenAIEmbeddingsModelFactory.EmbeddingTokenUsage(30, 30));
+ var mockResponse = new MockPipelineResponse(200);
+
+ mockEmbeddingClient
+ .Setup(c => c.GenerateEmbeddingAsync(
+ It.IsAny(),
+ It.IsAny(),
+ It.IsAny()))
+ .ReturnsAsync(ClientResult.FromValue(embeddings[0], mockResponse));
+
+ mockEmbeddingClient
+ .Setup(c => c.GenerateEmbeddingsAsync(
+ It.IsAny>(),
+ It.IsAny(),
+ It.IsAny()))
+ .ReturnsAsync(ClientResult.FromValue(embeddingCollection, mockResponse));
+
+ var responsesByTool = new Dictionary
+ {
+ ["math-tool"] = "\"math result\"",
+ ["weather-tool"] = "\"weather result\"",
+ ["translate-tool"] = "\"translate result\""
+ };
+
+ var testClient = new TestMcpClient(
+ mcpEndpoint,
+ mockToolsResponse,
+ toolName => BinaryData.FromString(responsesByTool[toolName.Split('_').Last()]));
+ var tools = new ResponseTools(mockEmbeddingClient.Object);
+
+ // Add the tools
+ await tools.AddMcpToolsAsync(testClient);
+
+ // Act & Assert
+ // Test with maxTools = 1
+ var options1 = await tools.ToResponseCreationOptionsAsync("calculate 2+2", 1, 0.5f);
+ Assert.That(options1.Tools, Has.Count.EqualTo(1));
+
+ // Test with maxTools = 2
+ var options2 = await tools.ToResponseCreationOptionsAsync("calculate 2+2", 2, 0.5f);
+ Assert.That(options2.Tools, Has.Count.EqualTo(2));
+
+ // Test that tool choice affects results
+ var optionsWithToolChoice = await tools.ToResponseCreationOptionsAsync("calculate 2+2", 1, 0.5f);
+ optionsWithToolChoice.ToolChoice = ResponseToolChoice.CreateRequiredChoice();
+
+ Assert.That(optionsWithToolChoice.ToolChoice, Is.Not.Null);
+ Assert.That(optionsWithToolChoice.Tools, Has.Count.EqualTo(1));
+
+ // Verify we can still call the filtered tools
+ var toolCall = new FunctionCallResponseItem(
+ "call1",
+ "localhost1234_-_math-tool",
+ BinaryData.FromString(@"{""expression"": ""2+2""}"));
+ var result = await tools.CallAsync(toolCall);
+ Assert.That(result.CallId, Is.EqualTo("call1"));
+ Assert.That(result.FunctionOutput, Is.EqualTo("\"math result\""));
+ }
+}
\ No newline at end of file
diff --git a/tests/Utility/ToolsTestsBase.cs b/tests/Utility/ToolsTestsBase.cs
new file mode 100644
index 000000000..dc4637989
--- /dev/null
+++ b/tests/Utility/ToolsTestsBase.cs
@@ -0,0 +1,94 @@
+
+using System;
+using System.Threading.Tasks;
+using OpenAI.Agents;
+
+namespace OpenAI.Tests.Utility;
+
+public class ToolsTestsBase
+{
+ internal class TestTools
+ {
+ public static string Echo(string message) => message;
+ public static int Add(int a, int b) => a + b;
+ public static double Multiply(double x, double y) => x * y;
+ public static bool IsGreaterThan(long value1, long value2) => value1 > value2;
+ public static float Divide(float numerator, float denominator) => numerator / denominator;
+ public static string ConcatWithBool(string text, bool flag) => $"{text}:{flag}";
+ }
+
+ internal class TestToolsAsync
+ {
+ public static async Task EchoAsync(string message)
+ {
+ await Task.Delay(1); // Simulate async work
+ return message;
+ }
+
+ public static async Task AddAsync(int a, int b)
+ {
+ await Task.Delay(1); // Simulate async work
+ return a + b;
+ }
+
+ public static async Task MultiplyAsync(double x, double y)
+ {
+ await Task.Delay(1); // Simulate async work
+ return x * y;
+ }
+
+ public static async Task IsGreaterThanAsync(long value1, long value2)
+ {
+ await Task.Delay(1); // Simulate async work
+ return value1 > value2;
+ }
+
+ public static async Task DivideAsync(float numerator, float denominator)
+ {
+ await Task.Delay(1); // Simulate async work
+ return numerator / denominator;
+ }
+
+ public static async Task ConcatWithBoolAsync(string text, bool flag)
+ {
+ await Task.Delay(1); // Simulate async work
+ return $"{text}:{flag}";
+ }
+ }
+
+ internal class TestMcpClient : McpClient
+ {
+ private readonly BinaryData _toolsResponse;
+ private readonly Func _toolCallResponseFactory;
+ private bool _isStarted;
+
+ public TestMcpClient(Uri endpoint, BinaryData toolsResponse, Func toolCallResponseFactory = null)
+ : base(endpoint)
+ {
+ _toolsResponse = toolsResponse;
+ _toolCallResponseFactory = toolCallResponseFactory ?? (_ => BinaryData.FromString("\"test result\""));
+ }
+
+ public override Task StartAsync()
+ {
+ _isStarted = true;
+ return Task.CompletedTask;
+ }
+
+ public override Task ListToolsAsync()
+ {
+ if (!_isStarted)
+ throw new InvalidOperationException("Session is not initialized. Call StartAsync() first.");
+
+ return Task.FromResult(_toolsResponse);
+ }
+
+ public override Task CallToolAsync(string toolName, BinaryData parameters)
+ {
+ if (!_isStarted)
+ throw new InvalidOperationException("Session is not initialized. Call StartAsync() first.");
+
+ return Task.FromResult(_toolCallResponseFactory(toolName));
+ }
+ }
+}
\ No newline at end of file