Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,50 @@ var options = new ChatOptions
};
```

To receive the actual search results and file references, include `CollectionsSearchCallOutput` in the options:

```csharp
var options = new GrokChatOptions
{
Include = [IncludeOption.CollectionsSearchCallOutput],
Tools = [new HostedFileSearchTool {
Inputs = [new HostedVectorStoreContent("[collection_id]")]
}]
};

var response = await grok.GetResponseAsync(messages, options);

// Access the search results with file references
var results = response.Messages
.SelectMany(x => x.Contents)
.OfType<CollectionSearchToolResultContent>();

foreach (var result in results)
{
// Each result contains files that were found and referenced
var files = result.Outputs?.OfType<HostedFileContent>();
foreach (var file in files ?? [])
{
Console.WriteLine($"File: {file.Name} (ID: {file.FileId})");

// Files include citation annotations with snippets
foreach (var citation in file.Annotations?.OfType<CitationAnnotation>() ?? [])
{
Console.WriteLine($" Title: {citation.Title}");
Console.WriteLine($" Snippet: {citation.Snippet}");
Console.WriteLine($" URL: {citation.Url}"); // collections://[collection_id]/files/[file_id]
}
}
}
```

Citations from collection search include:
- **Title**: Extracted from the first line of the chunk content (if available), typically the file name or heading
- **Snippet**: The relevant text excerpt from the document
- **FileId**: Identifier of the source file in the collection
- **Url**: A `collections://` URI pointing to the specific file within the collection
- **ToolName**: Always set to `"collections_search"`

Learn more about [collection search](https://docs.x.ai/docs/guides/tools/collections-search-tool).

## Remote MCP
Expand Down
23 changes: 16 additions & 7 deletions src/xAI.Tests/ChatClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ public async Task GrokInvokesToolAndSearch()
Assert.Equal(options.ModelId, response.ModelId);

var calls = response.Messages
.SelectMany(x => x.Contents.OfType<HostedToolCallContent>())
.Select(x => x.RawRepresentation as xAI.Protocol.ToolCall)
.SelectMany(x => x.Contents.Select(x => x.RawRepresentation as xAI.Protocol.ToolCall))
.Where(x => x is not null)
.ToList();

Expand Down Expand Up @@ -317,7 +316,6 @@ public async Task GrokInvokesHostedCollectionSearch()

var options = new GrokChatOptions
{
Include = { IncludeOption.CollectionsSearchCallOutput },
Tools = [new HostedFileSearchTool {
Inputs = [new HostedVectorStoreContent("collection_91559d9b-a55d-42fe-b2ad-ecf8904d9049")]
}]
Expand All @@ -329,9 +327,21 @@ public async Task GrokInvokesHostedCollectionSearch()
Assert.Contains("11,74", text);
Assert.Contains(response.Messages
.SelectMany(x => x.Contents)
.OfType<HostedToolCallContent>()
.OfType<CollectionSearchToolCallContent>()
.Select(x => x.RawRepresentation as xAI.Protocol.ToolCall),
x => x?.Type == xAI.Protocol.ToolCallType.CollectionsSearchTool);
// No actual search results content since we didn't specify it in Include
Assert.Empty(response.Messages.SelectMany(x => x.Contents).OfType<CollectionSearchToolResultContent>());

options.Include = [IncludeOption.CollectionsSearchCallOutput];
response = await grok.GetResponseAsync(messages, options);

// Now it also contains the file reference as result content
Assert.Contains(response.Messages
.SelectMany(x => x.Contents)
.OfType<CollectionSearchToolResultContent>()
.SelectMany(x => (x.Outputs ?? []).OfType<HostedFileContent>()),
x => x.Name == "LNS0004592.txt");
}

[SecretsFact("XAI_API_KEY", "GITHUB_TOKEN")]
Expand Down Expand Up @@ -458,9 +468,8 @@ public async Task GrokStreamsUpdatesFromAllTools()
.OfType<McpServerToolCallContent>());

Assert.Contains(response.Messages
.SelectMany(x => x.Contents)
.OfType<HostedToolCallContent>()
.Select(x => x.RawRepresentation as xAI.Protocol.ToolCall),
.SelectMany(x => x.Contents.Select(x => x.RawRepresentation as xAI.Protocol.ToolCall))
.Where(x => x != null),
x => x?.Type == xAI.Protocol.ToolCallType.WebSearchTool);

Assert.Equal(1, getDateCalls);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.AI;

namespace xAI;

/// <summary>Represents a hosted tool agentic call.</summary>
[Experimental("xAI001")]
public class HostedToolCallContent : AIContent
public class CollectionSearchToolCallContent : AIContent
{
/// <summary>Gets or sets the tool call ID.</summary>
public virtual string? CallId { get; set; }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.AI;

namespace xAI;

/// <summary>Represents a hosted tool agentic call.</summary>
[DebuggerDisplay("{DebuggerDisplay,nq}")]
[Experimental("xAI001")]
public class HostedToolResultContent : AIContent
public class CollectionSearchToolResultContent : AIContent
{
/// <summary>Gets or sets the tool call ID.</summary>
public virtual string? CallId { get; set; }
Expand Down
116 changes: 23 additions & 93 deletions src/xAI/GrokChatClient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Text.Json;
using Google.Protobuf;
using Grpc.Core;
using Grpc.Net.Client;
using Microsoft.Extensions.AI;
Expand Down Expand Up @@ -39,82 +40,23 @@ public async Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messag
var response = await client.GetCompletionAsync(request, cancellationToken: cancellationToken);
var lastOutput = response.Outputs.OrderByDescending(x => x.Index).FirstOrDefault();

if (lastOutput == null)
{
return new ChatResponse()
{
ResponseId = response.Id,
ModelId = response.Model,
CreatedAt = response.Created.ToDateTimeOffset(),
Usage = MapToUsage(response.Usage),
};
}

var message = new ChatMessage(MapRole(lastOutput.Message.Role), default(string));
var citations = response.Citations?.Distinct().Select(MapCitation).ToList<AIAnnotation>();

foreach (var output in response.Outputs.OrderBy(x => x.Index))
{
if (output.Message.Content is { Length: > 0 } text)
{
// Special-case output from tools
if (output.Message.Role == MessageRole.RoleTool &&
output.Message.ToolCalls.Count == 1 &&
output.Message.ToolCalls[0] is { } toolCall)
{
if (toolCall.Type == ToolCallType.McpTool)
{
message.Contents.Add(new McpServerToolCallContent(toolCall.Id, toolCall.Function.Name, null)
{
RawRepresentation = toolCall
});
message.Contents.Add(new McpServerToolResultContent(toolCall.Id)
{
RawRepresentation = toolCall,
Output = [new TextContent(text)]
});
continue;
}
else if (toolCall.Type == ToolCallType.CodeExecutionTool)
{
message.Contents.Add(new CodeInterpreterToolCallContent()
{
CallId = toolCall.Id,
RawRepresentation = toolCall
});
message.Contents.Add(new CodeInterpreterToolResultContent()
{
CallId = toolCall.Id,
RawRepresentation = toolCall,
Outputs = [new TextContent(text)]
});
continue;
}
}

var content = new TextContent(text) { Annotations = citations };

foreach (var citation in output.Message.Citations)
(content.Annotations ??= []).Add(MapInlineCitation(citation));

message.Contents.Add(content);
}

foreach (var toolCall in output.Message.ToolCalls)
message.Contents.Add(MapToolCall(toolCall));
}

return new ChatResponse(message)
var result = new ChatResponse()
{
ResponseId = response.Id,
ModelId = response.Model,
CreatedAt = response.Created?.ToDateTimeOffset(),
FinishReason = lastOutput != null ? MapFinishReason(lastOutput.FinishReason) : null,
Usage = MapToUsage(response.Usage),
};

var citations = response.Citations?.Distinct().Select(MapCitation).ToList<AIAnnotation>();

((List<ChatMessage>)result.Messages).AddRange(response.Outputs.AsChatMessages(citations));

return result;
}

AIContent MapToolCall(ToolCall toolCall) => toolCall.Type switch
AIContent? MapToolCall(ToolCall toolCall) => toolCall.Type switch
{
ToolCallType.ClientSideTool => new FunctionCallContent(
toolCall.Id,
Expand All @@ -134,11 +76,12 @@ public async Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messag
CallId = toolCall.Id,
RawRepresentation = toolCall
},
_ => new HostedToolCallContent()
ToolCallType.CollectionsSearchTool => new CollectionSearchToolCallContent()
{
CallId = toolCall.Id,
RawRepresentation = toolCall
}
},
_ => null
};

public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
Expand All @@ -161,44 +104,30 @@ async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable
ResponseId = chunk.Id,
ModelId = chunk.Model,
CreatedAt = chunk.Created?.ToDateTimeOffset(),
RawRepresentation = chunk,
FinishReason = output.FinishReason != FinishReason.ReasonInvalid ? MapFinishReason(output.FinishReason) : null,
};

if (chunk.Citations is { Count: > 0 } citations)
var citations = chunk.Citations?.Distinct().Select(MapCitation).ToList<AIAnnotation>();
if (citations?.Count > 0)
{
var textContent = update.Contents.OfType<TextContent>().FirstOrDefault();
if (textContent == null)
{
textContent = new TextContent(string.Empty);
update.Contents.Add(textContent);
}

foreach (var citation in citations.Distinct())
(textContent.Annotations ??= []).Add(MapCitation(citation));
((List<AIAnnotation>)(textContent.Annotations ??= [])).AddRange(citations);
}

foreach (var toolCall in output.Delta.ToolCalls)
update.Contents.Add(MapToolCall(toolCall));
((List<AIContent>)update.Contents).AddRange(output.Delta.ToolCalls.AsContents(text, citations));

if (update.Contents.Any())
yield return update;
}
}
}

static CitationAnnotation MapInlineCitation(InlineCitation citation) => citation.CitationCase switch
{
InlineCitation.CitationOneofCase.WebCitation => new CitationAnnotation { Url = new(citation.WebCitation.Url) },
InlineCitation.CitationOneofCase.XCitation => new CitationAnnotation { Url = new(citation.XCitation.Url) },
InlineCitation.CitationOneofCase.CollectionsCitation => new CitationAnnotation
{
FileId = citation.CollectionsCitation.FileId,
Snippet = citation.CollectionsCitation.ChunkContent,
ToolName = "file_search",
},
_ => new CitationAnnotation()
};

static CitationAnnotation MapCitation(string citation)
{
var url = new Uri(citation);
Expand All @@ -210,12 +139,13 @@ static CitationAnnotation MapCitation(string citation)
var file = url.AbsolutePath[7..];
return new CitationAnnotation
{
ToolName = "collections_search",
FileId = file,
AdditionalProperties = new AdditionalPropertiesDictionary
{
{ "collection_id", collection }
}
{
{ "collection_id", collection }
},
FileId = file,
ToolName = "collections_search",
Url = new Uri($"collections://{collection}/files/{file}"),
};
}

Expand Down
Loading