Skip to content

Commit 98faaa5

Browse files
committed
Add agentic code execution tool support
See https://docs.x.ai/docs/guides/tools/code-execution-tool.
1 parent b1525f6 commit 98faaa5

File tree

3 files changed

+47
-21
lines changed

3 files changed

+47
-21
lines changed

src/Extensions.Grok/GrokChatClient.cs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
using System.Text.Json;
2-
3-
using Microsoft.Extensions.AI;
4-
2+
using Devlooped.Grok;
53
using Grpc.Core;
64
using Grpc.Net.Client;
7-
using Devlooped.Grok;
5+
using Microsoft.Extensions.AI;
86
using static Devlooped.Grok.Chat;
97

108
namespace Devlooped.Extensions.AI.Grok;
@@ -25,7 +23,7 @@ internal GrokChatClient(GrpcChannel channel, GrokClientOptions clientOptions, st
2523
public async Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
2624
{
2725
var requestDto = MapToRequest(messages, options);
28-
26+
2927
var protoResponse = await client.GetCompletionAsync(requestDto, cancellationToken: cancellationToken);
3028

3129
var chatMessages = protoResponse.Outputs
@@ -34,7 +32,7 @@ public async Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messag
3432
.ToList();
3533

3634
var lastOutput = protoResponse.Outputs.LastOrDefault();
37-
35+
3836
return new ChatResponse(chatMessages)
3937
{
4038
ResponseId = protoResponse.Id,
@@ -52,13 +50,13 @@ public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerabl
5250
async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable<ChatMessage> messages, ChatOptions? options, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken)
5351
{
5452
var requestDto = MapToRequest(messages, options);
55-
53+
5654
using var call = client.GetCompletionChunk(requestDto, cancellationToken: cancellationToken);
57-
55+
5856
await foreach (var chunk in call.ResponseStream.ReadAllAsync(cancellationToken))
5957
{
6058
var outputChunk = chunk.Outputs[0];
61-
59+
6260
// Use positional arguments for ChatResponseUpdate
6361
var update = new ChatResponseUpdate(
6462
outputChunk.Delta.Role != MessageRole.InvalidRole ? MapRole(outputChunk.Delta.Role) : null,
@@ -79,7 +77,7 @@ async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable
7977
textContent = new TextContent(string.Empty);
8078
update.Contents.Add(textContent);
8179
}
82-
80+
8381
foreach (var citation in citations.Distinct())
8482
{
8583
(textContent.Annotations ??= []).Add(new CitationAnnotation { Url = new(citation) });
@@ -150,7 +148,7 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
150148
foreach (var message in messages)
151149
{
152150
var gmsg = new Message { Role = MapRole(message.Role) };
153-
151+
154152
foreach (var content in message.Contents)
155153
{
156154
if (content is TextContent textContent && !string.IsNullOrEmpty(textContent.Text))
@@ -250,6 +248,10 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
250248
request.Tools.Add(new Tool { WebSearch = new WebSearch() });
251249
}
252250
}
251+
else if (tool is HostedCodeInterpreterTool)
252+
{
253+
request.Tools.Add(new Tool { CodeExecution = new CodeExecution { } });
254+
}
253255
}
254256
}
255257

@@ -272,7 +274,7 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
272274
_ when role == ChatRole.Tool => MessageRole.RoleTool,
273275
_ => MessageRole.RoleUser
274276
};
275-
277+
276278
static ChatRole MapRole(MessageRole role) => role switch
277279
{
278280
MessageRole.RoleSystem => ChatRole.System,
@@ -299,8 +301,8 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
299301
TotalTokenCount = usage.TotalTokens
300302
};
301303

302-
public object? GetService(Type serviceType, object? serviceKey = null) =>
304+
public object? GetService(Type serviceType, object? serviceKey = null) =>
303305
serviceType == typeof(GrokChatClient) ? this : null;
304306

305-
public void Dispose() {}
307+
public void Dispose() { }
306308
}

src/Extensions.Grok/GrokClient.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ public class GrokClient(string apiKey, GrokClientOptions options)
2626
internal GrpcChannel Channel => channels.GetOrAdd((Endpoint, ApiKey), key =>
2727
{
2828
var handler = new AuthenticationHeaderHandler(ApiKey)
29-
{
30-
InnerHandler = Options.ChannelOptions?.HttpHandler ?? new HttpClientHandler()
29+
{
30+
InnerHandler = Options.ChannelOptions?.HttpHandler ?? new HttpClientHandler()
3131
};
3232

3333
var options = Options.ChannelOptions ?? new GrpcChannelOptions();

src/Tests/GrokTests.cs

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ public async Task GrokInvokesToolAndSearch()
6565
{
6666
ModelId = "grok-4-1-fast-non-reasoning",
6767
Search = GrokSearch.Web,
68-
Tools = [AIFunctionFactory.Create(() =>
68+
Tools = [AIFunctionFactory.Create(() =>
6969
{
7070
getDateCalls++;
7171
return DateTimeOffset.Now.ToString("O");
7272
}, "get_date", "Gets the current date")],
7373
};
74-
74+
7575
var response = await grok.GetResponseAsync(messages, options);
7676

7777
// The get_date result shows up as a tool role
@@ -178,14 +178,14 @@ public async Task GrokInvokesGrokSearchToolIncludesDomain()
178178

179179
var options = new ChatOptions
180180
{
181-
Tools = [new GrokSearchTool
182-
{
181+
Tools = [new GrokSearchTool
182+
{
183183
AllowedDomains = ["microsoft.com", "news.microsoft.com"],
184184
}]
185185
};
186186

187187
var response = await grok.GetResponseAsync(messages, options);
188-
188+
189189
Assert.NotNull(response.Text);
190190
Assert.Contains("Microsoft", response.Text);
191191

@@ -242,4 +242,28 @@ public async Task GrokInvokesGrokSearchToolExcludesDomain()
242242

243243
Assert.DoesNotContain(urls, x => x.Host == "blogs.microsoft.com");
244244
}
245+
246+
[SecretsFact("XAI_API_KEY")]
247+
public async Task GrokInvokesHostedCodeExecution()
248+
{
249+
var messages = new Chat()
250+
{
251+
{ "user", "Calculate the compound interest for $10,000 at 5% annually for 10 years" },
252+
};
253+
254+
var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-fast");
255+
256+
var options = new ChatOptions
257+
{
258+
Tools = [new HostedCodeInterpreterTool()]
259+
};
260+
261+
var response = await grok.GetResponseAsync(messages, options);
262+
var text = response.Text;
263+
264+
Assert.Contains("$6,288.95", text);
265+
Assert.Contains(
266+
response.Messages.SelectMany(x => x.Contents).OfType<HostedToolCallContent>(),
267+
x => x.ToolCall.Type == Devlooped.Grok.ToolCallType.CodeExecutionTool);
268+
}
245269
}

0 commit comments

Comments
 (0)