Skip to content

Commit 3fc8553

Browse files
committed
Model MCP and code execution using MS.E.AI primitives
Since opting for those outputs likely consumes additional output tokens, make them opt-in only. Now both are represented with specific tools rather than our generic HostedToolCallContent. Marked its usage as experimental since it's marked such in MS.E.AI too.
1 parent 2a105dc commit 3fc8553

File tree

7 files changed

+254
-116
lines changed

7 files changed

+254
-116
lines changed

src/Extensions.Grok/Extensions.Grok.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
<PackageLicenseFile>OSMFEULA.txt</PackageLicenseFile>
1111
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
1212
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
13-
<NoWarn>MEAI001;$(NoWarn)</NoWarn>
13+
<NoWarn>MEAI001;DEAI001;$(NoWarn)</NoWarn>
1414
</PropertyGroup>
1515

1616
<ItemGroup>
1717
<PackageReference Include="Google.Protobuf" Version="3.33.1" />
18-
<PackageReference Include="GrokClient" Version="1.0.0" />
18+
<PackageReference Include="GrokClient" Version="1.0.2" />
1919
<PackageReference Include="Grpc.Net.Client" Version="2.71.0" />
2020
<PackageReference Include="Grpc.Tools" Version="2.76.0" PrivateAssets="all" />
2121
<PackageReference Include="NuGetizer" Version="1.4.5" PrivateAssets="all" />

src/Extensions.Grok/GrokChatClient.cs

Lines changed: 122 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using System.Text.Json;
22
using Devlooped.Grok;
3-
using Google.Protobuf;
43
using Grpc.Core;
54
using Grpc.Net.Client;
65
using Microsoft.Extensions.AI;
@@ -23,93 +22,133 @@ internal GrokChatClient(GrpcChannel channel, GrokClientOptions clientOptions, st
2322

2423
public async Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
2524
{
26-
var requestDto = MapToRequest(messages, options);
27-
var protoResponse = await client.GetCompletionAsync(requestDto, cancellationToken: cancellationToken);
28-
var lastOutput = protoResponse.Outputs.OrderByDescending(x => x.Index).FirstOrDefault();
25+
var request = MapToRequest(messages, options);
26+
var response = await client.GetCompletionAsync(request, cancellationToken: cancellationToken);
27+
var lastOutput = response.Outputs.OrderByDescending(x => x.Index).FirstOrDefault();
2928

3029
if (lastOutput == null)
3130
{
3231
return new ChatResponse()
3332
{
34-
ResponseId = protoResponse.Id,
35-
ModelId = protoResponse.Model,
36-
CreatedAt = protoResponse.Created.ToDateTimeOffset(),
37-
Usage = MapToUsage(protoResponse.Usage),
33+
ResponseId = response.Id,
34+
ModelId = response.Model,
35+
CreatedAt = response.Created.ToDateTimeOffset(),
36+
Usage = MapToUsage(response.Usage),
3837
};
3938
}
4039

4140
var message = new ChatMessage(MapRole(lastOutput.Message.Role), default(string));
42-
var citations = protoResponse.Citations?.Distinct().Select(MapCitation).ToList<AIAnnotation>();
41+
var citations = response.Citations?.Distinct().Select(MapCitation).ToList<AIAnnotation>();
4342

44-
foreach (var output in protoResponse.Outputs.OrderBy(x => x.Index))
43+
foreach (var output in response.Outputs.OrderBy(x => x.Index))
4544
{
4645
if (output.Message.Content is { Length: > 0 } text)
4746
{
48-
var content = new TextContent(text)
47+
// Special-case output from tools
48+
if (output.Message.Role == MessageRole.RoleTool &&
49+
output.Message.ToolCalls.Count == 1 &&
50+
output.Message.ToolCalls[0] is { } toolCall)
4951
{
50-
Annotations = citations
51-
};
52+
if (toolCall.Type == ToolCallType.McpTool)
53+
{
54+
message.Contents.Add(new McpServerToolCallContent(toolCall.Id, toolCall.Function.Name, null)
55+
{
56+
RawRepresentation = toolCall
57+
});
58+
message.Contents.Add(new McpServerToolResultContent(toolCall.Id)
59+
{
60+
RawRepresentation = toolCall,
61+
Output = [new TextContent(text)]
62+
});
63+
continue;
64+
}
65+
else if (toolCall.Type == ToolCallType.CodeExecutionTool)
66+
{
67+
message.Contents.Add(new CodeInterpreterToolCallContent()
68+
{
69+
CallId = toolCall.Id,
70+
RawRepresentation = toolCall
71+
});
72+
message.Contents.Add(new CodeInterpreterToolResultContent()
73+
{
74+
CallId = toolCall.Id,
75+
RawRepresentation = toolCall,
76+
Outputs = [new TextContent(text)]
77+
});
78+
continue;
79+
}
80+
}
81+
82+
var content = new TextContent(text) { Annotations = citations };
5283

5384
foreach (var citation in output.Message.Citations)
54-
{
5585
(content.Annotations ??= []).Add(MapInlineCitation(citation));
56-
}
86+
5787
message.Contents.Add(content);
5888
}
5989

6090
foreach (var toolCall in output.Message.ToolCalls)
61-
{
62-
if (toolCall.Type == ToolCallType.ClientSideTool)
63-
{
64-
var arguments = !string.IsNullOrEmpty(toolCall.Function.Arguments)
65-
? JsonSerializer.Deserialize<IDictionary<string, object?>>(toolCall.Function.Arguments)
66-
: null;
67-
68-
var content = new FunctionCallContent(
69-
toolCall.Id,
70-
toolCall.Function.Name,
71-
arguments);
72-
73-
message.Contents.Add(content);
74-
}
75-
else
76-
{
77-
message.Contents.Add(new HostedToolCallContent(toolCall));
78-
}
79-
}
91+
message.Contents.Add(MapToolCall(toolCall));
8092
}
8193

8294
return new ChatResponse(message)
8395
{
84-
ResponseId = protoResponse.Id,
85-
ModelId = protoResponse.Model,
86-
CreatedAt = protoResponse.Created.ToDateTimeOffset(),
96+
ResponseId = response.Id,
97+
ModelId = response.Model,
98+
CreatedAt = response.Created.ToDateTimeOffset(),
8799
FinishReason = lastOutput != null ? MapFinishReason(lastOutput.FinishReason) : null,
88-
Usage = MapToUsage(protoResponse.Usage),
100+
Usage = MapToUsage(response.Usage),
89101
};
90102
}
91103

104+
AIContent MapToolCall(ToolCall toolCall) => toolCall.Type switch
105+
{
106+
ToolCallType.ClientSideTool => new FunctionCallContent(
107+
toolCall.Id,
108+
toolCall.Function.Name,
109+
!string.IsNullOrEmpty(toolCall.Function.Arguments)
110+
? JsonSerializer.Deserialize<IDictionary<string, object?>>(toolCall.Function.Arguments)
111+
: null)
112+
{
113+
RawRepresentation = toolCall
114+
},
115+
ToolCallType.McpTool => new McpServerToolCallContent(toolCall.Id, toolCall.Function.Name, null)
116+
{
117+
RawRepresentation = toolCall
118+
},
119+
ToolCallType.CodeExecutionTool => new CodeInterpreterToolCallContent()
120+
{
121+
CallId = toolCall.Id,
122+
RawRepresentation = toolCall
123+
},
124+
_ => new HostedToolCallContent()
125+
{
126+
CallId = toolCall.Id,
127+
RawRepresentation = toolCall
128+
}
129+
};
130+
92131
public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
93132
{
94133
return CompleteChatStreamingCore(messages, options, cancellationToken);
95134

96135
async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable<ChatMessage> messages, ChatOptions? options, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken)
97136
{
98-
var requestDto = MapToRequest(messages, options);
99-
var call = client.GetCompletionChunk(requestDto, cancellationToken: cancellationToken);
100-
137+
var request = MapToRequest(messages, options);
138+
var call = client.GetCompletionChunk(request, cancellationToken: cancellationToken);
139+
101140
await foreach (var chunk in call.ResponseStream.ReadAllAsync(cancellationToken))
102141
{
103-
var outputChunk = chunk.Outputs[0];
104-
var text = outputChunk.Delta.Content is { Length: > 0 } delta ? delta : null;
142+
var output = chunk.Outputs[0];
143+
var text = output.Delta.Content is { Length: > 0 } delta ? delta : null;
105144

106145
// Use positional arguments for ChatResponseUpdate
107-
var update = new ChatResponseUpdate(MapRole(outputChunk.Delta.Role), text)
146+
var update = new ChatResponseUpdate(MapRole(output.Delta.Role), text)
108147
{
109148
ResponseId = chunk.Id,
110149
ModelId = chunk.Model,
111150
CreatedAt = chunk.Created?.ToDateTimeOffset(),
112-
FinishReason = outputChunk.FinishReason != FinishReason.ReasonInvalid ? MapFinishReason(outputChunk.FinishReason) : null,
151+
FinishReason = output.FinishReason != FinishReason.ReasonInvalid ? MapFinishReason(output.FinishReason) : null,
113152
};
114153

115154
if (chunk.Citations is { Count: > 0 } citations)
@@ -122,31 +161,11 @@ async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable
122161
}
123162

124163
foreach (var citation in citations.Distinct())
125-
{
126164
(textContent.Annotations ??= []).Add(MapCitation(citation));
127-
}
128165
}
129166

130-
foreach (var toolCall in outputChunk.Delta.ToolCalls)
131-
{
132-
if (toolCall.Type == ToolCallType.ClientSideTool)
133-
{
134-
var arguments = !string.IsNullOrEmpty(toolCall.Function.Arguments)
135-
? JsonSerializer.Deserialize<IDictionary<string, object?>>(toolCall.Function.Arguments)
136-
: null;
137-
138-
var content = new FunctionCallContent(
139-
toolCall.Id,
140-
toolCall.Function.Name,
141-
arguments);
142-
143-
update.Contents.Add(content);
144-
}
145-
else
146-
{
147-
update.Contents.Add(new HostedToolCallContent(toolCall));
148-
}
149-
}
167+
foreach (var toolCall in output.Delta.ToolCalls)
168+
update.Contents.Add(MapToolCall(toolCall));
150169

151170
if (update.Contents.Any())
152171
yield return update;
@@ -191,6 +210,8 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
191210
{
192211
var request = new GetCompletionsRequest
193212
{
213+
// By default always include citations in the final output if available
214+
Include = { IncludeOption.InlineCitations },
194215
Model = options?.ModelId ?? defaultModelId,
195216
};
196217

@@ -211,6 +232,10 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
211232
{
212233
gmsg.Content.Add(new Content { Text = textContent.Text });
213234
}
235+
else if (content.RawRepresentation is ToolCall toolCall)
236+
{
237+
gmsg.ToolCalls.Add(toolCall);
238+
}
214239
else if (content is FunctionCallContent functionCall)
215240
{
216241
gmsg.ToolCalls.Add(new ToolCall
@@ -224,10 +249,6 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
224249
}
225250
});
226251
}
227-
else if (content is HostedToolCallContent serverFunction)
228-
{
229-
gmsg.ToolCalls.Add(serverFunction.ToolCall);
230-
}
231252
else if (content is FunctionResultContent resultContent)
232253
{
233254
request.Messages.Add(new Message
@@ -236,19 +257,49 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
236257
Content = { new Content { Text = JsonSerializer.Serialize(resultContent.Result) ?? "null" } }
237258
});
238259
}
260+
else if (content is McpServerToolResultContent mcpResult &&
261+
mcpResult.RawRepresentation is ToolCall mcpToolCall &&
262+
mcpResult.Output is { Count: 1 } &&
263+
mcpResult.Output[0] is TextContent mcpText)
264+
{
265+
request.Messages.Add(new Message
266+
{
267+
Role = MessageRole.RoleTool,
268+
ToolCalls = { mcpToolCall },
269+
Content = { new Content { Text = mcpText.Text } }
270+
});
271+
}
272+
else if (content is CodeInterpreterToolResultContent codeResult &&
273+
codeResult.RawRepresentation is ToolCall codeToolCall &&
274+
codeResult.Outputs is { Count: 1 } &&
275+
codeResult.Outputs[0] is TextContent codeText)
276+
{
277+
request.Messages.Add(new Message
278+
{
279+
Role = MessageRole.RoleTool,
280+
ToolCalls = { codeToolCall },
281+
Content = { new Content { Text = codeText.Text } }
282+
});
283+
}
239284
}
240285

241286
if (gmsg.Content.Count == 0 && gmsg.ToolCalls.Count == 0)
242287
continue;
243288

289+
// If we have only tool calls and no content, the gRPC enpoint fails, so add an empty one.
244290
if (gmsg.Content.Count == 0)
245291
gmsg.Content.Add(new Content());
246292

247293
request.Messages.Add(gmsg);
248294
}
249295

296+
IList<IncludeOption> includes = [IncludeOption.InlineCitations];
250297
if (options is GrokChatOptions grokOptions)
251298
{
299+
// NOTE: overrides our default include for inline citations, potentially.
300+
request.Include.Clear();
301+
request.Include.AddRange(grokOptions.Include);
302+
252303
if (grokOptions.Search.HasFlag(GrokSearch.X))
253304
{
254305
(options.Tools ??= []).Insert(0, new GrokXSearchTool());

src/Extensions.Grok/GrokChatOptions.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System.ComponentModel;
2+
using Devlooped.Grok;
23
using Microsoft.Extensions.AI;
34

45
namespace Devlooped.Extensions.AI.Grok;
@@ -33,4 +34,8 @@ public class GrokChatOptions : ChatOptions
3334
/// <summary>Configures Grok's agentic search capabilities.</summary>
3435
/// <remarks>See https://docs.x.ai/docs/guides/tools/search-tools.</remarks>
3536
public GrokSearch Search { get; set; } = GrokSearch.None;
37+
38+
/// <summary>Additional outputs to include in responses.</summary>
39+
/// <remarks>Defaults to including <see cref="IncludeOption.InlineCitations"/>.</remarks>
40+
public IList<IncludeOption> Include { get; set; } = [IncludeOption.InlineCitations];
3641
}
Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,13 @@
1-
using System.Diagnostics;
2-
using System.Text.Json.Serialization;
3-
using Devlooped.Grok;
1+
using System.Diagnostics.CodeAnalysis;
42
using Microsoft.Extensions.AI;
53

64
namespace Devlooped.Extensions.AI;
75

86
/// <summary>Represents a hosted tool agentic call.</summary>
97
/// <param name="toolCall">The tool call details.</param>
10-
[DebuggerDisplay("{DebuggerDisplay,nq}")]
11-
[method: JsonConstructor]
12-
public sealed class HostedToolCallContent(ToolCall toolCall) : AIContent
8+
[Experimental("DEAI001")]
9+
public class HostedToolCallContent : AIContent
1310
{
14-
/// <summary>Gets the tool call details.</summary>
15-
public ToolCall ToolCall => toolCall;
16-
17-
/// <summary>Gets a string representing this instance to display in the debugger.</summary>
18-
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
19-
string DebuggerDisplay
20-
{
21-
get
22-
{
23-
var display = $"ToolCall = {toolCall.Id}, ";
24-
25-
display += toolCall.Function.Arguments is not null ?
26-
$"{toolCall.Function.Name}({toolCall.Function.Arguments})" :
27-
$"{toolCall.Function.Name}()";
28-
29-
return display;
30-
}
31-
}
11+
/// <summary>Gets or sets the tool call ID.</summary>
12+
public virtual string? CallId { get; set; }
3213
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System.Diagnostics;
2+
using System.Diagnostics.CodeAnalysis;
3+
using Microsoft.Extensions.AI;
4+
5+
namespace Devlooped.Extensions.AI;
6+
7+
/// <summary>Represents a hosted tool agentic call.</summary>
8+
/// <param name="toolCall">The tool call details.</param>
9+
[DebuggerDisplay("{DebuggerDisplay,nq}")]
10+
[Experimental("DEAI001")]
11+
public class HostedToolResultContent : AIContent
12+
{
13+
/// <summary>Gets or sets the tool call ID.</summary>
14+
public virtual string? CallId { get; set; }
15+
16+
/// <summary>Gets or sets the resulting contents from the tool.</summary>
17+
public virtual IList<AIContent>? Outputs { get; set; }
18+
}

0 commit comments

Comments
 (0)