Skip to content

Commit 4bbfa07

Browse files
committed
Enable full Live Search compatibility for Grok
This enables the full spectrum of Live Search parameters supported by Grok as documented at https://docs.x.ai/docs/guides/live-search.
1 parent b72a5b4 commit 4bbfa07

File tree

4 files changed

+303
-37
lines changed

4 files changed

+303
-37
lines changed

src/AI.Tests/GrokTests.cs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,4 +171,69 @@ public async Task GrokThinksHard()
171171
// different model and the grok client honors that choice.
172172
Assert.StartsWith("grok-3-mini", response.ModelId);
173173
}
174+
175+
[SecretsFact("XAI_API_KEY")]
176+
public async Task GrokInvokesSpecificSearchUrl()
177+
{
178+
var messages = new Chat()
179+
{
180+
{ "system", "Sos un asistente del Cerro Catedral, usas la funcionalidad de Live Search en el sitio oficial." },
181+
{ "system", $"Hoy es {DateTime.Now.ToString("o")}" },
182+
{ "user", "Que calidad de nieve hay hoy?" },
183+
};
184+
185+
var requests = new List<JsonNode>();
186+
var responses = new List<JsonNode>();
187+
188+
var grok = new GrokChatClient(Configuration["XAI_API_KEY"]!, "grok-3", OpenAI.OpenAIClientOptions
189+
.Observable(requests.Add, responses.Add)
190+
.WriteTo(output));
191+
192+
var options = new ChatOptions
193+
{
194+
Tools = [new GrokSearchTool(GrokSearch.On)
195+
{
196+
//FromDate = new DateOnly(2025, 1, 1),
197+
//ToDate = DateOnly.FromDateTime(DateTime.Now),
198+
//MaxSearchResults = 10,
199+
Sources =
200+
[
201+
new GrokWebSource
202+
{
203+
AllowedWebsites =
204+
[
205+
"https://catedralaltapatagonia.com",
206+
"https://catedralaltapatagonia.com/parte-de-nieve/",
207+
"https://catedralaltapatagonia.com/tarifas/"
208+
]
209+
},
210+
]
211+
}]
212+
};
213+
214+
var response = await grok.GetResponseAsync(messages, options);
215+
var text = response.Text;
216+
217+
// assert that the request contains the following node
218+
// "search_parameters": {
219+
// "mode": "auto"
220+
//}
221+
Assert.All(requests, x =>
222+
{
223+
var search = Assert.IsType<JsonObject>(x["search_parameters"]);
224+
Assert.Equal("on", search["mode"]?.GetValue<string>());
225+
});
226+
227+
// Citations include catedralaltapatagonia.com at least as a web search source
228+
Assert.Single(responses);
229+
var node = responses[0];
230+
Assert.NotNull(node);
231+
var citations = Assert.IsType<JsonArray>(node["citations"], false);
232+
var catedral = citations.Where(x => x != null).Any(x => x!.ToString().Contains("catedralaltapatagonia.com", StringComparison.Ordinal));
233+
234+
Assert.True(catedral, "Expected at least one citation to catedralaltapatagonia.com");
235+
236+
// Uses the default model set by the client when we asked for it
237+
Assert.Equal("grok-3", response.ModelId);
238+
}
174239
}

src/AI/ClientPipelineExtensions.cs

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,43 +49,48 @@ class ObservePipelinePolicy(Action<JsonNode>? onRequest = default, Action<JsonNo
4949
public override void Process(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
5050
{
5151
message.BufferResponse = true;
52+
NotifyRequest(message);
5253
ProcessNext(message, pipeline, currentIndex);
53-
NotifyObservers(message);
54+
NotifyResponse(message);
5455
}
5556

5657
public override async ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
5758
{
5859
message.BufferResponse = true;
60+
NotifyRequest(message);
5961
await ProcessNextAsync(message, pipeline, currentIndex);
60-
NotifyObservers(message);
62+
NotifyResponse(message);
6163
}
6264

63-
void NotifyObservers(PipelineMessage message)
65+
void NotifyResponse(PipelineMessage message)
6466
{
65-
if (onRequest != null && message.Request.Content != null)
67+
if (onResponse != null && message.Response != null)
6668
{
67-
using var memory = new MemoryStream();
68-
message.Request.Content.WriteTo(memory);
69-
memory.Position = 0;
70-
using var reader = new StreamReader(memory);
71-
var content = reader.ReadToEnd();
7269
try
7370
{
74-
if (JsonNode.Parse(content) is { } node)
75-
onRequest.Invoke(node!);
71+
if (JsonNode.Parse(message.Response.Content.ToString()) is { } node)
72+
onResponse.Invoke(node!);
7673
}
7774
catch (JsonException)
7875
{
7976
// We ignore invalid JSON
8077
}
8178
}
79+
}
8280

83-
if (onResponse != null && message.Response != null)
81+
void NotifyRequest(PipelineMessage message)
82+
{
83+
if (onRequest != null && message.Request.Content != null)
8484
{
85+
using var memory = new MemoryStream();
86+
message.Request.Content.WriteTo(memory);
87+
memory.Position = 0;
88+
using var reader = new StreamReader(memory);
89+
var content = reader.ReadToEnd();
8590
try
8691
{
87-
if (JsonNode.Parse(message.Response.Content.ToString()) is { } node)
88-
onResponse.Invoke(node!);
92+
if (JsonNode.Parse(content) is { } node)
93+
onRequest.Invoke(node!);
8994
}
9095
catch (JsonException)
9196
{

src/AI/Grok/GrokChatClient.cs

Lines changed: 96 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using System.ClientModel;
22
using System.ClientModel.Primitives;
33
using System.Collections.Concurrent;
4+
using System.Diagnostics;
45
using System.Text.Json;
6+
using System.Text.Json.Serialization;
57
using Microsoft.Extensions.AI;
68
using OpenAI;
79

@@ -10,7 +12,7 @@ namespace Devlooped.Extensions.AI;
1012
/// <summary>
1113
/// An <see cref="IChatClient"/> implementation for Grok.
1214
/// </summary>
13-
public class GrokChatClient : IChatClient
15+
public partial class GrokChatClient : IChatClient
1416
{
1517
readonly ConcurrentDictionary<string, IChatClient> clients = new();
1618
readonly string modelId;
@@ -52,21 +54,39 @@ IChatClient GetChatClient(string modelId) => clients.GetOrAdd(modelId, model
5254
var result = new GrokCompletionOptions();
5355
var grok = options as GrokChatOptions;
5456
var search = grok?.Search;
57+
var tool = options.Tools?.OfType<GrokSearchTool>().FirstOrDefault();
58+
GrokChatWebSearchOptions? searchOptions = default;
5559

56-
if (options.Tools != null)
60+
if (search is not null && tool is null)
5761
{
58-
if (options.Tools.OfType<GrokSearchTool>().FirstOrDefault() is GrokSearchTool grokSearch)
59-
search = grokSearch.Mode;
60-
else if (options.Tools.OfType<HostedWebSearchTool>().FirstOrDefault() is HostedWebSearchTool webSearch)
61-
search = GrokSearch.Auto;
62-
63-
// Grok doesn't support any other hosted search tools, so remove remaining ones
64-
// so they don't get copied over by the OpenAI client.
65-
//options.Tools = [.. options.Tools.Where(tool => tool is not HostedWebSearchTool)];
62+
searchOptions = new GrokChatWebSearchOptions
63+
{
64+
Mode = search.Value
65+
};
66+
}
67+
else if (tool is null && options.Tools?.OfType<HostedWebSearchTool>().FirstOrDefault() is not null)
68+
{
69+
searchOptions = new GrokChatWebSearchOptions
70+
{
71+
Mode = GrokSearch.Auto
72+
};
73+
}
74+
else if (tool is not null)
75+
{
76+
searchOptions = new GrokChatWebSearchOptions
77+
{
78+
Mode = tool.Mode,
79+
FromDate = tool.FromDate,
80+
ToDate = tool.ToDate,
81+
MaxSearchResults = tool.MaxSearchResults,
82+
Sources = tool.Sources
83+
};
6684
}
6785

68-
if (search != null)
69-
result.Search = search.Value;
86+
if (searchOptions is not null)
87+
{
88+
result.WebSearchOptions = searchOptions;
89+
}
7090

7191
if (grok?.ReasoningEffort != null)
7292
{
@@ -91,19 +111,76 @@ void IDisposable.Dispose() { }
91111
// Allows creating the base OpenAIClient with a pre-created pipeline.
92112
class PipelineClient(ClientPipeline pipeline, OpenAIClientOptions options) : OpenAIClient(pipeline, options) { }
93113

94-
class GrokCompletionOptions : OpenAI.Chat.ChatCompletionOptions
114+
class GrokChatWebSearchOptions : OpenAI.Chat.ChatWebSearchOptions
115+
{
116+
public GrokSearch Mode { get; set; } = GrokSearch.Auto;
117+
public DateOnly? FromDate { get; set; }
118+
public DateOnly? ToDate { get; set; }
119+
public int? MaxSearchResults { get; set; }
120+
public IList<GrokSource>? Sources { get; set; }
121+
}
122+
123+
[JsonSourceGenerationOptions(JsonSerializerDefaults.Web,
124+
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull | JsonIgnoreCondition.WhenWritingDefault,
125+
UnmappedMemberHandling = JsonUnmappedMemberHandling.Skip,
126+
PropertyNameCaseInsensitive = true,
127+
PropertyNamingPolicy = JsonKnownNamingPolicy.SnakeCaseLower
128+
#if DEBUG
129+
, WriteIndented = true
130+
#endif
131+
)]
132+
[JsonSerializable(typeof(GrokChatWebSearchOptions))]
133+
[JsonSerializable(typeof(GrokSearch))]
134+
[JsonSerializable(typeof(GrokSource))]
135+
[JsonSerializable(typeof(GrokRssSource))]
136+
[JsonSerializable(typeof(GrokWebSource))]
137+
[JsonSerializable(typeof(GrokNewsSource))]
138+
[JsonSerializable(typeof(GrokXSource))]
139+
partial class GrokJsonContext : JsonSerializerContext
95140
{
96-
public GrokSearch Search { get; set; } = GrokSearch.Auto;
141+
static readonly Lazy<JsonSerializerOptions> options = new(CreateDefaultOptions);
97142

143+
/// <summary>
144+
/// Provides a pre-configured instance of <see cref="JsonSerializerOptions"/> that aligns with the context's settings.
145+
/// </summary>
146+
public static JsonSerializerOptions DefaultOptions { get => options.Value; }
147+
148+
static JsonSerializerOptions CreateDefaultOptions()
149+
{
150+
JsonSerializerOptions options = new(Default.Options)
151+
{
152+
WriteIndented = Debugger.IsAttached,
153+
Converters =
154+
{
155+
new JsonStringEnumConverter(new LowercaseNamingPolicy()),
156+
},
157+
};
158+
159+
options.MakeReadOnly();
160+
return options;
161+
}
162+
163+
class LowercaseNamingPolicy : JsonNamingPolicy
164+
{
165+
public override string ConvertName(string name) => name.ToLowerInvariant();
166+
}
167+
}
168+
169+
class GrokCompletionOptions : OpenAI.Chat.ChatCompletionOptions
170+
{
98171
protected override void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWriterOptions? options)
99172
{
173+
var search = WebSearchOptions as GrokChatWebSearchOptions;
174+
// This avoids writing the default `web_search_options` property
175+
WebSearchOptions = null;
176+
100177
base.JsonModelWriteCore(writer, options);
101178

102-
// "search_parameters": { "mode": "auto" }
103-
writer.WritePropertyName("search_parameters");
104-
writer.WriteStartObject();
105-
writer.WriteString("mode", Search.ToString().ToLowerInvariant());
106-
writer.WriteEndObject();
179+
if (search != null)
180+
{
181+
writer.WritePropertyName("search_parameters");
182+
JsonSerializer.Serialize(writer, search, GrokJsonContext.DefaultOptions);
183+
}
107184
}
108185
}
109186
}

0 commit comments

Comments
 (0)