Skip to content

Commit 86d8ac6

Browse files
committed
Improve tool conversion and make API public
Added comprehensive tests for existing conversions. Add support for specifying additional headers for MCP tools. Add support for setting Instructions on the collection search API in the HostedFileSearchTool. Validate that both AllowedDomains/ExcludedDomains and AllowedXHandles/ExcludedXHandles aren't used simultaneously (not supported in the API).
1 parent f1d230a commit 86d8ac6

File tree

4 files changed

+392
-77
lines changed

4 files changed

+392
-77
lines changed
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Google.Protobuf.WellKnownTypes;
5+
using Microsoft.Extensions.AI;
6+
using OpenAI.Responses;
7+
using xAI.Protocol;
8+
9+
namespace xAI;
10+
11+
public class GrokConversionTests
12+
{
13+
[Fact]
14+
public void AsTool_WithWebSearch()
15+
{
16+
var webSearch = new HostedWebSearchTool();
17+
18+
var tool = webSearch.AsProtocolTool();
19+
20+
Assert.NotNull(tool?.WebSearch);
21+
}
22+
23+
[Fact]
24+
public void AsTool_WithWebSearch_ThrowsIfAllowedAndExcluded()
25+
{
26+
var webSearch = new GrokSearchTool
27+
{
28+
AllowedDomains = ["Foo"],
29+
ExcludedDomains = ["Bar"]
30+
};
31+
32+
Assert.Throws<NotSupportedException>(() => webSearch.AsProtocolTool());
33+
}
34+
35+
[Fact]
36+
public void AsTool_WithWebSearch_AllowedDomains()
37+
{
38+
var webSearch = new GrokSearchTool
39+
{
40+
AllowedDomains = ["foo.com", "bar.com"],
41+
};
42+
43+
var tool = webSearch.AsProtocolTool();
44+
45+
Assert.NotNull(tool?.WebSearch);
46+
Assert.Equal(["foo.com", "bar.com"], tool.WebSearch.AllowedDomains);
47+
}
48+
49+
[Fact]
50+
public void AsTool_WithWebSearch_ExcludedDomains()
51+
{
52+
var webSearch = new GrokSearchTool
53+
{
54+
ExcludedDomains = ["foo.com", "bar.com"],
55+
};
56+
57+
var tool = webSearch.AsProtocolTool();
58+
59+
Assert.NotNull(tool?.WebSearch);
60+
Assert.Equal(["foo.com", "bar.com"], tool.WebSearch.ExcludedDomains);
61+
}
62+
63+
[Fact]
64+
public void AsTool_WithWebSearch_ImageUnderstanding()
65+
{
66+
var webSearch = new GrokSearchTool
67+
{
68+
EnableImageUnderstanding = true
69+
};
70+
71+
var tool = webSearch.AsProtocolTool();
72+
73+
Assert.NotNull(tool?.WebSearch);
74+
Assert.True(tool.WebSearch.EnableImageUnderstanding);
75+
}
76+
77+
[Fact]
78+
public void AsTool_WithXSearch_ThrowsIfAllowedAndExcluded()
79+
{
80+
var webSearch = new GrokXSearchTool
81+
{
82+
AllowedHandles = ["Foo"],
83+
ExcludedHandles = ["Bar"]
84+
};
85+
86+
Assert.Throws<NotSupportedException>(() => webSearch.AsProtocolTool());
87+
}
88+
89+
[Fact]
90+
public void AsTool_WithXSearch_AllowedHandles()
91+
{
92+
var webSearch = new GrokXSearchTool
93+
{
94+
AllowedHandles = ["foo", "bar"],
95+
};
96+
97+
var tool = webSearch.AsProtocolTool();
98+
99+
Assert.NotNull(tool?.XSearch);
100+
Assert.Equal(["foo", "bar"], tool.XSearch.AllowedXHandles);
101+
}
102+
103+
[Fact]
104+
public void AsTool_WithXSearch_ExcludedDomains()
105+
{
106+
var webSearch = new GrokXSearchTool
107+
{
108+
ExcludedHandles = ["foo", "bar"],
109+
};
110+
111+
var tool = webSearch.AsProtocolTool();
112+
113+
Assert.NotNull(tool?.XSearch);
114+
Assert.Equal(["foo", "bar"], tool.XSearch.ExcludedXHandles);
115+
}
116+
117+
[Fact]
118+
public void AsTool_WithXSearch_ImageUnderstanding()
119+
{
120+
var webSearch = new GrokXSearchTool
121+
{
122+
EnableImageUnderstanding = true
123+
};
124+
125+
var tool = webSearch.AsProtocolTool();
126+
127+
Assert.NotNull(tool?.XSearch);
128+
Assert.True(tool.XSearch.EnableImageUnderstanding);
129+
}
130+
131+
[Fact]
132+
public void AsTool_WithXSearch_VideoUnderstanding()
133+
{
134+
var webSearch = new GrokXSearchTool
135+
{
136+
EnableVideoUnderstanding = true
137+
};
138+
139+
var tool = webSearch.AsProtocolTool();
140+
141+
Assert.NotNull(tool?.XSearch);
142+
Assert.True(tool.XSearch.EnableVideoUnderstanding);
143+
}
144+
145+
[Fact]
146+
public void AsTool_WithXSearch_FromTo()
147+
{
148+
var webSearch = new GrokXSearchTool
149+
{
150+
FromDate = DateOnly.FromDateTime(DateTime.UtcNow.Subtract(TimeSpan.FromDays(1))),
151+
ToDate = DateOnly.FromDateTime(DateTime.UtcNow)
152+
};
153+
154+
var tool = webSearch.AsProtocolTool();
155+
156+
Assert.NotNull(tool?.XSearch);
157+
Assert.Equal(tool.XSearch.FromDate, Timestamp.FromDateTime(webSearch.FromDate.Value.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)));
158+
Assert.Equal(tool.XSearch.ToDate, Timestamp.FromDateTime(webSearch.ToDate.Value.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)));
159+
}
160+
161+
[Fact]
162+
public void AsTool_WithFunctionTool()
163+
{
164+
var functionTool = AIFunctionFactory.Create(() => "", "Name", "Description");
165+
166+
var tool = functionTool.AsProtocolTool();
167+
168+
Assert.NotNull(tool?.Function);
169+
Assert.Equal("Name", tool.Function.Name);
170+
Assert.Equal("Description", tool.Function.Description);
171+
}
172+
173+
[Fact]
174+
public void AsTool_WithCodeExecution()
175+
{
176+
var codeTool = new HostedCodeInterpreterTool();
177+
178+
var tool = codeTool.AsProtocolTool();
179+
180+
Assert.NotNull(tool?.CodeExecution);
181+
}
182+
183+
[Fact]
184+
public void AsTool_WithHostedFileSearchTool()
185+
{
186+
var collectionId = Guid.NewGuid().ToString();
187+
var instructions = "Return N/A if no results found";
188+
var fileSearch = new HostedFileSearchTool()
189+
{
190+
MaximumResultCount = 50,
191+
Inputs = [new HostedVectorStoreContent(collectionId)]
192+
}.WithInstructions(instructions);
193+
194+
var tool = fileSearch.AsProtocolTool();
195+
196+
Assert.NotNull(tool?.CollectionsSearch);
197+
Assert.Contains(collectionId, tool.CollectionsSearch.CollectionIds);
198+
Assert.Equal(50, tool.CollectionsSearch.Limit);
199+
Assert.Equal(instructions, tool.CollectionsSearch.Instructions);
200+
}
201+
202+
[Fact]
203+
public void AsTool_WithHostedMcpTool()
204+
{
205+
var accessToken = Guid.NewGuid().ToString();
206+
var headers = new Dictionary<string, string>
207+
{
208+
["foo"] = "baz"
209+
};
210+
var mcpTool = new HostedMcpServerTool("foo", "foo.com", new Dictionary<string, object?>
211+
{
212+
["x-extra"] = "bar",
213+
[nameof(MCP.ExtraHeaders)] = headers
214+
})
215+
{
216+
AllowedTools = ["list"],
217+
AuthorizationToken = accessToken,
218+
};
219+
220+
var tool = mcpTool.AsProtocolTool();
221+
222+
Assert.NotNull(tool?.Mcp);
223+
Assert.Equal("foo", tool.Mcp.ServerLabel);
224+
Assert.Equal("foo.com", tool.Mcp.ServerUrl);
225+
Assert.Contains("list", tool.Mcp.AllowedToolNames);
226+
Assert.Equal(accessToken, tool.Mcp.Authorization);
227+
Assert.Contains(KeyValuePair.Create("x-extra", "bar"), tool.Mcp.ExtraHeaders);
228+
Assert.Contains(KeyValuePair.Create("foo", "baz"), tool.Mcp.ExtraHeaders);
229+
}
230+
}

src/xAI/Extensions/ChatExtensions.cs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
using Microsoft.Extensions.AI;
1+
using System.ComponentModel;
2+
using Microsoft.Extensions.AI;
3+
using Microsoft.Extensions.Options;
4+
using xAI.Protocol;
25

36
namespace xAI;
47

58
/// <summary>Extensions for <see cref="ChatOptions"/>.</summary>
9+
[EditorBrowsable(EditorBrowsableState.Never)]
610
public static partial class ChatOptionsExtensions
711
{
812
extension(ChatOptions options)
@@ -14,4 +18,34 @@ public string? EndUserId
1418
set => (options.AdditionalProperties ??= [])["EndUserId"] = value;
1519
}
1620
}
21+
}
22+
23+
/// <summary>Grok-specific extensions for <see cref="HostedFileSearchTool"/>.</summary>
24+
[EditorBrowsable(EditorBrowsableState.Never)]
25+
public static partial class HostedFileSearchToolExtensions
26+
{
27+
extension(HostedFileSearchTool tool)
28+
{
29+
/// <summary>
30+
/// User-defined instructions to be included in the search query. Defaults to generic search
31+
/// instructions used by the collections search backend if unset.
32+
/// </summary>
33+
public HostedFileSearchTool WithInstructions(string instructions) => new(new Dictionary<string, object?>
34+
{
35+
[nameof(CollectionsSearch.Instructions)] = Throw.IfNullOrEmpty(instructions)
36+
})
37+
{
38+
Inputs = tool.Inputs,
39+
MaximumResultCount = tool.MaximumResultCount,
40+
};
41+
}
42+
}
43+
44+
static partial class AIToolExtensions
45+
{
46+
extension(AITool tool)
47+
{
48+
public T? GetProperty<T>(string name) =>
49+
tool.AdditionalProperties?.TryGetValue(name, out var value) is true && value is T typed ? typed : default;
50+
}
1751
}

src/xAI/GrokChatClient.cs

Lines changed: 2 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -328,82 +328,8 @@ codeResult.RawRepresentation is ToolCall codeToolCall &&
328328

329329
if (options?.Tools is not null)
330330
{
331-
foreach (var tool in options.Tools)
332-
{
333-
if (tool is AIFunction functionTool)
334-
{
335-
var function = new Function
336-
{
337-
Name = functionTool.Name,
338-
Description = functionTool.Description,
339-
Parameters = JsonSerializer.Serialize(functionTool.JsonSchema)
340-
};
341-
request.Tools.Add(new Tool { Function = function });
342-
}
343-
else if (tool is HostedWebSearchTool webSearchTool)
344-
{
345-
if (webSearchTool is GrokXSearchTool xSearch)
346-
{
347-
var toolProto = new XSearch
348-
{
349-
EnableImageUnderstanding = xSearch.EnableImageUnderstanding,
350-
EnableVideoUnderstanding = xSearch.EnableVideoUnderstanding,
351-
};
352-
353-
if (xSearch.AllowedHandles is { } allowed) toolProto.AllowedXHandles.AddRange(allowed);
354-
if (xSearch.ExcludedHandles is { } excluded) toolProto.ExcludedXHandles.AddRange(excluded);
355-
if (xSearch.FromDate is { } from) toolProto.FromDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(from.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc));
356-
if (xSearch.ToDate is { } to) toolProto.ToDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(to.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc));
357-
358-
request.Tools.Add(new Tool { XSearch = toolProto });
359-
}
360-
else if (webSearchTool is GrokSearchTool grokSearch)
361-
{
362-
var toolProto = new WebSearch
363-
{
364-
EnableImageUnderstanding = grokSearch.EnableImageUnderstanding,
365-
};
366-
367-
if (grokSearch.AllowedDomains is { } allowed) toolProto.AllowedDomains.AddRange(allowed);
368-
if (grokSearch.ExcludedDomains is { } excluded) toolProto.ExcludedDomains.AddRange(excluded);
369-
370-
request.Tools.Add(new Tool { WebSearch = toolProto });
371-
}
372-
else
373-
{
374-
request.Tools.Add(new Tool { WebSearch = new WebSearch() });
375-
}
376-
}
377-
else if (tool is HostedCodeInterpreterTool)
378-
{
379-
request.Tools.Add(new Tool { CodeExecution = new CodeExecution { } });
380-
}
381-
else if (tool is HostedFileSearchTool fileSearch)
382-
{
383-
var toolProto = new CollectionsSearch();
384-
385-
if (fileSearch.Inputs?.OfType<HostedVectorStoreContent>() is { } vectorStores)
386-
toolProto.CollectionIds.AddRange(vectorStores.Select(x => x.VectorStoreId).Distinct());
387-
388-
if (fileSearch.MaximumResultCount is { } maxResults)
389-
toolProto.Limit = maxResults;
390-
391-
request.Tools.Add(new Tool { CollectionsSearch = toolProto });
392-
}
393-
else if (tool is HostedMcpServerTool mcpTool)
394-
{
395-
request.Tools.Add(new Tool
396-
{
397-
Mcp = new MCP
398-
{
399-
Authorization = mcpTool.AuthorizationToken,
400-
ServerLabel = mcpTool.ServerName,
401-
ServerUrl = mcpTool.ServerAddress,
402-
AllowedToolNames = { mcpTool.AllowedTools ?? Array.Empty<string>() }
403-
}
404-
});
405-
}
406-
}
331+
foreach (var tool in options.Tools.Select(x => x.AsProtocolTool(options)))
332+
if (tool is not null) request.Tools.Add(tool);
407333
}
408334

409335
if (options?.ResponseFormat is ChatResponseFormatJson)

0 commit comments

Comments
 (0)