Skip to content

Commit f04d6ae

Browse files
authored
Add support for MCP Servers (Azure#48919)
1 parent fbadfd3 commit f04d6ae

File tree

13 files changed

+799
-35
lines changed

13 files changed

+799
-35
lines changed

sdk/cloudmachine/Azure.Projects.OpenAI/api/Azure.Projects.OpenAI.net8.0.cs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,21 @@ public ChatProcessor(OpenAI.Chat.ChatClient chat, OpenAI.Embeddings.EmbeddingCli
3030
protected virtual void OnGround(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, string prompt) { }
3131
protected virtual void OnLength(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, OpenAI.Chat.ChatCompletion completion) { }
3232
protected virtual void OnStop(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, OpenAI.Chat.ChatCompletion completion) { }
33-
protected virtual void OnToolCalls(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, OpenAI.Chat.ChatCompletion completion) { }
33+
protected virtual System.Threading.Tasks.Task OnToolCalls(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, OpenAI.Chat.ChatCompletion completion) { throw null; }
3434
protected virtual void OnToolError(System.Collections.Generic.List<string> failed, System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, OpenAI.Chat.ChatCompletion completion) { }
35-
public OpenAI.Chat.ChatCompletion TakeTurn(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, string prompt) { throw null; }
35+
public System.Threading.Tasks.Task<OpenAI.Chat.ChatCompletion> TakeTurnAsync(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, string prompt) { throw null; }
3636
}
3737
public partial class ChatTools
3838
{
3939
public ChatTools(params System.Type[] tools) { }
4040
public System.Collections.Generic.IList<OpenAI.Chat.ChatTool> Definitions { get { throw null; } }
4141
public void Add(System.Reflection.MethodInfo function) { }
4242
public void Add(System.Type functions) { }
43+
public System.Threading.Tasks.Task AddMcpServerAsync(System.Uri serverEndpoint) { throw null; }
4344
public string Call(OpenAI.Chat.ChatToolCall call) { throw null; }
4445
public string Call(string name, object[] arguments) { throw null; }
4546
public System.Collections.Generic.IEnumerable<OpenAI.Chat.ToolChatMessage> CallAll(System.Collections.Generic.IEnumerable<OpenAI.Chat.ChatToolCall> toolCalls) { throw null; }
46-
public System.Collections.Generic.IEnumerable<OpenAI.Chat.ToolChatMessage> CallAll(System.Collections.Generic.IEnumerable<OpenAI.Chat.ChatToolCall> toolCalls, out System.Collections.Generic.List<string>? failed) { throw null; }
47+
public System.Threading.Tasks.Task<Azure.Projects.OpenAI.ToolCallResult> CallAllWithErrors(System.Collections.Generic.IEnumerable<OpenAI.Chat.ChatToolCall> toolCalls) { throw null; }
4748
public static implicit operator OpenAI.Chat.ChatCompletionOptions (Azure.Projects.OpenAI.ChatTools tools) { throw null; }
4849
public OpenAI.Chat.ChatCompletionOptions ToOptions() { throw null; }
4950
}
@@ -78,6 +79,15 @@ protected override void EmitConstructs(Azure.Projects.ProjectInfrastructure infr
7879
protected override void EmitFeatures(Azure.Projects.ProjectInfrastructure infrastructure) { }
7980
}
8081
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
82+
public partial struct ToolCallResult
83+
{
84+
private object _dummy;
85+
private int _dummyPrimitive;
86+
public ToolCallResult(System.Collections.Generic.IEnumerable<OpenAI.Chat.ToolChatMessage> messages, System.Collections.Generic.List<string>? failed = null) { throw null; }
87+
public System.Collections.Generic.List<string>? Failed { get { throw null; } }
88+
public System.Collections.Generic.IEnumerable<OpenAI.Chat.ToolChatMessage> Messages { get { throw null; } }
89+
}
90+
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
8191
public readonly partial struct VectorbaseEntry
8292
{
8393
private readonly object _dummy;

sdk/cloudmachine/Azure.Projects.OpenAI/api/Azure.Projects.OpenAI.netstandard2.0.cs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,21 @@ public ChatProcessor(OpenAI.Chat.ChatClient chat, OpenAI.Embeddings.EmbeddingCli
3030
protected virtual void OnGround(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, string prompt) { }
3131
protected virtual void OnLength(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, OpenAI.Chat.ChatCompletion completion) { }
3232
protected virtual void OnStop(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, OpenAI.Chat.ChatCompletion completion) { }
33-
protected virtual void OnToolCalls(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, OpenAI.Chat.ChatCompletion completion) { }
33+
protected virtual System.Threading.Tasks.Task OnToolCalls(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, OpenAI.Chat.ChatCompletion completion) { throw null; }
3434
protected virtual void OnToolError(System.Collections.Generic.List<string> failed, System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, OpenAI.Chat.ChatCompletion completion) { }
35-
public OpenAI.Chat.ChatCompletion TakeTurn(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, string prompt) { throw null; }
35+
public System.Threading.Tasks.Task<OpenAI.Chat.ChatCompletion> TakeTurnAsync(System.Collections.Generic.List<OpenAI.Chat.ChatMessage> conversation, string prompt) { throw null; }
3636
}
3737
public partial class ChatTools
3838
{
3939
public ChatTools(params System.Type[] tools) { }
4040
public System.Collections.Generic.IList<OpenAI.Chat.ChatTool> Definitions { get { throw null; } }
4141
public void Add(System.Reflection.MethodInfo function) { }
4242
public void Add(System.Type functions) { }
43+
public System.Threading.Tasks.Task AddMcpServerAsync(System.Uri serverEndpoint) { throw null; }
4344
public string Call(OpenAI.Chat.ChatToolCall call) { throw null; }
4445
public string Call(string name, object[] arguments) { throw null; }
4546
public System.Collections.Generic.IEnumerable<OpenAI.Chat.ToolChatMessage> CallAll(System.Collections.Generic.IEnumerable<OpenAI.Chat.ChatToolCall> toolCalls) { throw null; }
46-
public System.Collections.Generic.IEnumerable<OpenAI.Chat.ToolChatMessage> CallAll(System.Collections.Generic.IEnumerable<OpenAI.Chat.ChatToolCall> toolCalls, out System.Collections.Generic.List<string>? failed) { throw null; }
47+
public System.Threading.Tasks.Task<Azure.Projects.OpenAI.ToolCallResult> CallAllWithErrors(System.Collections.Generic.IEnumerable<OpenAI.Chat.ChatToolCall> toolCalls) { throw null; }
4748
public static implicit operator OpenAI.Chat.ChatCompletionOptions (Azure.Projects.OpenAI.ChatTools tools) { throw null; }
4849
public OpenAI.Chat.ChatCompletionOptions ToOptions() { throw null; }
4950
}
@@ -78,6 +79,15 @@ protected override void EmitConstructs(Azure.Projects.ProjectInfrastructure infr
7879
protected override void EmitFeatures(Azure.Projects.ProjectInfrastructure infrastructure) { }
7980
}
8081
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
82+
public partial struct ToolCallResult
83+
{
84+
private object _dummy;
85+
private int _dummyPrimitive;
86+
public ToolCallResult(System.Collections.Generic.IEnumerable<OpenAI.Chat.ToolChatMessage> messages, System.Collections.Generic.List<string>? failed = null) { throw null; }
87+
public System.Collections.Generic.List<string>? Failed { get { throw null; } }
88+
public System.Collections.Generic.IEnumerable<OpenAI.Chat.ToolChatMessage> Messages { get { throw null; } }
89+
}
90+
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
8191
public readonly partial struct VectorbaseEntry
8292
{
8393
private readonly object _dummy;

sdk/cloudmachine/Azure.Projects.OpenAI/src/CHatProcessor.cs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System;
55
using System.Collections.Generic;
6+
using System.Threading.Tasks;
67
using Azure.AI.OpenAI;
78
using OpenAI.Chat;
89
using OpenAI.Embeddings;
@@ -58,7 +59,7 @@ public ChatProcessor(ChatClient chat, EmbeddingClient? embeddings, ChatTools? to
5859
/// <param name="prompt"></param>
5960
/// <returns></returns>
6061
/// <exception cref="NotImplementedException"></exception>
61-
public ChatCompletion TakeTurn(List<ChatMessage> conversation, string prompt)
62+
public async Task<ChatCompletion> TakeTurnAsync(List<ChatMessage> conversation, string prompt)
6263
{
6364
OnGround(conversation, prompt);
6465

@@ -76,7 +77,7 @@ public ChatCompletion TakeTurn(List<ChatMessage> conversation, string prompt)
7677
OnLength(conversation, completion);
7778
goto complete;
7879
case ChatFinishReason.ToolCalls:
79-
OnToolCalls(conversation, completion);
80+
await OnToolCalls(conversation, completion).ConfigureAwait(false);
8081
goto complete;
8182
default:
8283
//case ChatFinishReason.ContentFilter:
@@ -142,21 +143,21 @@ protected virtual void OnLength(List<ChatMessage> conversation, ChatCompletion c
142143
/// </summary>
143144
/// <param name="conversation"></param>
144145
/// <param name="completion"></param>
145-
protected virtual void OnToolCalls(List<ChatMessage> conversation, ChatCompletion completion)
146+
protected virtual async Task OnToolCalls(List<ChatMessage> conversation, ChatCompletion completion)
146147
{
147148
if (Tools == null)
148149
throw new InvalidOperationException("No tools defined.");
149150

150151
// for some reason I am getting tool calls for tools that dont exist.
151-
IEnumerable<ToolChatMessage> toolResults = Tools.CallAll(completion.ToolCalls, out List<string>? failed);
152-
if (failed != null)
152+
var toolResults = await Tools.CallAllWithErrors(completion.ToolCalls).ConfigureAwait(false);
153+
if (toolResults.Failed != null)
153154
{
154-
OnToolError(failed, conversation, completion);
155+
OnToolError(toolResults.Failed, conversation, completion);
155156
}
156157
else
157158
{
158159
conversation.Add(completion);
159-
conversation.AddRange(toolResults);
160+
conversation.AddRange(toolResults.Messages);
160161
}
161162
}
162163

sdk/cloudmachine/Azure.Projects.OpenAI/src/ChatTools.cs

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using System.Reflection;
99
using System.Text;
1010
using System.Text.Json;
11+
using System.Threading.Tasks;
1112
using OpenAI.Chat;
1213

1314
namespace Azure.Projects.OpenAI;
@@ -18,8 +19,12 @@ public class ChatTools
1819
private static readonly BinaryData s_noparams = BinaryData.FromString("""{ "type" : "object", "properties" : {} }""");
1920

2021
private readonly Dictionary<string, MethodInfo> _methods = [];
22+
private readonly Dictionary<string, Func<string, BinaryData, Task<BinaryData>>> _mcpMethods = [];
2123
private readonly List<ChatTool> _definitions = [];
2224

25+
private List<McpClient> _mcpClients = [];
26+
private Dictionary<string, McpClient> _mcpClientsByEndpoint = [];
27+
2328
/// <summary>
2429
/// Initializes a new instance of the <see cref="ChatTools"/> class.
2530
/// </summary>
@@ -30,6 +35,20 @@ public ChatTools(params Type[] tools)
3035
Add(functionHolder);
3136
}
3237

38+
/// <summary>
39+
/// Adds a new MCP Server connection to be used for function calls.
40+
/// </summary>
41+
/// <param name="serverEndpoint">The Uri of the MCP Server.</param>
42+
public async Task AddMcpServerAsync(Uri serverEndpoint)
43+
{
44+
var client = new McpClient(serverEndpoint);
45+
_mcpClientsByEndpoint[serverEndpoint.AbsoluteUri] = client;
46+
await client.StartAsync().ConfigureAwait(false);
47+
BinaryData tools = await client.ListToolsAsync().ConfigureAwait(false);
48+
Add(tools, client);
49+
_mcpClients.Add(client);
50+
}
51+
3352
/// <summary>
3453
/// Gets the tool definitions.
3554
/// </summary>
@@ -49,6 +68,42 @@ public static implicit operator ChatCompletionOptions(ChatTools tools)
4968
return options;
5069
}
5170

71+
/// <summary>
72+
/// Adds tool definitions from a JSON array in BinaryData format.
73+
/// </summary>
74+
/// <param name="toolDefinitions">BinaryData containing a JSON array of tool definitions</param>
75+
/// <param name="client">The McpClient.</param>
76+
/// <exception cref="ArgumentNullException">Thrown when toolDefinitions is null</exception>
77+
/// <exception cref="JsonException">Thrown when JSON parsing fails</exception>
78+
internal void Add(BinaryData toolDefinitions, McpClient client)
79+
{
80+
using var document = JsonDocument.Parse(toolDefinitions);
81+
if (!document.RootElement.TryGetProperty("tools", out JsonElement toolsElement))
82+
{
83+
throw new JsonException("The JSON document must contain a 'tools' array.");
84+
}
85+
86+
var tools = toolsElement.EnumerateArray();
87+
// the replacement is to deal with OpenAI's tool name regex validation.
88+
var serverKey = client.ServerEndpoint.AbsoluteUri.Replace('/', '_').Replace(':', '_');
89+
90+
foreach (var tool in tools)
91+
{
92+
var name = $"{serverKey}__.__{tool.GetProperty("name").GetString()!}";
93+
var description = tool.GetProperty("description").GetString()!;
94+
var inputSchema = tool.GetProperty("inputSchema").GetRawText();
95+
96+
var chatTool = ChatTool.CreateFunctionTool(
97+
name,
98+
description,
99+
BinaryData.FromString(inputSchema));
100+
101+
_definitions.Add(chatTool);
102+
103+
_mcpMethods[name] = client.CallToolAsync;
104+
}
105+
}
106+
52107
/// <summary>
53108
/// Adds a set of functions to the chat functions.
54109
/// </summary>
@@ -128,6 +183,26 @@ public string Call(ChatToolCall call)
128183
return result;
129184
}
130185

186+
private async Task<string> CallMcp(ChatToolCall call)
187+
{
188+
if (_mcpMethods.TryGetValue(call.FunctionName, out Func<string, BinaryData, Task<BinaryData>>? method))
189+
{
190+
#if !NETSTANDARD2_0
191+
var actualFunctionName = call.FunctionName.Split("__.__", 2)[1];
192+
#else
193+
var separator = "__.__";
194+
var index = call.FunctionName.IndexOf(separator);
195+
var actualFunctionName = call.FunctionName.Substring(index + separator.Length);
196+
#endif
197+
var result = await method(actualFunctionName, call.FunctionArguments).ConfigureAwait(false);
198+
return result.ToString();
199+
}
200+
else
201+
{
202+
throw new NotImplementedException($"MCP tool {call.FunctionName} not found.");
203+
}
204+
}
205+
131206
/// <summary>
132207
/// Calls all the specified <see cref="ChatToolCall"/>s.
133208
/// </summary>
@@ -148,24 +223,32 @@ public IEnumerable<ToolChatMessage> CallAll(IEnumerable<ChatToolCall> toolCalls)
148223
/// Calls all the specified <see cref="ChatToolCall"/>s.
149224
/// </summary>
150225
/// <param name="toolCalls"></param>
151-
/// <param name="failed"></param>
152226
/// <returns></returns>
153-
public IEnumerable<ToolChatMessage> CallAll(IEnumerable<ChatToolCall> toolCalls, out List<string>? failed)
227+
public async Task<ToolCallResult> CallAllWithErrors(IEnumerable<ChatToolCall> toolCalls)
154228
{
155-
failed = null;
229+
List<string>? failed = null;
230+
bool isMcpTool = false;
156231
var messages = new List<ToolChatMessage>();
157232
foreach (ChatToolCall toolCall in toolCalls)
158233
{
159234
if (!_methods.ContainsKey(toolCall.FunctionName))
160235
{
161-
if (failed == null) failed = new List<string>();
162-
failed.Add(toolCall.FunctionName);
163-
continue;
236+
if (_mcpMethods.ContainsKey(toolCall.FunctionName))
237+
{
238+
isMcpTool = true;
239+
}
240+
else
241+
{
242+
failed ??= new List<string>();
243+
failed.Add(toolCall.FunctionName);
244+
continue;
245+
}
164246
}
165-
var result = Call(toolCall);
247+
248+
var result = isMcpTool ? await CallMcp(toolCall).ConfigureAwait(false) : Call(toolCall);
166249
messages.Add(new ToolChatMessage(toolCall.Id, result));
167250
}
168-
return messages;
251+
return new(messages, failed);
169252
}
170253

171254
/// <summary>
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.ClientModel.Primitives;
6+
using System.Threading.Tasks;
7+
8+
internal class McpClient
9+
{
10+
private McpSession _session;
11+
private ClientPipeline _pipeline = ClientPipeline.Create();
12+
13+
public Uri ServerEndpoint {get;}
14+
15+
public McpClient(Uri endpoint)
16+
{
17+
_session = new McpSession(endpoint, _pipeline);
18+
ServerEndpoint = endpoint;
19+
}
20+
21+
public async Task StartAsync()
22+
{
23+
await _session.EnsureInitializedAsync().ConfigureAwait(false);
24+
}
25+
26+
public void Stop()
27+
{
28+
_session.Stop();
29+
}
30+
31+
public async Task<BinaryData> ListToolsAsync()
32+
{
33+
if (_session == null)
34+
throw new InvalidOperationException("Session is not initialized. Call StartAsync() first.");
35+
36+
return await _session.SendMethod("tools/list").ConfigureAwait(false);
37+
}
38+
39+
public async Task<BinaryData> CallToolAsync(string toolName, BinaryData parameters)
40+
{
41+
if (_session == null)
42+
throw new InvalidOperationException("Session is not initialized. Call StartAsync() first.");
43+
44+
Console.WriteLine($"Calling tool {toolName}...");
45+
return await _session.CallTool(toolName, parameters).ConfigureAwait(false);
46+
}
47+
}

0 commit comments

Comments
 (0)