Skip to content

Commit d5400b0

Browse files
committed
Add allow for citations
1 parent e41e0b8 commit d5400b0

File tree

2 files changed

+89
-7
lines changed

2 files changed

+89
-7
lines changed

src/Extensions.Grok/GrokChatClient.cs

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public async Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messag
3232
var protoResponse = await client.GetCompletionAsync(requestDto, cancellationToken: cancellationToken);
3333

3434
var chatMessages = protoResponse.Outputs
35-
.Select(x => MapToChatMessage(x.Message))
35+
.Select(x => MapToChatMessage(x.Message, protoResponse.Citations))
3636
.Where(x => x.Contents.Count > 0)
3737
.ToList();
3838

@@ -63,7 +63,7 @@ async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable
6363
var outputChunk = chunk.Outputs[0];
6464

6565
// Use positional arguments for ChatResponseUpdate
66-
yield return new ChatResponseUpdate(
66+
var update = new ChatResponseUpdate(
6767
outputChunk.Delta.Role != MessageRole.InvalidRole ? MapRole(outputChunk.Delta.Role) : null,
6868
outputChunk.Delta.Content
6969
)
@@ -73,6 +73,23 @@ async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable
7373
CreatedAt = chunk.Created.ToDateTimeOffset(),
7474
FinishReason = outputChunk.FinishReason != FinishReason.ReasonInvalid ? MapFinishReason(outputChunk.FinishReason) : null,
7575
};
76+
77+
if (chunk.Citations is { Count: > 0 } citations)
78+
{
79+
var textContent = update.Contents.OfType<TextContent>().FirstOrDefault();
80+
if (textContent == null)
81+
{
82+
textContent = new TextContent(string.Empty);
83+
update.Contents.Add(textContent);
84+
}
85+
86+
foreach (var citation in citations.Distinct())
87+
{
88+
(textContent.Annotations ??= []).Add(new CitationAnnotation { Url = new(citation) });
89+
}
90+
}
91+
92+
yield return update;
7693
}
7794
}
7895
}
@@ -181,13 +198,21 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
181198
_ => ChatRole.Assistant
182199
};
183200

184-
static ChatMessage MapToChatMessage(CompletionMessage message)
201+
static ChatMessage MapToChatMessage(CompletionMessage message, IList<string>? citations = null)
185202
{
186203
var chatMessage = new ChatMessage() { Role = MapRole(message.Role) };
187204

188-
if (!string.IsNullOrEmpty(message.Content))
205+
if (!string.IsNullOrEmpty(message.Content) || (citations is { Count: > 0 }))
189206
{
190-
chatMessage.Contents.Add(new TextContent(message.Content) { Annotations = [] });
207+
var textContent = new TextContent(message.Content ?? string.Empty);
208+
if (citations is { Count: > 0 })
209+
{
210+
foreach (var citation in citations.Distinct())
211+
{
212+
(textContent.Annotations ??= []).Add(new CitationAnnotation { Url = new(citation) });
213+
}
214+
}
215+
chatMessage.Contents.Add(textContent);
191216
}
192217

193218
foreach (var toolCall in message.ToolCalls)

src/Tests/GrokTests.cs

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,14 @@ public async Task GrokGrpcInvokesHostedSearchTool()
324324

325325
Assert.Contains("TSLA", text);
326326
Assert.NotNull(response.ModelId);
327+
Assert.Contains(new Uri("https://finance.yahoo.com/quote/TSLA/news/"), response.Messages
328+
.SelectMany(x => x.Contents)
329+
.SelectMany(x => x.Annotations?.OfType<CitationAnnotation>() ?? [])
330+
.Select(x => x.Url));
327331
}
328332

329333
[SecretsFact("XAI_API_KEY")]
330-
public async Task GrokGrpcInvokesGrokSearchTool()
334+
public async Task GrokGrpcInvokesGrokSearchToolIncludesDomain()
331335
{
332336
var messages = new Chat()
333337
{
@@ -341,13 +345,66 @@ public async Task GrokGrpcInvokesGrokSearchTool()
341345
{
342346
Tools = [new GrokSearchTool
343347
{
344-
AllowedDomains = ["microsoft.com", "news.microsoft.com"]
348+
AllowedDomains = ["microsoft.com", "news.microsoft.com"],
345349
}]
346350
};
347351

348352
var response = await grok.GetResponseAsync(messages, options);
349353

350354
Assert.NotNull(response.Text);
351355
Assert.Contains("Microsoft", response.Text);
356+
357+
var urls = response.Messages
358+
.SelectMany(x => x.Contents)
359+
.SelectMany(x => x.Annotations?.OfType<CitationAnnotation>() ?? [])
360+
.Where(x => x.Url is not null)
361+
.Select(x => x.Url!)
362+
.ToList();
363+
364+
foreach (var url in urls)
365+
{
366+
output.WriteLine(url.ToString());
367+
}
368+
369+
Assert.All(urls, x => x.Host.EndsWith(".microsoft.com"));
370+
}
371+
372+
[SecretsFact("XAI_API_KEY")]
373+
public async Task GrokGrpcInvokesGrokSearchToolExcludesDomain()
374+
{
375+
var messages = new Chat()
376+
{
377+
{ "system", "You are an AI assistant that knows how to search the web." },
378+
{ "user", "What is the latest news about Microsoft?" },
379+
};
380+
381+
var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-fast");
382+
383+
var options = new ChatOptions
384+
{
385+
Tools = [new GrokSearchTool
386+
{
387+
ExcludedDomains = ["blogs.microsoft.com"]
388+
}]
389+
};
390+
391+
var response = await grok.GetResponseAsync(messages, options);
392+
393+
Assert.NotNull(response.Text);
394+
Assert.Contains("Microsoft", response.Text);
395+
396+
var urls = response.Messages
397+
.SelectMany(x => x.Contents)
398+
.SelectMany(x => x.Annotations?.OfType<CitationAnnotation>() ?? [])
399+
.Where(x => x.Url is not null)
400+
.Select(x => x.Url!)
401+
.ToList();
402+
403+
foreach (var url in urls)
404+
{
405+
output.WriteLine(url.ToString());
406+
}
407+
408+
Assert.DoesNotContain(urls, x => x.Host == "blogs.microsoft.com");
352409
}
353410
}

0 commit comments

Comments
 (0)