Skip to content

Commit 14b64c3

Browse files
committed
Add streaming test
1 parent da48114 commit 14b64c3

File tree

2 files changed

+100
-10
lines changed

2 files changed

+100
-10
lines changed

src/Extensions.Grok/GrokChatClient.cs

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System.Text.Json;
22
using Devlooped.Grok;
3+
using Google.Protobuf;
34
using Grpc.Core;
45
using Grpc.Net.Client;
56
using Microsoft.Extensions.AI;
@@ -95,22 +96,19 @@ public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerabl
9596
async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable<ChatMessage> messages, ChatOptions? options, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken)
9697
{
9798
var requestDto = MapToRequest(messages, options);
98-
99-
using var call = client.GetCompletionChunk(requestDto, cancellationToken: cancellationToken);
100-
99+
var call = client.GetCompletionChunk(requestDto, cancellationToken: cancellationToken);
100+
101101
await foreach (var chunk in call.ResponseStream.ReadAllAsync(cancellationToken))
102102
{
103103
var outputChunk = chunk.Outputs[0];
104+
var text = outputChunk.Delta.Content is { Length: > 0 } delta ? delta : null;
104105

105106
// Use positional arguments for ChatResponseUpdate
106-
var update = new ChatResponseUpdate(
107-
outputChunk.Delta.Role != MessageRole.InvalidRole ? MapRole(outputChunk.Delta.Role) : null,
108-
outputChunk.Delta.Content
109-
)
107+
var update = new ChatResponseUpdate(MapRole(outputChunk.Delta.Role), text)
110108
{
111109
ResponseId = chunk.Id,
112110
ModelId = chunk.Model,
113-
CreatedAt = chunk.Created.ToDateTimeOffset(),
111+
CreatedAt = chunk.Created?.ToDateTimeOffset(),
114112
FinishReason = outputChunk.FinishReason != FinishReason.ReasonInvalid ? MapFinishReason(outputChunk.FinishReason) : null,
115113
};
116114

@@ -129,7 +127,29 @@ async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable
129127
}
130128
}
131129

132-
yield return update;
130+
foreach (var toolCall in outputChunk.Delta.ToolCalls)
131+
{
132+
if (toolCall.Type == ToolCallType.ClientSideTool)
133+
{
134+
var arguments = !string.IsNullOrEmpty(toolCall.Function.Arguments)
135+
? JsonSerializer.Deserialize<IDictionary<string, object?>>(toolCall.Function.Arguments)
136+
: null;
137+
138+
var content = new FunctionCallContent(
139+
toolCall.Id,
140+
toolCall.Function.Name,
141+
arguments);
142+
143+
update.Contents.Add(content);
144+
}
145+
else
146+
{
147+
update.Contents.Add(new HostedToolCallContent(toolCall));
148+
}
149+
}
150+
151+
if (update.Contents.Any())
152+
yield return update;
133153
}
134154
}
135155
}
@@ -221,6 +241,9 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
221241
if (gmsg.Content.Count == 0 && gmsg.ToolCalls.Count == 0)
222242
continue;
223243

244+
if (gmsg.Content.Count == 0)
245+
gmsg.Content.Add(new Content());
246+
224247
request.Messages.Add(gmsg);
225248
}
226249

src/Tests/GrokTests.cs

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Text.Json.Nodes;
1+
using System.Text.Json;
2+
using System.Text.Json.Nodes;
23
using Azure;
34
using Devlooped.Extensions.AI.Grok;
45
using Microsoft.Extensions.AI;
@@ -319,4 +320,70 @@ public async Task GrokInvokesHostedMcp()
319320
response.Messages.SelectMany(x => x.Contents).OfType<HostedToolCallContent>(),
320321
x => x.ToolCall.Type == Devlooped.Grok.ToolCallType.McpTool);
321322
}
323+
324+
[SecretsFact("XAI_API_KEY", "GITHUB_TOKEN")]
325+
public async Task GrokStreamsUpdatesFromAllTools()
326+
{
327+
var messages = new Chat()
328+
{
329+
{ "user",
330+
"""
331+
What's the oldest stable version released on the devlooped/GrokClient repo on GitHub?,
332+
what is the current price of Tesla stock,
333+
and what is the current date? Respond with the following JSON:
334+
{
335+
"today": "[get_date result]",
336+
"release": "[first stable release of devlooped/GrokClient]",
337+
"price": [$TSLA price]
338+
}
339+
"""
340+
},
341+
};
342+
343+
var grok = new GrokClient(Configuration["XAI_API_KEY"]!)
344+
.AsIChatClient("grok-4-fast")
345+
.AsBuilder()
346+
.UseFunctionInvocation()
347+
.UseLogging(output.AsLoggerFactory())
348+
.Build();
349+
350+
var getDateCalls = 0;
351+
var options = new ChatOptions
352+
{
353+
Tools =
354+
[
355+
new HostedWebSearchTool(),
356+
new HostedMcpServerTool("GitHub", "https://api.githubcopilot.com/mcp/") {
357+
AuthorizationToken = Configuration["GITHUB_TOKEN"]!,
358+
AllowedTools = ["list_releases", "get_release_by_tag"],
359+
},
360+
AIFunctionFactory.Create(() => {
361+
getDateCalls++;
362+
return DateTimeOffset.Now.ToString("O");
363+
}, "get_date", "Gets the current date")
364+
]
365+
};
366+
367+
var updates = await grok.GetStreamingResponseAsync(messages, options).ToListAsync();
368+
var response = updates.ToChatResponse();
369+
var typed = JsonSerializer.Deserialize<Response>(response.Text, new JsonSerializerOptions(JsonSerializerDefaults.Web));
370+
371+
Assert.NotNull(typed);
372+
373+
Assert.Contains(
374+
response.Messages.SelectMany(x => x.Contents).OfType<HostedToolCallContent>(),
375+
x => x.ToolCall.Type == Devlooped.Grok.ToolCallType.McpTool);
376+
377+
Assert.Contains(
378+
response.Messages.SelectMany(x => x.Contents).OfType<HostedToolCallContent>(),
379+
x => x.ToolCall.Type == Devlooped.Grok.ToolCallType.WebSearchTool);
380+
381+
Assert.Equal(1, getDateCalls);
382+
383+
Assert.Equal(DateOnly.FromDateTime(DateTime.Today), typed.Today);
384+
Assert.EndsWith("1.0.0", typed.Release);
385+
Assert.True(typed.Price > 100);
386+
}
387+
388+
record Response(DateOnly Today, string Release, decimal Price);
322389
}

0 commit comments

Comments
 (0)