diff --git a/src/OpenAI.csproj b/src/OpenAI.csproj index ce4029cd4..aaa57a364 100644 --- a/src/OpenAI.csproj +++ b/src/OpenAI.csproj @@ -60,6 +60,10 @@ 0024000004800000940000000602000000240000525341310004000001000100097ad52abbeaa2e1a1982747cc0106534f65cfea6707eaed696a3a63daea80de2512746801a7e47f88e7781e71af960d89ba2e25561f70b0e2dbc93319e0af1961a719ccf5a4d28709b2b57a5d29b7c09dc8d269a490ebe2651c4b6e6738c27c5fb2c02469fe9757f0a3479ac310d6588a50a28d7dd431b907fd325e18b9e8ed + + 0024000004800000940000000602000000240000525341310004000001000100b197326f2e5bfe2e2a49eb2a05bee871c55cc894325b3775159732ad816c4f304916e7f154295486f8ccabefa3c19b059d51cd19987cc2d31a3195d6203ad0948662f51cc61cc3eb535fc852dfe5159318c734b163f7d1387f1112e1ffe10f83aae7b809c4e36cf2025da5d1aed6b67e1556883d8778eeb63131c029555166de + + @@ -83,5 +87,6 @@ + diff --git a/src/Utility/ChatTools.cs b/src/Utility/ChatTools.cs new file mode 100644 index 000000000..026920705 --- /dev/null +++ b/src/Utility/ChatTools.cs @@ -0,0 +1,287 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Threading.Tasks; +using OpenAI.Agents; +using OpenAI.Chat; +using OpenAI.Embeddings; + +namespace OpenAI.Chat; + +/// +/// Provides functionality to manage and execute OpenAI function tools for chat completions. +/// +//[Experimental("OPENAIMCP001")] +public class ChatTools +{ + private readonly Dictionary _methods = []; + private readonly Dictionary>> _mcpMethods = []; + private readonly List _tools = []; + private readonly EmbeddingClient _client; + private readonly List _entries = []; + private readonly List _mcpClients = []; + private readonly Dictionary _mcpClientsByEndpoint = []; + + /// + /// Initializes a new instance of the ChatTools class with an optional embedding client. + /// + /// The embedding client used for tool vectorization, or null to disable vectorization. + public ChatTools(EmbeddingClient client = null) + { + _client = client; + } + + /// + /// Initializes a new instance of the ChatTools class with the specified tool types. + /// + /// Additional tool types to add. + public ChatTools(params Type[] tools) : this((EmbeddingClient)null) + { + foreach (var t in tools) + AddFunctionTool(t); + } + + /// + /// Gets the list of defined tools. + /// + public IList Tools => _tools; + + /// + /// Gets whether tools can be filtered using embeddings provided by the provided . + /// + public bool CanFilterTools => _client != null; + + /// + /// Adds local tool implementations from the provided types. + /// + /// Types containing static methods to be used as tools. + public void AddFunctionTools(params Type[] tools) + { + foreach (Type functionHolder in tools) + AddFunctionTool(functionHolder); + } + + /// + /// Adds all public static methods from the specified type as tools. + /// + /// The type containing tool methods. + internal void AddFunctionTool(Type tool) + { +#pragma warning disable IL2070 + foreach (MethodInfo function in tool.GetMethods(BindingFlags.Public | BindingFlags.Static)) + { + AddFunctionTool(function); + } +#pragma warning restore IL2070 + } + + internal void AddFunctionTool(MethodInfo function) + { + string name = function.Name; + var tool = ChatTool.CreateFunctionTool(name, ToolsUtility.GetMethodDescription(function), ToolsUtility.BuildParametersJson(function.GetParameters())); + _tools.Add(tool); + _methods[name] = function; + } + + /// + /// Adds a remote MCP server as a tool provider. + /// + /// The MCP client instance. + /// A task representing the asynchronous operation. + internal async Task AddMcpToolsAsync(McpClient client) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + _mcpClientsByEndpoint[client.Endpoint.AbsoluteUri] = client; + await client.StartAsync().ConfigureAwait(false); + BinaryData tools = await client.ListToolsAsync().ConfigureAwait(false); + await AddMcpToolsAsync(tools, client).ConfigureAwait(false); + _mcpClients.Add(client); + } + + /// + /// Adds a remote MCP server as a tool provider. + /// + /// The URI endpoint of the MCP server. + /// A task representing the asynchronous operation. + public async Task AddMcpToolsAsync(Uri mcpEndpoint) + { + var client = new McpClient(mcpEndpoint); + await AddMcpToolsAsync(client).ConfigureAwait(false); + } + + private async Task AddMcpToolsAsync(BinaryData toolDefinitions, McpClient client) + { + List toolsToVectorize = new(); + var parsedTools = ToolsUtility.ParseMcpToolDefinitions(toolDefinitions, client); + + foreach (var (name, description, inputSchema) in parsedTools) + { + var chatTool = ChatTool.CreateFunctionTool(name, description, BinaryData.FromString(inputSchema)); + _tools.Add(chatTool); + toolsToVectorize.Add(chatTool); + _mcpMethods[name] = client.CallToolAsync; + } + + if (_client != null) + { + var embeddings = await _client.GenerateEmbeddingsAsync(toolsToVectorize.Select(t => t.FunctionDescription).ToList()).ConfigureAwait(false); + foreach (var embedding in embeddings.Value) + { + var vector = embedding.ToFloats(); + var item = toolsToVectorize[embedding.Index]; + var toolDefinition = SerializeTool(item); + _entries.Add(new VectorDatabaseEntry(vector, toolDefinition)); + } + } + } + + private BinaryData SerializeTool(ChatTool tool) + { + return ToolsUtility.SerializeTool(tool.FunctionName, tool.FunctionDescription, tool.FunctionParameters); + } + + private ChatTool ParseToolDefinition(BinaryData data) + { + using var document = JsonDocument.Parse(data); + var root = document.RootElement; + + return ChatTool.CreateFunctionTool( + root.GetProperty("name").GetString()!, + root.GetProperty("description").GetString()!, + BinaryData.FromString(root.GetProperty("inputSchema").GetRawText())); + } + + /// + /// Converts the tools collection to chat completion options. + /// + /// A new ChatCompletionOptions containing all defined tools. + public ChatCompletionOptions ToChatCompletionOptions() + { + var options = new ChatCompletionOptions(); + foreach (var tool in _tools) + options.Tools.Add(tool); + return options; + } + + /// + /// Converts the tools collection to , filtered by relevance to the given prompt. + /// + /// The prompt to find relevant tools for. + /// The maximum number of tools to return. Default is 3. + /// The similarity threshold for including tools. Default is 0.29. + /// A new containing the most relevant tools. + public ChatCompletionOptions CreateCompletionOptions(string prompt, int maxTools = 5, float minVectorDistance = 0.29f) + { + if (!CanFilterTools) + return ToChatCompletionOptions(); + + var completionOptions = new ChatCompletionOptions(); + foreach (var tool in FindRelatedTools(false, prompt, maxTools, minVectorDistance).GetAwaiter().GetResult()) + completionOptions.Tools.Add(tool); + return completionOptions; + } + + /// + /// Converts the tools collection to , filtered by relevance to the given prompt. + /// + /// The prompt to find relevant tools for. + /// The maximum number of tools to return. Default is 3. + /// The similarity threshold for including tools. Default is 0.29. + /// A new containing the most relevant tools. + public async Task ToChatCompletionOptions(string prompt, int maxTools = 5, float minVectorDistance = 0.29f) + { + if (!CanFilterTools) + return ToChatCompletionOptions(); + + var completionOptions = new ChatCompletionOptions(); + foreach (var tool in await FindRelatedTools(true, prompt, maxTools, minVectorDistance).ConfigureAwait(false)) + completionOptions.Tools.Add(tool); + return completionOptions; + } + + private async Task> FindRelatedTools(bool async, string prompt, int maxTools, float minVectorDistance) + { + if (!CanFilterTools) + return _tools; + + return (await FindVectorMatches(async, prompt, maxTools, minVectorDistance).ConfigureAwait(false)) + .Select(e => ParseToolDefinition(e.Data)); + } + + private async Task> FindVectorMatches(bool async, string prompt, int maxTools, float minVectorDistance) + { + var vector = async ? + await ToolsUtility.GetEmbeddingAsync(_client, prompt).ConfigureAwait(false) : + ToolsUtility.GetEmbedding(_client, prompt); + + lock (_entries) + { + return ToolsUtility.GetClosestEntries(_entries, maxTools, minVectorDistance, vector); + } + } + + internal async Task CallFunctionToolAsync(ChatToolCall call) + { + var arguments = new List(); + if (call.FunctionArguments != null) + { + if (!_methods.TryGetValue(call.FunctionName, out MethodInfo method)) + throw new InvalidOperationException($"Tool not found: {call.FunctionName}"); + + ToolsUtility.ParseFunctionCallArgs(method, call.FunctionArguments, out arguments); + } + return await ToolsUtility.CallFunctionToolAsync(_methods, call.FunctionName, [.. arguments]); + } + + internal async Task CallMcpAsync(ChatToolCall call) + { + if (!_mcpMethods.TryGetValue(call.FunctionName, out var method)) + throw new NotImplementedException($"MCP tool {call.FunctionName} not found."); + +#if !NETSTANDARD2_0 + var actualFunctionName = call.FunctionName.Split(ToolsUtility.McpToolSeparator, 2)[1]; +#else + var index = call.FunctionName.IndexOf(ToolsUtility.McpToolSeparator); + var actualFunctionName = call.FunctionName.Substring(index + ToolsUtility.McpToolSeparator.Length); +#endif + var result = await method(actualFunctionName, call.FunctionArguments).ConfigureAwait(false); + if (result == null) + throw new InvalidOperationException($"MCP tool {call.FunctionName} returned null. Function tools should always return a value."); + return result.ToString(); + } + + /// + /// Executes all tool calls and returns their results. + /// + /// The collection of tool calls to execute. + /// A collection of tool chat messages containing the results. + public async Task> CallAsync(IEnumerable toolCalls) + { + var messages = new List(); + foreach (ChatToolCall toolCall in toolCalls) + { + bool isMcpTool = false; + if (!_methods.ContainsKey(toolCall.FunctionName)) + { + if (_mcpMethods.ContainsKey(toolCall.FunctionName)) + { + isMcpTool = true; + } + else + { + throw new InvalidOperationException("Tool not found: " + toolCall.FunctionName); + } + } + + var result = isMcpTool ? await CallMcpAsync(toolCall).ConfigureAwait(false) : await CallFunctionToolAsync(toolCall).ConfigureAwait(false); + messages.Add(new ToolChatMessage(toolCall.Id, result)); + } + + return messages; + } +} + diff --git a/src/Utility/MCP/McpClient.cs b/src/Utility/MCP/McpClient.cs new file mode 100644 index 000000000..5517af852 --- /dev/null +++ b/src/Utility/MCP/McpClient.cs @@ -0,0 +1,68 @@ +using System; +using System.ClientModel.Primitives; +using System.Threading.Tasks; + +namespace OpenAI.Agents; + +/// +/// Client for interacting with a Model Context Protocol (MCP) server. +/// +//[Experimental("OPENAIMCP001")] +internal class McpClient +{ + private readonly McpSession _session; + private readonly ClientPipeline _pipeline; + + /// + /// Gets the endpoint URI of the MCP server. + /// + public virtual Uri Endpoint { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The URI endpoint of the MCP server. + public McpClient(Uri endpoint) + { + _pipeline = ClientPipeline.Create(); + _session = new McpSession(endpoint, _pipeline); + Endpoint = endpoint; + } + + /// + /// Starts the MCP client session by initializing the connection to the server. + /// + /// A task that represents the asynchronous operation. + public virtual async Task StartAsync() + { + await _session.EnsureInitializedAsync().ConfigureAwait(false); + } + + /// + /// Lists all available tools from the MCP server. + /// + /// A task that represents the asynchronous operation. The task result contains the binary data representing the tools list. + /// Thrown when the session is not initialized. + public virtual async Task ListToolsAsync() + { + if (_session == null) + throw new InvalidOperationException("Session is not initialized. Call StartAsync() first."); + + return await _session.SendMethod("tools/list").ConfigureAwait(false); + } + + /// + /// Calls a specific tool on the MCP server. + /// + /// The name of the tool to call. + /// The parameters to pass to the tool as binary data. + /// A task that represents the asynchronous operation. The task result contains the binary data representing the tool's response. + /// Thrown when the session is not initialized. + public virtual async Task CallToolAsync(string toolName, BinaryData parameters) + { + if (_session == null) + throw new InvalidOperationException("Session is not initialized. Call StartAsync() first."); + + return await _session.CallTool(toolName, parameters).ConfigureAwait(false); + } +} diff --git a/src/Utility/MCP/McpSession.cs b/src/Utility/MCP/McpSession.cs new file mode 100644 index 000000000..036673540 --- /dev/null +++ b/src/Utility/MCP/McpSession.cs @@ -0,0 +1,517 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Concurrent; +using System.IO; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace OpenAI; + +internal class McpSession : IDisposable +{ + private readonly Uri _serverEndpoint; + private readonly ClientPipeline _pipeline; + private readonly MessageRouter _messageRouter = new MessageRouter(); + private CancellationTokenSource _cancellationSource; + private PipelineMessage _activeHandshake; + private string _messageEndpoint = string.Empty; + private bool _isInitialized; + private int _nextId = 0; + private readonly SemaphoreSlim _initializationLock = new(1, 1); + private TaskCompletionSource _endpointTcs; + + public McpSession(Uri serverEndpoint, ClientPipeline pipeline) + { + _serverEndpoint = serverEndpoint; + _pipeline = pipeline; + _cancellationSource = new CancellationTokenSource(); + } + + public async Task EnsureInitializedAsync() + { + DebugPrint("Ensuring session is initialized..."); + if (_isInitialized && _activeHandshake != null) + { + DebugPrint("Session is already initialized."); + return; + } + + await _initializationLock.WaitAsync().ConfigureAwait(false); + try + { + if (_isInitialized && _activeHandshake != null) + { + DebugPrint("Session is already initialized."); + return; + } + + await InitializeSessionAsync().ConfigureAwait(false); + } + finally + { + _initializationLock.Release(); + } + } + + public void Stop() + { + _cancellationSource.Cancel(); + DebugPrint("Stopping session..."); + CleanupCurrentSession(); + _isInitialized = false; + _cancellationSource = new CancellationTokenSource(); + DebugPrint("Session stopped."); + } + + private async Task InitializeSessionAsync() + { + DebugPrint("Initializing session..."); + CleanupCurrentSession(); + + _activeHandshake = CreateHandshakeMessage(); + try + { + _pipeline.Send(_activeHandshake); + var response = _activeHandshake.Response; + + if (response?.IsError == true) + { + throw new InvalidOperationException($"Failed to initialize SSE connection: {response.Status}"); + } + + StartSseProcessing(response!.ContentStream!); + + // Get the message endpoint from the server + _messageEndpoint = await GetMessageEndpointAsync().ConfigureAwait(false); + DebugPrint($"Message endpoint: {_messageEndpoint}"); + await SendInitializeAsync().ConfigureAwait(false); + _isInitialized = true; + } + catch + { + CleanupCurrentSession(); + throw; + } + } + + private void StartSseProcessing(Stream responseStream) + { + var streamReader = new StreamReader(responseStream); + _ = Task.Run(async () => + { + try + { + await ProcessSseStreamAsync(streamReader).ConfigureAwait(false); + } + catch (Exception ex) + { + DebugPrint($"SSE processing failed: {ex.Message}"); + _isInitialized = false; + } + }, _cancellationSource.Token); + } + + private async Task ProcessSseStreamAsync(StreamReader streamReader) + { + string eventName = string.Empty; + var dataBuilder = new StringBuilder(); + + try + { + while (!_cancellationSource.Token.IsCancellationRequested) + { + DebugPrint("Reading line from SSE stream..."); + string line = await streamReader.ReadLineAsync().ConfigureAwait(false); + if (line == null) + { + throw new IOException("SSE stream closed unexpectedly"); + } + + DebugPrint($"Received line: '{line}'"); + + if (line.StartsWith("event:", StringComparison.OrdinalIgnoreCase)) + { + eventName = line.AsSpan(6).Trim().ToString(); + } + else if (line.StartsWith("data:")) + { +#if !NETSTANDARD2_0 + dataBuilder.Append(line.AsSpan(5)); +#else + dataBuilder.Append(line.AsSpan(5).ToString()); +#endif + } + else if (string.IsNullOrEmpty(line) && dataBuilder.Length > 0) + { + var sseEvent = new SseEvent(eventName, dataBuilder.ToString().TrimEnd()); + await ProcessEventAsync(sseEvent).ConfigureAwait(false); + eventName = string.Empty; + dataBuilder.Clear(); + } + } + DebugPrint("SSE stream processing stopped."); + } + catch (Exception ex) + { + DebugPrint($"Error processing SSE stream: {ex.Message}"); + } + finally + { + CleanupCurrentSession(); + } + } + + private async Task ProcessEventAsync(SseEvent sseEvent) + { + switch (sseEvent.Event) + { + case "endpoint": + if (_endpointTcs != null) + { + var serverUri = _serverEndpoint; + var endpoint = serverUri.GetLeftPart(UriPartial.Authority); + var path = sseEvent.Data.Trim(); + string trailingSlash = endpoint.EndsWith("/") || path.StartsWith("/") ? "" : "/"; + var messageEndpoint = $"{endpoint}{trailingSlash}{path}"; + _endpointTcs.TrySetResult(messageEndpoint); + } + break; + + case "message": + case "": // Handle empty event name as a message + DebugPrint($"Received message: {sseEvent.Data}"); + await _messageRouter.RouteMessageAsync(sseEvent).ConfigureAwait(false); + break; + + default: + DebugPrint($"Unknown event: {sseEvent.Event}"); + break; + } + } + + private async Task GetMessageEndpointAsync() + { + if (!string.IsNullOrEmpty(_messageEndpoint)) + { + return _messageEndpoint; + } + + _endpointTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + using var registration = cts.Token.Register(() => + _endpointTcs.TrySetException(new TimeoutException("Timeout waiting for endpoint event"))); + + try + { + return await _endpointTcs.Task.ConfigureAwait(false); + } + finally + { + _endpointTcs = null; + } + } + + private async Task SendInitializeAsync() + { + var id = GetNextId(); + var initializeTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + string json = $$""" + { + "jsonrpc": "2.0", + "id": {{id}}, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": { + "roots": { + "listChanged": true + }, + "sampling": {} + }, + "clientInfo": { + "name": "ExampleClient", + "version": "1.0.0" + } + } + } + """; + + // Register handler for initialize response + _messageRouter.RegisterHandler(id, async (sseEvent) => + { + try + { + // Send initialized notification after receiving initialize response + string initialized = $$""" + { + "jsonrpc": "2.0", + "method": "notifications/initialized" + } + """; + + await SendAsync(initialized).ConfigureAwait(false); + initializeTcs.SetResult(Task.CompletedTask); + } + catch (Exception ex) + { + initializeTcs.SetException(ex); + } + }); + + await SendAsync(json).ConfigureAwait(false); + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + using var registration = cts.Token.Register(() => initializeTcs.TrySetCanceled()); + + await initializeTcs.Task.ConfigureAwait(false); + } + + private PipelineMessage CreateHandshakeMessage() + { + var message = _pipeline.CreateMessage(); + var request = message.Request; + request.Uri = _serverEndpoint; + request.Method = "GET"; + request.Headers.Add("Accept", "text/event-stream"); + message.BufferResponse = false; + return message; + } + + private int GetNextId() + { + return Interlocked.Increment(ref _nextId); + } + + public async Task SendMethod(string methodName) + { + await EnsureInitializedAsync().ConfigureAwait(false); + var id = GetNextId(); + var completionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + string json = $$""" + { + "jsonrpc" : "2.0", + "id" : {{id}}, + "method" : "{{methodName}}" + } + """; + + _messageRouter.RegisterHandler(id, (sseEvent) => + { + try + { + using JsonDocument doc = JsonDocument.Parse(sseEvent.Data); + var messageResult = doc.RootElement + .GetProperty("result"); + + var resultJson = messageResult.GetRawText(); + completionSource.SetResult(BinaryData.FromString(resultJson)); + } + catch (Exception ex) + { + completionSource.SetException(ex); + } + return Task.CompletedTask; + }); + + await SendAsync(json).ConfigureAwait(false); + return await completionSource.Task.ConfigureAwait(false); + } + + public async Task CallTool(string toolName, BinaryData arguments) + { + await EnsureInitializedAsync().ConfigureAwait(false); + var id = GetNextId(); + var completionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + string json = $$""" + { + "jsonrpc": "2.0", + "id": {{id}}, + "method": "tools/call", + "params": { + "name": "{{toolName}}", + "arguments": {{arguments}} + } + } + """; + + _messageRouter.RegisterHandler(id, (sseEvent) => + { + try + { + using JsonDocument doc = JsonDocument.Parse(sseEvent.Data); + var messageResult = doc.RootElement + .GetProperty("result"); + + var resultJson = messageResult.GetRawText(); + completionSource.SetResult(BinaryData.FromString(resultJson)); + } + catch (Exception ex) + { + completionSource.SetException(ex); + } + return Task.CompletedTask; + }); + + await SendAsync(json).ConfigureAwait(false); + return await completionSource.Task.ConfigureAwait(false); + } + + public async Task SendAsync(string json) + { + DebugPrint($"sending:\n {json}\n"); + using PipelineMessage message = _pipeline.CreateMessage(); + using PipelineRequest request = message.Request; + request.Uri = new Uri(_messageEndpoint); + request.Method = "POST"; + request.Headers.Add("Content-Type", "application/json"); + request.Headers.Add("Accept", "text/event-stream"); + message.BufferResponse = false; + + var jsonBytes = BinaryData.FromString(json); + request.Content = BinaryContent.Create(jsonBytes); + await _pipeline.SendAsync(message).ConfigureAwait(false); + var response = message.Response; + } + + private void CleanupCurrentSession() + { + if (_activeHandshake != null) + { + _activeHandshake.Dispose(); + _activeHandshake = null; + } + _isInitialized = false; + } + + public void Dispose() + { + _cancellationSource.Cancel(); + CleanupCurrentSession(); + _initializationLock.Dispose(); + _cancellationSource.Dispose(); + } + + internal struct SseEvent + { + public string Event { get; set; } + public string Data { get; set; } + + public SseEvent(string eventName, string data) + { + Event = eventName; + Data = data; + } + + public override string ToString() + { + return $"====== SSE Event =====\nEvent: {Event}, Data: {Data}\n======================\n"; + } + } + + private class MessageRouter + { + // Dictionary to store message handlers and completion actions by ID + private readonly ConcurrentDictionary> _handlers = new(); + + // Register a handler with optional completion action + public void RegisterHandler(int id, Func handler) + { + if (!_handlers.TryAdd(id, handler)) + { + throw new InvalidOperationException($"Handler for message ID {id} already registered"); + } + } + + // Route a message to its registered handler + public async Task RouteMessageAsync(SseEvent sseEvent) + { + if (string.IsNullOrEmpty(sseEvent.Data)) + { + return; + } + + // Extract ID from the message payload + int id = ExtractId(sseEvent.Data); + if (id == -1) + { + return; + } + + // Try to get and remove the handler + if (_handlers.TryRemove(id, out var handler)) + { + try + { + // Execute handler + await handler(sseEvent).ConfigureAwait(false); + } + catch (Exception ex) + { + DebugPrint($"Error in message handler for ID {id}: {ex.Message}"); + } + } + } + + private static int ExtractId(string jsonPayload) + { + if (string.IsNullOrEmpty(jsonPayload)) + { + throw new ArgumentException("JSON payload cannot be null or empty", nameof(jsonPayload)); + } + + byte[] bytes = Encoding.UTF8.GetBytes(jsonPayload); + var reader = new Utf8JsonReader(bytes); + + try + { + // Look for the top-level "id" property + while (reader.Read()) + { + // We only care about property names at the top level + if (reader.TokenType == JsonTokenType.PropertyName && + reader.GetString() == "id" && + reader.CurrentDepth == 1) + { + // Move to the value + reader.Read(); + + // Handle number or string format + if (reader.TokenType == JsonTokenType.Number) + { + if (reader.TryGetInt32(out int id)) + { + return id; + } + } + + // Found "id" but it's not a valid integer + return -1; + } + } + + // No "id" property found + return -1; + } + catch (JsonException ex) + { + DebugPrint($"Error parsing JSON: {ex.Message}"); + return -1; + } + } + } + + private static void DebugPrint(string message) + { +#if DEBUGPRINT + var color = Console.ForegroundColor; + Console.ForegroundColor = ConsoleColor.DarkGray; + Console.WriteLine(message); + Console.ForegroundColor = color; +#endif + } +} diff --git a/src/Utility/ResponseTools.cs b/src/Utility/ResponseTools.cs new file mode 100644 index 000000000..281dc8011 --- /dev/null +++ b/src/Utility/ResponseTools.cs @@ -0,0 +1,282 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Threading.Tasks; +using OpenAI.Agents; +using OpenAI.Embeddings; +using OpenAI.Responses; + +namespace OpenAI.Responses; + +/// +/// Provides functionality to manage and execute OpenAI function tools for responses. +/// +//[Experimental("OPENAIMCP001") +public class ResponseTools +{ + private readonly Dictionary _methods = []; + private readonly Dictionary>> _mcpMethods = []; + private readonly List _tools = []; + private readonly EmbeddingClient _client; + private readonly List _entries = []; + private readonly List _mcpClients = []; + private readonly Dictionary _mcpClientsByEndpoint = []; + + /// + /// Initializes a new instance of the ResponseTools class with an optional embedding client. + /// + /// The embedding client used for tool vectorization, or null to disable vectorization. + public ResponseTools(EmbeddingClient client = null) + { + _client = client; + } + + /// + /// Initializes a new instance of the ResponseTools class with the specified tool types. + /// + /// Additional tool types to add. + public ResponseTools(params Type[] tools) : this((EmbeddingClient)null) + { + foreach (var t in tools) + AddFunctionTool(t); + } + + /// + /// Gets the list of defined tools. + /// + public IList Tools => _tools; + + /// + /// Gets whether tools can be filtered using embeddings provided by the provided . + /// + public bool CanFilterTools => _client != null; + + /// + /// Adds local tool implementations from the provided types. + /// + /// Types containing static methods to be used as tools. + public void AddFunctionTools(params Type[] tools) + { + foreach (Type functionHolder in tools) + AddFunctionTool(functionHolder); + } + + /// + /// Adds all public static methods from the specified type as tools. + /// + /// The type containing tool methods. + internal void AddFunctionTool(Type tool) + { +#pragma warning disable IL2070 + foreach (MethodInfo function in tool.GetMethods(BindingFlags.Public | BindingFlags.Static)) + { + AddFunctionTool(function); + } +#pragma warning restore IL2070 + } + + internal void AddFunctionTool(MethodInfo function) + { + string name = function.Name; + var tool = ResponseTool.CreateFunctionTool(name, ToolsUtility.GetMethodDescription(function), ToolsUtility.BuildParametersJson(function.GetParameters())); + _tools.Add(tool); + _methods[name] = function; + } + + /// + /// Adds a remote MCP server as a tool provider. + /// + /// The MCP client instance. + /// A task representing the asynchronous operation. + internal async Task AddMcpToolsAsync(McpClient client) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + _mcpClientsByEndpoint[client.Endpoint.AbsoluteUri] = client; + await client.StartAsync().ConfigureAwait(false); + BinaryData tools = await client.ListToolsAsync().ConfigureAwait(false); + await AddMcpToolsAsync(tools, client).ConfigureAwait(false); + _mcpClients.Add(client); + } + + /// + /// Adds a remote MCP server as a tool provider. + /// + /// The URI endpoint of the MCP server. + /// A task representing the asynchronous operation. + public async Task AddMcpToolsAsync(Uri mcpEndpoint) + { + var client = new McpClient(mcpEndpoint); + await AddMcpToolsAsync(client).ConfigureAwait(false); + } + + private async Task AddMcpToolsAsync(BinaryData toolDefinitions, McpClient client) + { + List toolsToVectorize = new(); + var parsedTools = ToolsUtility.ParseMcpToolDefinitions(toolDefinitions, client); + + foreach (var (name, description, inputSchema) in parsedTools) + { + var responseTool = ResponseTool.CreateFunctionTool(name, description, BinaryData.FromString(inputSchema)); + _tools.Add(responseTool); + toolsToVectorize.Add(responseTool); + _mcpMethods[name] = client.CallToolAsync; + } + + if (_client != null) + { + var embeddings = await _client.GenerateEmbeddingsAsync(toolsToVectorize.ConvertAll(GetDescription)).ConfigureAwait(false); + foreach (var embedding in embeddings.Value) + { + var vector = embedding.ToFloats(); + var item = toolsToVectorize[embedding.Index]; + var toolDefinition = SerializeTool(item); + _entries.Add(new VectorDatabaseEntry(vector, toolDefinition)); + } + } + } + + private string GetDescription(ResponseTool tool) => (tool as InternalResponsesFunctionTool)?.Description ?? ""; + + private BinaryData SerializeTool(ResponseTool tool) + { + var functionTool = tool as InternalResponsesFunctionTool; + return ToolsUtility.SerializeTool(functionTool?.Name, functionTool?.Description, functionTool?.Parameters ?? BinaryData.FromString("{}")); + } + + private ResponseTool ParseToolDefinition(BinaryData data) + { + using var document = JsonDocument.Parse(data); + var root = document.RootElement; + + return ResponseTool.CreateFunctionTool( + root.GetProperty("name").GetString()!, + root.GetProperty("description").GetString()!, + BinaryData.FromString(root.GetProperty("inputSchema").GetRawText())); + } + + /// + /// Converts the tools collection to configured with the tools contained in this instance.. + /// + /// A new ResponseCreationOptions containing all defined tools. + public ResponseCreationOptions ToResponseCreationOptions() + { + var options = new ResponseCreationOptions(); + foreach (var tool in _tools) + options.Tools.Add(tool); + return options; + } + + /// + /// Converts the tools collection to , filtered by relevance to the given prompt. + /// + /// The prompt to find relevant tools for. + /// The maximum number of tools to return. Default is 5. + /// The similarity threshold for including tools. Default is 0.29. + /// A new ResponseCreationOptions containing the most relevant tools. + public ResponseCreationOptions ToResponseCreationOptions(string prompt, int maxTools = 5, float minVectorDistance = 0.29f) + { + if (!CanFilterTools) + return ToResponseCreationOptions(); + + var completionOptions = new ResponseCreationOptions(); + foreach (var tool in FindRelatedTools(false, prompt, maxTools, minVectorDistance).GetAwaiter().GetResult()) + completionOptions.Tools.Add(tool); + return completionOptions; + } + + /// + /// Converts the tools collection to , filtered by relevance to the given prompt. + /// + /// The prompt to find relevant tools for. + /// The maximum number of tools to return. Default is 5. + /// The similarity threshold for including tools. Default is 0.29. + /// A new ResponseCreationOptions containing the most relevant tools. + public async Task ToResponseCreationOptionsAsync(string prompt, int maxTools = 5, float minVectorDistance = 0.29f) + { + if (!CanFilterTools) + return ToResponseCreationOptions(); + + var completionOptions = new ResponseCreationOptions(); + foreach (var tool in await FindRelatedTools(true, prompt, maxTools, minVectorDistance).ConfigureAwait(false)) + completionOptions.Tools.Add(tool); + return completionOptions; + } + + private async Task> FindRelatedTools(bool async, string prompt, int maxTools, float minVectorDistance) + { + if (!CanFilterTools) + return _tools; + + return (await FindVectorMatches(async, prompt, maxTools, minVectorDistance).ConfigureAwait(false)) + .Select(e => ParseToolDefinition(e.Data)); + } + + private async Task> FindVectorMatches(bool async, string prompt, int maxTools, float minVectorDistance) + { + var vector = async ? + await ToolsUtility.GetEmbeddingAsync(_client, prompt).ConfigureAwait(false) : + ToolsUtility.GetEmbedding(_client, prompt); + lock (_entries) + { + return ToolsUtility.GetClosestEntries(_entries, maxTools, minVectorDistance, vector); + } + } + + internal async Task CallFunctionToolAsync(FunctionCallResponseItem call) + { + List arguments = new(); + if (call.FunctionArguments != null) + { + if (!_methods.TryGetValue(call.FunctionName, out MethodInfo method)) + return $"I don't have a tool called {call.FunctionName}"; + + ToolsUtility.ParseFunctionCallArgs(method, call.FunctionArguments, out arguments); + } + + return await ToolsUtility.CallFunctionToolAsync(_methods, call.FunctionName, [.. arguments]); + } + + internal async Task CallMcpAsync(FunctionCallResponseItem call) + { + if (!_mcpMethods.TryGetValue(call.FunctionName, out var method)) + throw new NotImplementedException($"MCP tool {call.FunctionName} not found."); + +#if !NETSTANDARD2_0 + var actualFunctionName = call.FunctionName.Split(ToolsUtility.McpToolSeparator, 2)[1]; +#else + var index = call.FunctionName.IndexOf(ToolsUtility.McpToolSeparator); + var actualFunctionName = call.FunctionName.Substring(index + ToolsUtility.McpToolSeparator.Length); +#endif + var result = await method(actualFunctionName, call.FunctionArguments).ConfigureAwait(false); + return result.ToString(); + } + + /// + /// Executes a function call and returns its result as a FunctionCallOutputResponseItem. + /// + /// The function call to execute. + /// A task that represents the asynchronous operation and contains the function call result. + public async Task CallAsync(FunctionCallResponseItem toolCall) + { + bool isMcpTool = false; + if (!_methods.ContainsKey(toolCall.FunctionName)) + { + if (_mcpMethods.ContainsKey(toolCall.FunctionName)) + { + isMcpTool = true; + } + else + { + return new FunctionCallOutputResponseItem(toolCall.CallId, $"I don't have a tool called {toolCall.FunctionName}"); + } + } + + var result = isMcpTool ? await CallMcpAsync(toolCall).ConfigureAwait(false) : await CallFunctionToolAsync(toolCall); + return new FunctionCallOutputResponseItem(toolCall.CallId, result); + } +} + diff --git a/src/Utility/ToolsUtility.cs b/src/Utility/ToolsUtility.cs new file mode 100644 index 000000000..f9c8f0314 --- /dev/null +++ b/src/Utility/ToolsUtility.cs @@ -0,0 +1,243 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Threading.Tasks; +using OpenAI.Agents; +using OpenAI.Embeddings; + +namespace OpenAI; + +internal static class ToolsUtility +{ + internal static readonly BinaryData EmptyParameters = BinaryData.FromString("""{ "type" : "object", "properties" : {} }"""); + internal const string McpToolSeparator = "_-_"; + + internal static string GetMethodDescription(MethodInfo method) + { + var description = method.Name; + var attr = method.GetCustomAttribute(); + if (attr != null) + description = attr.Description; + return description; + } + + internal static string GetParameterDescription(ParameterInfo param) + { + string description = param.Name!; + var attr = param.GetCustomAttribute(); + if (attr != null) + description = attr.Description; + return description; + } + + internal static ReadOnlySpan ClrToJsonTypeUtf8(Type clrType) => + clrType switch + { + Type t when t == typeof(double) => "number"u8, + Type t when t == typeof(int) => "number"u8, + Type t when t == typeof(long) => "number"u8, + Type t when t == typeof(float) => "number"u8, + Type t when t == typeof(string) => "string"u8, + Type t when t == typeof(bool) => "bool"u8, + _ => throw new NotImplementedException() + }; + + internal static BinaryData BuildParametersJson(ParameterInfo[] parameters) + { + if (parameters.Length == 0) + return EmptyParameters; + + var required = new List(); + using var stream = new MemoryStream(); + using var writer = new Utf8JsonWriter(stream); + writer.WriteStartObject(); + writer.WriteString("type"u8, "object"u8); + writer.WriteStartObject("properties"u8); + + foreach (ParameterInfo parameter in parameters) + { + writer.WriteStartObject(parameter.Name!); + writer.WriteString("type"u8, ClrToJsonTypeUtf8(parameter.ParameterType)); + writer.WriteString("description"u8, GetParameterDescription(parameter)); + writer.WriteEndObject(); + + if (!parameter.IsOptional || (parameter.HasDefaultValue && parameter.DefaultValue is not null)) + required.Add(parameter.Name!); + } + + writer.WriteEndObject(); // properties + writer.WriteStartArray("required"); + foreach (string param in required) + writer.WriteStringValue(param); + writer.WriteEndArray(); + writer.WriteEndObject(); + writer.Flush(); + stream.Position = 0; + return BinaryData.FromStream(stream); + } + + internal static async Task> GetEmbeddingAsync(EmbeddingClient client, string text) + { + var result = await client.GenerateEmbeddingAsync(text).ConfigureAwait(false); + return result.Value.ToFloats(); + } + + internal static ReadOnlyMemory GetEmbedding(EmbeddingClient client, string text) + { + var result = client.GenerateEmbedding(text); + return result.Value.ToFloats(); + } + + internal static float CosineSimilarity(ReadOnlySpan x, ReadOnlySpan y) + { + float dot = 0, xSumSquared = 0, ySumSquared = 0; + for (int i = 0; i < x.Length; i++) + { + dot += x[i] * y[i]; + xSumSquared += x[i] * x[i]; + ySumSquared += y[i] * y[i]; + } +#if NETSTANDARD2_0 + return dot / (float)(Math.Sqrt(xSumSquared) * (float)Math.Sqrt(ySumSquared)); +#else + return dot / (MathF.Sqrt(xSumSquared) * MathF.Sqrt(ySumSquared)); +#endif + } + + internal static IEnumerable<(string name, string description, string inputSchema)> ParseMcpToolDefinitions(BinaryData toolDefinitions, McpClient client) + { + using var document = JsonDocument.Parse(toolDefinitions); + if (!document.RootElement.TryGetProperty("tools", out JsonElement toolsElement)) + throw new JsonException("The JSON document must contain a 'tools' array."); + + var serverKey = client.Endpoint.Host + client.Endpoint.Port.ToString(); + var result = new List<(string name, string description, string inputSchema)>(); + + foreach (var tool in toolsElement.EnumerateArray()) + { + var name = $"{serverKey}{McpToolSeparator}{tool.GetProperty("name").GetString()!}"; + var description = tool.GetProperty("description").GetString()!; + var inputSchemaElement = tool.GetProperty("inputSchema"); + string inputSchema; + using (var stream = new MemoryStream()) + { + using (var writer = new Utf8JsonWriter(stream)) + { + inputSchemaElement.WriteTo(writer); + } + inputSchema = System.Text.Encoding.UTF8.GetString(stream.ToArray()); + } + + result.Add((name, description, inputSchema)); + } + + return result; + } + + internal static void ParseFunctionCallArgs(MethodInfo method, BinaryData functionCallArguments, out List arguments) + { + arguments = new List(); + using var document = JsonDocument.Parse(functionCallArguments); + var parameters = method.GetParameters(); + var argumentsByName = document.RootElement.EnumerateObject().ToDictionary(p => p.Name, p => p.Value); + + foreach (var param in parameters) + { + if (!argumentsByName.TryGetValue(param.Name!, out var value)) + { + if (param.HasDefaultValue) + { + arguments.Add(param.DefaultValue!); + continue; + } + throw new JsonException($"Required parameter '{param.Name}' not found in function call arguments."); + } + + arguments.Add(value.ValueKind switch + { + JsonValueKind.String => value.GetString()!, + JsonValueKind.Number when param.ParameterType == typeof(int) => value.GetInt32(), + JsonValueKind.Number when param.ParameterType == typeof(long) => value.GetInt64(), + JsonValueKind.Number when param.ParameterType == typeof(double) => value.GetDouble(), + JsonValueKind.Number when param.ParameterType == typeof(float) => value.GetSingle(), + JsonValueKind.True => true, + JsonValueKind.False => false, + JsonValueKind.Null when param.HasDefaultValue => param.DefaultValue!, + _ => throw new NotImplementedException($"Conversion from {value.ValueKind} to {param.ParameterType.Name} is not implemented.") + }); + } + } + + internal static IEnumerable GetClosestEntries(List entries, int maxTools, float minVectorDistance, ReadOnlyMemory vector) + { + var distances = entries + .Select((e, i) => (Distance: 1f - ToolsUtility.CosineSimilarity(e.Vector.Span, vector.Span), Index: i)) + .OrderBy(t => t.Distance) + .Take(maxTools) + .Where(t => t.Distance <= minVectorDistance); + + return distances.Select(d => entries[d.Index]); + } + + internal static BinaryData SerializeTool(string name, string description, BinaryData parameters) + { + using var stream = new MemoryStream(); + using var writer = new Utf8JsonWriter(stream, new JsonWriterOptions { Indented = true }); + + writer.WriteStartObject(); + writer.WriteString("name", name); + writer.WriteString("description", description); + writer.WritePropertyName("inputSchema"); + using (var doc = JsonDocument.Parse(parameters)) + doc.RootElement.WriteTo(writer); + writer.WriteEndObject(); + writer.Flush(); + + stream.Position = 0; + return BinaryData.FromStream(stream); + } + + internal static async Task CallFunctionToolAsync(Dictionary methods, string name, object[] arguments) + { + if (!methods.TryGetValue(name, out MethodInfo method)) + throw new InvalidOperationException($"Tool not found: {name}"); + + object result; + if (IsGenericTask(method.ReturnType, out Type taskResultType)) + { + // Method is async, invoke and await + var task = (Task)method.Invoke(null, arguments); + await task.ConfigureAwait(false); + // Get the Result property from the Task + result = taskResultType.GetProperty("Result").GetValue(task); + } + else + { + // Method is synchronous + result = method.Invoke(null, arguments); + } + + return result?.ToString() ?? string.Empty; + } + + private static bool IsGenericTask(Type type, out Type taskResultType) + { + while (type != null && type != typeof(object)) + { + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Task<>)) + { + taskResultType = type;//type.GetGenericArguments()[0]; + return true; + } + + type = type.BaseType!; + } + + taskResultType = null; + return false; + } +} \ No newline at end of file diff --git a/src/Utility/VectorDatabaseEntry.cs b/src/Utility/VectorDatabaseEntry.cs new file mode 100644 index 000000000..fca987972 --- /dev/null +++ b/src/Utility/VectorDatabaseEntry.cs @@ -0,0 +1,37 @@ +using System; + +namespace OpenAI; + +/// +/// A vector database entry. +/// +internal readonly struct VectorDatabaseEntry +{ + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + public VectorDatabaseEntry(ReadOnlyMemory vector, BinaryData data, int? id = default) + { + Vector = vector; + Data = data; + Id = id; + } + + /// + /// Gets the data associated with the entry. + /// + public BinaryData Data { get; } + + /// + /// Gets the vector associated with the entry. + /// + public ReadOnlyMemory Vector { get; } + + /// + /// Gets the id associated with the entry. + /// + public int? Id { get; } +} diff --git a/tests/OpenAI.Tests.csproj b/tests/OpenAI.Tests.csproj index 797a810ea..8776cc207 100644 --- a/tests/OpenAI.Tests.csproj +++ b/tests/OpenAI.Tests.csproj @@ -10,6 +10,13 @@ latest + + + + true + ..\src\OpenAI.snk + + @@ -22,8 +29,4 @@ - - - - \ No newline at end of file diff --git a/tests/Utility/ChatToolsTests.cs b/tests/Utility/ChatToolsTests.cs new file mode 100644 index 000000000..4fb2508cd --- /dev/null +++ b/tests/Utility/ChatToolsTests.cs @@ -0,0 +1,372 @@ +using System; +using System.ClientModel; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Moq; +using NUnit.Framework; +using OpenAI.Agents; +using OpenAI.Chat; +using OpenAI.Embeddings; + +namespace OpenAI.Tests.Utility; + +[TestFixture] +[Category("Utility")] +public class ChatToolsTests : ToolsTestsBase +{ + private Mock mockEmbeddingClient; + + [SetUp] + public void Setup() + { + mockEmbeddingClient = new Mock("text-embedding-ada-002", new ApiKeyCredential("test-key")); + } + + [Test] + public void CanAddLocalTools() + { + var tools = new ChatTools(); + tools.AddFunctionTools(typeof(TestTools)); + + Assert.That(tools.Tools, Has.Count.EqualTo(6)); + Assert.That(tools.Tools.Any(t => t.FunctionName == "Echo")); + Assert.That(tools.Tools.Any(t => t.FunctionName == "Add")); + Assert.That(tools.Tools.Any(t => t.FunctionName == "Multiply")); + Assert.That(tools.Tools.Any(t => t.FunctionName == "IsGreaterThan")); + Assert.That(tools.Tools.Any(t => t.FunctionName == "Divide")); + Assert.That(tools.Tools.Any(t => t.FunctionName == "ConcatWithBool")); + } + + [Test] + public void CanAddAsyncLocalTools() + { + var tools = new ChatTools(); + tools.AddFunctionTools(typeof(TestToolsAsync)); + + Assert.That(tools.Tools, Has.Count.EqualTo(6)); + Assert.That(tools.Tools.Any(t => t.FunctionName == "EchoAsync")); + Assert.That(tools.Tools.Any(t => t.FunctionName == "AddAsync")); + Assert.That(tools.Tools.Any(t => t.FunctionName == "MultiplyAsync")); + Assert.That(tools.Tools.Any(t => t.FunctionName == "IsGreaterThanAsync")); + Assert.That(tools.Tools.Any(t => t.FunctionName == "DivideAsync")); + Assert.That(tools.Tools.Any(t => t.FunctionName == "ConcatWithBoolAsync")); + } + + [Test] + public async Task CanCallToolsAsync() + { + var tools = new ChatTools(); + tools.AddFunctionTools(typeof(TestTools)); + + var toolCalls = new[] + { + ChatToolCall.CreateFunctionToolCall("call1", "Echo", BinaryData.FromString(@"{""message"": ""Hello""}")), + ChatToolCall.CreateFunctionToolCall("call2", "Add", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")), + ChatToolCall.CreateFunctionToolCall("call3", "Multiply", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")), + ChatToolCall.CreateFunctionToolCall("call4", "IsGreaterThan", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")), + ChatToolCall.CreateFunctionToolCall("call5", "Divide", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")), + ChatToolCall.CreateFunctionToolCall("call6", "ConcatWithBool", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}")) + }; + + var results = await tools.CallAsync(toolCalls); + var resultsList = results.ToList(); + + Assert.That(resultsList, Has.Count.EqualTo(6)); + Assert.That(resultsList[0].ToolCallId, Is.EqualTo("call1")); + Assert.That(resultsList[0].Content[0].Text, Is.EqualTo("Hello")); + Assert.That(resultsList[1].ToolCallId, Is.EqualTo("call2")); + Assert.That(resultsList[1].Content[0].Text, Is.EqualTo("5")); + Assert.That(resultsList[2].ToolCallId, Is.EqualTo("call3")); + Assert.That(resultsList[2].Content[0].Text, Is.EqualTo("7.5")); + Assert.That(resultsList[3].ToolCallId, Is.EqualTo("call4")); + Assert.That(resultsList[3].Content[0].Text, Is.EqualTo("True")); + Assert.That(resultsList[4].ToolCallId, Is.EqualTo("call5")); + Assert.That(resultsList[4].Content[0].Text, Is.EqualTo("5")); + Assert.That(resultsList[5].ToolCallId, Is.EqualTo("call6")); + Assert.That(resultsList[5].Content[0].Text, Is.EqualTo("Test:True")); + } + + [Test] + public async Task CanCallAsyncToolsAsync() + { + var tools = new ChatTools(); + tools.AddFunctionTools(typeof(TestToolsAsync)); + + var toolCalls = new[] + { + ChatToolCall.CreateFunctionToolCall("call1", "EchoAsync", BinaryData.FromString(@"{""message"": ""Hello""}")), + ChatToolCall.CreateFunctionToolCall("call2", "AddAsync", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")), + ChatToolCall.CreateFunctionToolCall("call3", "MultiplyAsync", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")), + ChatToolCall.CreateFunctionToolCall("call4", "IsGreaterThanAsync", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")), + ChatToolCall.CreateFunctionToolCall("call5", "DivideAsync", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")), + ChatToolCall.CreateFunctionToolCall("call6", "ConcatWithBoolAsync", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}")) + }; + + var results = await tools.CallAsync(toolCalls); + var resultsList = results.ToList(); + + Assert.That(resultsList, Has.Count.EqualTo(6)); + Assert.That(resultsList[0].ToolCallId, Is.EqualTo("call1")); + Assert.That(resultsList[0].Content[0].Text, Is.EqualTo("Hello")); + Assert.That(resultsList[1].ToolCallId, Is.EqualTo("call2")); + Assert.That(resultsList[1].Content[0].Text, Is.EqualTo("5")); + Assert.That(resultsList[2].ToolCallId, Is.EqualTo("call3")); + Assert.That(resultsList[2].Content[0].Text, Is.EqualTo("7.5")); + Assert.That(resultsList[3].ToolCallId, Is.EqualTo("call4")); + Assert.That(resultsList[3].Content[0].Text, Is.EqualTo("True")); + Assert.That(resultsList[4].ToolCallId, Is.EqualTo("call5")); + Assert.That(resultsList[4].Content[0].Text, Is.EqualTo("5")); + Assert.That(resultsList[5].ToolCallId, Is.EqualTo("call6")); + Assert.That(resultsList[5].Content[0].Text, Is.EqualTo("Test:True")); + } + + [Test] + public void CreatesCompletionOptionsWithTools() + { + var tools = new ChatTools(); + tools.AddFunctionTools(typeof(TestTools)); + + var options = tools.ToChatCompletionOptions(); + + Assert.That(options.Tools, Has.Count.EqualTo(6)); + Assert.That(options.Tools.Any(t => t.FunctionName == "Echo")); + Assert.That(options.Tools.Any(t => t.FunctionName == "Add")); + Assert.That(options.Tools.Any(t => t.FunctionName == "Multiply")); + Assert.That(options.Tools.Any(t => t.FunctionName == "IsGreaterThan")); + Assert.That(options.Tools.Any(t => t.FunctionName == "Divide")); + Assert.That(options.Tools.Any(t => t.FunctionName == "ConcatWithBool")); + } + + [Test] + public async Task CanFilterToolsByRelevance() + { + // Setup mock embedding client to return a mock response + var embedding = OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.5f, 0.5f }); + var embeddingCollection = OpenAIEmbeddingsModelFactory.OpenAIEmbeddingCollection( + items: new[] { embedding }, + model: "text-embedding-ada-002", + usage: OpenAIEmbeddingsModelFactory.EmbeddingTokenUsage(10, 10)); + var mockResponse = new MockPipelineResponse(200); + + mockEmbeddingClient + .Setup(c => c.GenerateEmbeddingAsync( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(ClientResult.FromValue(embedding, mockResponse)); + + var tools = new ChatTools(mockEmbeddingClient.Object); + tools.AddFunctionTools(typeof(TestTools)); + + var options = await tools.ToChatCompletionOptions("Need to add two numbers", 1, 0.5f); + + Assert.That(options.Tools, Has.Count.LessThanOrEqualTo(1)); + } + + [Test] + public void ThrowsWhenCallingNonExistentTool() + { + var tools = new ChatTools(); + var toolCalls = new[] + { + ChatToolCall.CreateFunctionToolCall("call1", "NonExistentTool", BinaryData.FromString("{}")) + }; + + Assert.ThrowsAsync(() => tools.CallAsync(toolCalls)); + } + + [Test] + public async Task AddMcpToolsAsync_AddsToolsCorrectly() + { + // Arrange + var mcpEndpoint = new Uri("http://localhost:1234"); + var mockToolsResponse = BinaryData.FromString(@" + { + ""tools"": [ + { + ""name"": ""mcp-tool-1"", + ""description"": ""This is the first MCP tool."", + ""inputSchema"": { + ""type"": ""object"", + ""properties"": { + ""param1"": { + ""type"": ""string"", + ""description"": ""The first param."" + }, + ""param2"": { + ""type"": ""string"", + ""description"": ""The second param."" + } + }, + ""required"": [""param1""] + } + }, + { + ""name"": ""mcp-tool-2"", + ""description"": ""This is the second MCP tool."", + ""inputSchema"": { + ""type"": ""object"", + ""properties"": { + ""param1"": { + ""type"": ""string"", + ""description"": ""The first param."" + }, + ""param2"": { + ""type"": ""string"", + ""description"": ""The second param."" + } + }, + ""required"": [] + } + } + ] + }"); + + var responsesByTool = new Dictionary + { + ["mcp-tool-1"] = "\"tool1 result\"", + ["mcp-tool-2"] = "\"tool2 result\"" + }; + + var testClient = new TestMcpClient( + mcpEndpoint, + mockToolsResponse, + toolName => BinaryData.FromString(responsesByTool[toolName.Split('_').Last()])); + var tools = new ChatTools(); + + // Act + await tools.AddMcpToolsAsync(testClient); + + // Assert + Assert.That(tools.Tools, Has.Count.EqualTo(2)); + var toolNames = tools.Tools.Select(t => t.FunctionName).ToList(); + Assert.That(toolNames, Contains.Item("localhost1234_-_mcp-tool-1")); + Assert.That(toolNames, Contains.Item("localhost1234_-_mcp-tool-2")); + + // Verify we can call the tools with different responses + var toolCalls = new[] + { + ChatToolCall.CreateFunctionToolCall("call1", "localhost1234_-_mcp-tool-1", BinaryData.FromString(@"{""param1"": ""test""}")), + ChatToolCall.CreateFunctionToolCall("call2", "localhost1234_-_mcp-tool-2", BinaryData.FromString(@"{""param2"": ""test""}")) + }; + var results = await tools.CallAsync(toolCalls); + var resultsList = results.ToList(); + + Assert.That(resultsList, Has.Count.EqualTo(2)); + Assert.That(resultsList[0].ToolCallId, Is.EqualTo("call1")); + Assert.That(resultsList[0].Content[0].Text, Is.EqualTo("\"tool1 result\"")); + Assert.That(resultsList[1].ToolCallId, Is.EqualTo("call2")); + Assert.That(resultsList[1].Content[0].Text, Is.EqualTo("\"tool2 result\"")); + } + + [Test] + public async Task CreateCompletionOptions_WithMaxToolsParameter_FiltersTools() + { + // Arrange + var mcpEndpoint = new Uri("http://localhost:1234"); + var mockToolsResponse = BinaryData.FromString(@" + { + ""tools"": [ + { + ""name"": ""math-tool"", + ""description"": ""Tool for performing mathematical calculations"", + ""inputSchema"": { + ""type"": ""object"", + ""properties"": { + ""expression"": { + ""type"": ""string"", + ""description"": ""The mathematical expression to evaluate"" + } + } + } + }, + { + ""name"": ""weather-tool"", + ""description"": ""Tool for getting weather information"", + ""inputSchema"": { + ""type"": ""object"", + ""properties"": { + ""location"": { + ""type"": ""string"", + ""description"": ""The location to get weather for"" + } + } + } + }, + { + ""name"": ""translate-tool"", + ""description"": ""Tool for translating text between languages"", + ""inputSchema"": { + ""type"": ""object"", + ""properties"": { + ""text"": { + ""type"": ""string"", + ""description"": ""Text to translate"" + }, + ""targetLanguage"": { + ""type"": ""string"", + ""description"": ""Target language code"" + } + } + } + } + ] + }"); + + // Setup mock embedding responses + var embeddings = new[] + { + OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.8f, 0.5f }), + OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.6f, 0.4f }), + OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.3f, 0.2f }) + }; + var embeddingCollection = OpenAIEmbeddingsModelFactory.OpenAIEmbeddingCollection( + items: embeddings, + model: "text-embedding-ada-002", + usage: OpenAIEmbeddingsModelFactory.EmbeddingTokenUsage(30, 30)); + var mockResponse = new MockPipelineResponse(200); + + mockEmbeddingClient + .Setup(c => c.GenerateEmbeddingAsync( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(ClientResult.FromValue(embeddings[0], mockResponse)); + + mockEmbeddingClient + .Setup(c => c.GenerateEmbeddingsAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(ClientResult.FromValue(embeddingCollection, mockResponse)); + + var testClient = new TestMcpClient( + mcpEndpoint, + mockToolsResponse, + toolName => BinaryData.FromString($"\"{toolName} result\"")); + var tools = new ChatTools(mockEmbeddingClient.Object); + + // Add the tools + await tools.AddMcpToolsAsync(testClient); + + // Act & Assert + // Test with maxTools = 1 + var options1 = await tools.ToChatCompletionOptions("calculate 2+2", 1, 0.5f); + Assert.That(options1.Tools, Has.Count.EqualTo(1)); + + // Test with maxTools = 2 + var options2 = await tools.ToChatCompletionOptions("calculate 2+2", 2, 0.5f); + Assert.That(options2.Tools, Has.Count.EqualTo(2)); + + // Test that we can call the tools after filtering + var toolCall = ChatToolCall.CreateFunctionToolCall( + "call1", + "localhost1234_-_math-tool", + BinaryData.FromString(@"{""expression"": ""2+2""}")); + var result = await tools.CallAsync(new[] { toolCall }); + Assert.That(result.First().ToolCallId, Is.EqualTo("call1")); + Assert.That(result.First().Content[0].Text, Is.EqualTo("\"math-tool result\"")); + } +} \ No newline at end of file diff --git a/tests/Utility/MultipartFormDataBinaryContent.cs b/tests/Utility/MultipartFormDataBinaryContent.cs deleted file mode 100644 index 335253450..000000000 --- a/tests/Utility/MultipartFormDataBinaryContent.cs +++ /dev/null @@ -1,195 +0,0 @@ -// - -#nullable disable - -using System; -using System.ClientModel; -using System.Globalization; -using System.IO; -using System.Net.Http; -using System.Net.Http.Headers; -using System.Threading; -using System.Threading.Tasks; - -namespace OpenAI -{ - internal partial class MultiPartFormDataBinaryContent : BinaryContent - { - private readonly MultipartFormDataContent _multipartContent; - private static readonly Random _random = new Random(); - private static readonly char[] _boundaryValues = "0123456789=ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz".ToCharArray(); - - public MultiPartFormDataBinaryContent() - { - _multipartContent = new MultipartFormDataContent(CreateBoundary()); - } - - public string ContentType - { - get - { - return _multipartContent.Headers.ContentType.ToString(); - } - } - - internal HttpContent HttpContent => _multipartContent; - - private static string CreateBoundary() - { - Span chars = new char[70]; - byte[] random = new byte[70]; - _random.NextBytes(random); - int mask = 255 >> 2; - int i = 0; - for (; i < 70; i++) - { - chars[i] = _boundaryValues[random[i] & mask]; - } - return chars.ToString(); - } - - public void Add(string content, string name, string filename = default, string contentType = default) - { - //Argument.AssertNotNull(content, nameof(content)); - //Argument.AssertNotNullOrEmpty(name, nameof(name)); - - Add(new StringContent(content), name, filename, contentType); - } - - public void Add(int content, string name, string filename = default, string contentType = default) - { - //Argument.AssertNotNull(content, nameof(content)); - //Argument.AssertNotNullOrEmpty(name, nameof(name)); - - string value = content.ToString("G", CultureInfo.InvariantCulture); - Add(new StringContent(value), name, filename, contentType); - } - - public void Add(long content, string name, string filename = default, string contentType = default) - { - //Argument.AssertNotNull(content, nameof(content)); - //Argument.AssertNotNullOrEmpty(name, nameof(name)); - - string value = content.ToString("G", CultureInfo.InvariantCulture); - Add(new StringContent(value), name, filename, contentType); - } - - public void Add(float content, string name, string filename = default, string contentType = default) - { - //Argument.AssertNotNull(content, nameof(content)); - //Argument.AssertNotNullOrEmpty(name, nameof(name)); - - string value = content.ToString("G", CultureInfo.InvariantCulture); - Add(new StringContent(value), name, filename, contentType); - } - - public void Add(double content, string name, string filename = default, string contentType = default) - { - //Argument.AssertNotNull(content, nameof(content)); - //Argument.AssertNotNullOrEmpty(name, nameof(name)); - - string value = content.ToString("G", CultureInfo.InvariantCulture); - Add(new StringContent(value), name, filename, contentType); - } - - public void Add(decimal content, string name, string filename = default, string contentType = default) - { - //Argument.AssertNotNull(content, nameof(content)); - //Argument.AssertNotNullOrEmpty(name, nameof(name)); - - string value = content.ToString("G", CultureInfo.InvariantCulture); - Add(new StringContent(value), name, filename, contentType); - } - - public void Add(bool content, string name, string filename = default, string contentType = default) - { - //Argument.AssertNotNull(content, nameof(content)); - //Argument.AssertNotNullOrEmpty(name, nameof(name)); - - string value = content ? "true" : "false"; - Add(new StringContent(value), name, filename, contentType); - } - - public void Add(Stream content, string name, string filename = default, string contentType = default) - { - //Argument.AssertNotNull(content, nameof(content)); - //Argument.AssertNotNullOrEmpty(name, nameof(name)); - - Add(new StreamContent(content), name, filename, contentType); - } - - public void Add(byte[] content, string name, string filename = default, string contentType = default) - { - //Argument.AssertNotNull(content, nameof(content)); - //Argument.AssertNotNullOrEmpty(name, nameof(name)); - - Add(new ByteArrayContent(content), name, filename, contentType); - } - - public void Add(BinaryData content, string name, string filename = default, string contentType = default) - { - //Argument.AssertNotNull(content, nameof(content)); - //Argument.AssertNotNullOrEmpty(name, nameof(name)); - - Add(new ByteArrayContent(content.ToArray()), name, filename, contentType); - } - - private void Add(HttpContent content, string name, string filename, string contentType) - { - if (contentType != null) - { - //Argument.AssertNotNullOrEmpty(contentType, nameof(contentType)); - AddContentTypeHeader(content, contentType); - } - if (filename != null) - { - //Argument.AssertNotNullOrEmpty(filename, nameof(filename)); - _multipartContent.Add(content, name, filename); - } - else - { - _multipartContent.Add(content, name); - } - } - - public static void AddContentTypeHeader(HttpContent content, string contentType) - { - MediaTypeHeaderValue header = new MediaTypeHeaderValue(contentType); - content.Headers.ContentType = header; - } - - public override bool TryComputeLength(out long length) - { - if (_multipartContent.Headers.ContentLength is long contentLength) - { - length = contentLength; - return true; - } - length = 0; - return false; - } - - public override void WriteTo(Stream stream, CancellationToken cancellationToken = default) - { -#if NET6_0_OR_GREATER - _multipartContent.CopyTo(stream, default, cancellationToken); -#else - _multipartContent.CopyToAsync(stream).GetAwaiter().GetResult(); -#endif - } - - public override async Task WriteToAsync(Stream stream, CancellationToken cancellationToken = default) - { -#if NET6_0_OR_GREATER - await _multipartContent.CopyToAsync(stream).ConfigureAwait(false); -#else - await _multipartContent.CopyToAsync(stream).ConfigureAwait(false); -#endif - } - - public override void Dispose() - { - _multipartContent.Dispose(); - } - } -} diff --git a/tests/Utility/ResponseToolsTests.cs b/tests/Utility/ResponseToolsTests.cs new file mode 100644 index 000000000..3f1fb8fe8 --- /dev/null +++ b/tests/Utility/ResponseToolsTests.cs @@ -0,0 +1,411 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Moq; +using NUnit.Framework; +using OpenAI.Agents; +using OpenAI.Embeddings; +using OpenAI.Responses; + +namespace OpenAI.Tests.Utility; + +[TestFixture] +[Category("Utility")] +public class ResponseToolsTests : ToolsTestsBase +{ + private Mock mockEmbeddingClient; + + [SetUp] + public void Setup() + { + mockEmbeddingClient = new Mock("text-embedding-ada-002", new ApiKeyCredential("test-key")); + } + + [Test] + public void CanAddLocalTools() + { + var tools = new ResponseTools(); + tools.AddFunctionTools(typeof(TestTools)); + + Assert.That(tools.Tools, Has.Count.EqualTo(6)); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Echo"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Add"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Multiply"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("IsGreaterThan"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Divide"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("ConcatWithBool"))); + } + + [Test] + public void CanAddLocalAsyncTools() + { + var tools = new ResponseTools(); + tools.AddFunctionTools(typeof(TestToolsAsync)); + + Assert.That(tools.Tools, Has.Count.EqualTo(6)); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("EchoAsync"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("AddAsync"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("MultiplyAsync"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("IsGreaterThanAsync"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("DivideAsync"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("ConcatWithBoolAsync"))); + } + + [Test] + public async Task CanCallToolAsync() + { + var tools = new ResponseTools(); + tools.AddFunctionTools(typeof(TestTools)); + + var toolCalls = new[] + { + new FunctionCallResponseItem("call1", "Echo", BinaryData.FromString(@"{""message"": ""Hello""}")), + new FunctionCallResponseItem("call2", "Add", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")), + new FunctionCallResponseItem("call3", "Multiply", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")), + new FunctionCallResponseItem("call4", "IsGreaterThan", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")), + new FunctionCallResponseItem("call5", "Divide", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")), + new FunctionCallResponseItem("call6", "ConcatWithBool", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}")) + }; + + foreach (var toolCall in toolCalls) + { + var result = await tools.CallAsync(toolCall); + Assert.That(result.CallId, Is.EqualTo(toolCall.CallId)); + switch (toolCall.CallId) + { + case "call1": + Assert.That(result.FunctionOutput, Is.EqualTo("Hello")); + break; + case "call2": + Assert.That(result.FunctionOutput, Is.EqualTo("5")); + break; + case "call3": + Assert.That(result.FunctionOutput, Is.EqualTo("7.5")); + break; + case "call4": + Assert.That(result.FunctionOutput, Is.EqualTo("True")); + break; + case "call5": + Assert.That(result.FunctionOutput, Is.EqualTo("5")); + break; + case "call6": + Assert.That(result.FunctionOutput, Is.EqualTo("Test:True")); + break; + } + } + } + + [Test] + public async Task CanCallAsyncToolsAsync() + { + var tools = new ResponseTools(); + tools.AddFunctionTools(typeof(TestToolsAsync)); + + var toolCalls = new[] + { + new FunctionCallResponseItem("call1", "EchoAsync", BinaryData.FromString(@"{""message"": ""Hello""}")), + new FunctionCallResponseItem("call2", "AddAsync", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")), + new FunctionCallResponseItem("call3", "MultiplyAsync", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")), + new FunctionCallResponseItem("call4", "IsGreaterThanAsync", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")), + new FunctionCallResponseItem("call5", "DivideAsync", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")), + new FunctionCallResponseItem("call6", "ConcatWithBoolAsync", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}")) + }; + + foreach (var toolCall in toolCalls) + { + var result = await tools.CallAsync(toolCall); + Assert.That(result.CallId, Is.EqualTo(toolCall.CallId)); + switch (toolCall.CallId) + { + case "call1": + Assert.That(result.FunctionOutput, Is.EqualTo("Hello")); + break; + case "call2": + Assert.That(result.FunctionOutput, Is.EqualTo("5")); + break; + case "call3": + Assert.That(result.FunctionOutput, Is.EqualTo("7.5")); + break; + case "call4": + Assert.That(result.FunctionOutput, Is.EqualTo("True")); + break; + case "call5": + Assert.That(result.FunctionOutput, Is.EqualTo("5")); + break; + case "call6": + Assert.That(result.FunctionOutput, Is.EqualTo("Test:True")); + break; + } + } + } + + [Test] + public void CreatesResponseOptionsWithTools() + { + var tools = new ResponseTools(); + tools.AddFunctionTools(typeof(TestTools)); + + var options = tools.ToResponseCreationOptions(); + + Assert.That(options.Tools, Has.Count.EqualTo(6)); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Echo"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Add"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Multiply"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("IsGreaterThan"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Divide"))); + Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("ConcatWithBool"))); + } + + [Test] + public async Task CanFilterToolsByRelevance() + { + // Setup mock embedding client to return a mock response + var embeddings = new[] + { + OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.8f, 0.5f }), + OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.6f, 0.4f }), + OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.3f, 0.2f }) + }; + var embeddingCollection = OpenAIEmbeddingsModelFactory.OpenAIEmbeddingCollection( + items: embeddings, + model: "text-embedding-ada-002", + usage: OpenAIEmbeddingsModelFactory.EmbeddingTokenUsage(30, 30)); + var mockResponse = new MockPipelineResponse(200); + + mockEmbeddingClient + .Setup(c => c.GenerateEmbeddingAsync( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(ClientResult.FromValue(embeddings[0], mockResponse)); + + mockEmbeddingClient + .Setup(c => c.GenerateEmbeddingsAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(ClientResult.FromValue(embeddingCollection, mockResponse)); + + var tools = new ResponseTools(mockEmbeddingClient.Object); + tools.AddFunctionTools(typeof(TestTools)); + + var options = await tools.ToResponseCreationOptionsAsync("Need to add two numbers", 1, 0.5f); + + Assert.That(options.Tools, Has.Count.LessThanOrEqualTo(1)); + } + + [Test] + public async Task ReturnsErrorForNonExistentTool() + { + var tools = new ResponseTools(); + var toolCall = new FunctionCallResponseItem("call1", "NonExistentTool", BinaryData.FromString("{}")); + + var result = await tools.CallAsync(toolCall); + Assert.That(result.FunctionOutput, Does.StartWith("I don't have a tool called")); + } + + [Test] + public async Task AddMcpToolsAsync_AddsToolsCorrectly() + { + // Arrange + var mcpEndpoint = new Uri("http://localhost:1234"); + var mockToolsResponse = BinaryData.FromString(@" + { + ""tools"": [ + { + ""name"": ""mcp-tool-1"", + ""description"": ""This is the first MCP tool."", + ""inputSchema"": { + ""type"": ""object"", + ""properties"": { + ""param1"": { + ""type"": ""string"", + ""description"": ""The first param."" + }, + ""param2"": { + ""type"": ""string"", + ""description"": ""The second param."" + } + }, + ""required"": [""param1""] + } + }, + { + ""name"": ""mcp-tool-2"", + ""description"": ""This is the second MCP tool."", + ""inputSchema"": { + ""type"": ""object"", + ""properties"": { + ""param1"": { + ""type"": ""string"", + ""description"": ""The first param."" + }, + ""param2"": { + ""type"": ""string"", + ""description"": ""The second param."" + } + }, + ""required"": [] + } + } + ] + }"); + + var responsesByTool = new Dictionary + { + ["mcp-tool-1"] = "\"tool1 result\"", + ["mcp-tool-2"] = "\"tool2 result\"" + }; + + var testClient = new TestMcpClient( + mcpEndpoint, + mockToolsResponse, + toolName => BinaryData.FromString(responsesByTool[toolName.Split('_').Last()])); + var tools = new ResponseTools(); + + // Act + await tools.AddMcpToolsAsync(testClient); + + // Assert + Assert.That(tools.Tools, Has.Count.EqualTo(2)); + var toolNames = tools.Tools.Select(t => (string)t.GetType().GetProperty("Name").GetValue(t)).ToList(); + Assert.That(toolNames, Contains.Item("localhost1234_-_mcp-tool-1")); + Assert.That(toolNames, Contains.Item("localhost1234_-_mcp-tool-2")); + + // Verify we can call the tools with different responses + var toolCall = new FunctionCallResponseItem("call1", "localhost1234_-_mcp-tool-1", BinaryData.FromString(@"{""param1"": ""test""}")); + var result = await tools.CallAsync(toolCall); + Assert.That(result.FunctionOutput, Is.EqualTo("\"tool1 result\"")); + + var toolCall2 = new FunctionCallResponseItem("call2", "localhost1234_-_mcp-tool-2", BinaryData.FromString(@"{""param2"": ""test""}")); + var result2 = await tools.CallAsync(toolCall2); + Assert.That(result2.FunctionOutput, Is.EqualTo("\"tool2 result\"")); + } + + [Test] + public async Task CreateResponseOptions_WithMaxToolsParameter_FiltersTools() + { + // Arrange + var mcpEndpoint = new Uri("http://localhost:1234"); + var mockToolsResponse = BinaryData.FromString(@" + { + ""tools"": [ + { + ""name"": ""math-tool"", + ""description"": ""Tool for performing mathematical calculations"", + ""inputSchema"": { + ""type"": ""object"", + ""properties"": { + ""expression"": { + ""type"": ""string"", + ""description"": ""The mathematical expression to evaluate"" + } + } + } + }, + { + ""name"": ""weather-tool"", + ""description"": ""Tool for getting weather information"", + ""inputSchema"": { + ""type"": ""object"", + ""properties"": { + ""location"": { + ""type"": ""string"", + ""description"": ""The location to get weather for"" + } + } + } + }, + { + ""name"": ""translate-tool"", + ""description"": ""Tool for translating text between languages"", + ""inputSchema"": { + ""type"": ""object"", + ""properties"": { + ""text"": { + ""type"": ""string"", + ""description"": ""Text to translate"" + }, + ""targetLanguage"": { + ""type"": ""string"", + ""description"": ""Target language code"" + } + } + } + } + ] + }"); + + // Setup mock embedding responses + var embeddings = new[] + { + OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.8f, 0.5f }), + OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.6f, 0.4f }), + OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.3f, 0.2f }) + }; + var embeddingCollection = OpenAIEmbeddingsModelFactory.OpenAIEmbeddingCollection( + items: embeddings, + model: "text-embedding-ada-002", + usage: OpenAIEmbeddingsModelFactory.EmbeddingTokenUsage(30, 30)); + var mockResponse = new MockPipelineResponse(200); + + mockEmbeddingClient + .Setup(c => c.GenerateEmbeddingAsync( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(ClientResult.FromValue(embeddings[0], mockResponse)); + + mockEmbeddingClient + .Setup(c => c.GenerateEmbeddingsAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(ClientResult.FromValue(embeddingCollection, mockResponse)); + + var responsesByTool = new Dictionary + { + ["math-tool"] = "\"math result\"", + ["weather-tool"] = "\"weather result\"", + ["translate-tool"] = "\"translate result\"" + }; + + var testClient = new TestMcpClient( + mcpEndpoint, + mockToolsResponse, + toolName => BinaryData.FromString(responsesByTool[toolName.Split('_').Last()])); + var tools = new ResponseTools(mockEmbeddingClient.Object); + + // Add the tools + await tools.AddMcpToolsAsync(testClient); + + // Act & Assert + // Test with maxTools = 1 + var options1 = await tools.ToResponseCreationOptionsAsync("calculate 2+2", 1, 0.5f); + Assert.That(options1.Tools, Has.Count.EqualTo(1)); + + // Test with maxTools = 2 + var options2 = await tools.ToResponseCreationOptionsAsync("calculate 2+2", 2, 0.5f); + Assert.That(options2.Tools, Has.Count.EqualTo(2)); + + // Test that tool choice affects results + var optionsWithToolChoice = await tools.ToResponseCreationOptionsAsync("calculate 2+2", 1, 0.5f); + optionsWithToolChoice.ToolChoice = ResponseToolChoice.CreateRequiredChoice(); + + Assert.That(optionsWithToolChoice.ToolChoice, Is.Not.Null); + Assert.That(optionsWithToolChoice.Tools, Has.Count.EqualTo(1)); + + // Verify we can still call the filtered tools + var toolCall = new FunctionCallResponseItem( + "call1", + "localhost1234_-_math-tool", + BinaryData.FromString(@"{""expression"": ""2+2""}")); + var result = await tools.CallAsync(toolCall); + Assert.That(result.CallId, Is.EqualTo("call1")); + Assert.That(result.FunctionOutput, Is.EqualTo("\"math result\"")); + } +} \ No newline at end of file diff --git a/tests/Utility/ToolsTestsBase.cs b/tests/Utility/ToolsTestsBase.cs new file mode 100644 index 000000000..dc4637989 --- /dev/null +++ b/tests/Utility/ToolsTestsBase.cs @@ -0,0 +1,94 @@ + +using System; +using System.Threading.Tasks; +using OpenAI.Agents; + +namespace OpenAI.Tests.Utility; + +public class ToolsTestsBase +{ + internal class TestTools + { + public static string Echo(string message) => message; + public static int Add(int a, int b) => a + b; + public static double Multiply(double x, double y) => x * y; + public static bool IsGreaterThan(long value1, long value2) => value1 > value2; + public static float Divide(float numerator, float denominator) => numerator / denominator; + public static string ConcatWithBool(string text, bool flag) => $"{text}:{flag}"; + } + + internal class TestToolsAsync + { + public static async Task EchoAsync(string message) + { + await Task.Delay(1); // Simulate async work + return message; + } + + public static async Task AddAsync(int a, int b) + { + await Task.Delay(1); // Simulate async work + return a + b; + } + + public static async Task MultiplyAsync(double x, double y) + { + await Task.Delay(1); // Simulate async work + return x * y; + } + + public static async Task IsGreaterThanAsync(long value1, long value2) + { + await Task.Delay(1); // Simulate async work + return value1 > value2; + } + + public static async Task DivideAsync(float numerator, float denominator) + { + await Task.Delay(1); // Simulate async work + return numerator / denominator; + } + + public static async Task ConcatWithBoolAsync(string text, bool flag) + { + await Task.Delay(1); // Simulate async work + return $"{text}:{flag}"; + } + } + + internal class TestMcpClient : McpClient + { + private readonly BinaryData _toolsResponse; + private readonly Func _toolCallResponseFactory; + private bool _isStarted; + + public TestMcpClient(Uri endpoint, BinaryData toolsResponse, Func toolCallResponseFactory = null) + : base(endpoint) + { + _toolsResponse = toolsResponse; + _toolCallResponseFactory = toolCallResponseFactory ?? (_ => BinaryData.FromString("\"test result\"")); + } + + public override Task StartAsync() + { + _isStarted = true; + return Task.CompletedTask; + } + + public override Task ListToolsAsync() + { + if (!_isStarted) + throw new InvalidOperationException("Session is not initialized. Call StartAsync() first."); + + return Task.FromResult(_toolsResponse); + } + + public override Task CallToolAsync(string toolName, BinaryData parameters) + { + if (!_isStarted) + throw new InvalidOperationException("Session is not initialized. Call StartAsync() first."); + + return Task.FromResult(_toolCallResponseFactory(toolName)); + } + } +} \ No newline at end of file