Skip to content

Commit f017650

Browse files
committed
Create a smarter OpenAI chat client that honors model ID
When using `IChatClient` API, the `ChatOptions.ModelId` is ignored by the built-in OpenAI implementation obtained by invoking `OpenAIClient.GetChatClient("model").AsIChatClient`. This is because the `GetChatClient` stores the model passed in and doesn't respect the options-specified one after initial creation. We now align our own implementation of `IChatClient` for OpenAI so it behaves like Grok's, allowing more flexible usage while following the natural expectation set by having a writable `ChatOptions.ModelId` in the first place :).
1 parent 229c42c commit f017650

File tree

5 files changed

+215
-44
lines changed

5 files changed

+215
-44
lines changed

src/AI.Tests/AI.Tests.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
<PropertyGroup>
44
<TargetFramework>net10.0</TargetFramework>
55
<NoWarn>OPENAI001;$(NoWarn)</NoWarn>
6+
<LangVersion>Preview</LangVersion>
67
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
78
</PropertyGroup>
89

src/AI.Tests/Extensions/PipelineTestOutput.cs

Lines changed: 61 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,57 +12,86 @@ public static class PipelineTestOutput
1212
/// </summary>
1313
/// <typeparam name="TOptions">The options type to configure for HTTP logging.</typeparam>
1414
/// <param name="pipelineOptions">The options instance to configure.</param>
15+
/// <param name="output">The test output helper to write to.</param>
16+
/// <param name="onRequest">A callback to process the <see cref="JsonNode"/> that was sent.</param>
17+
/// <param name="onResponse">A callback to process the <see cref="JsonNode"/> that was received.</param>
1518
/// <remarks>
1619
/// NOTE: this is the lowst-level logging after all chat pipeline processing has been done.
1720
/// <para>
1821
/// If the options already provide a transport, it will be wrapped with the console
1922
/// logging transport to minimize the impact on existing configurations.
2023
/// </para>
2124
/// </remarks>
22-
public static TOptions UseTestOutput<TOptions>(this TOptions pipelineOptions, ITestOutputHelper output)
25+
public static TOptions WriteTo<TOptions>(this TOptions pipelineOptions, ITestOutputHelper? output = default, Action<JsonNode>? onRequest = default, Action<JsonNode>? onResponse = default)
2326
where TOptions : ClientPipelineOptions
2427
{
25-
pipelineOptions.Transport = new TestPipelineTransport(pipelineOptions.Transport ?? HttpClientPipelineTransport.Shared, output);
26-
28+
pipelineOptions.AddPolicy(new TestOutputPolicy(output ?? NullTestOutputHelper.Default, onRequest, onResponse), PipelinePosition.BeforeTransport);
2729
return pipelineOptions;
2830
}
29-
}
3031

31-
public class TestPipelineTransport(PipelineTransport inner, ITestOutputHelper? output = null) : PipelineTransport
32-
{
33-
static readonly JsonSerializerOptions options = new JsonSerializerOptions(JsonSerializerDefaults.General)
32+
class NullTestOutputHelper : ITestOutputHelper
3433
{
35-
WriteIndented = true,
36-
};
37-
38-
public List<JsonNode> Requests { get; } = [];
39-
public List<JsonNode> Responses { get; } = [];
34+
public static ITestOutputHelper Default { get; } = new NullTestOutputHelper();
35+
NullTestOutputHelper() { }
36+
public void WriteLine(string message) { }
37+
public void WriteLine(string format, params object[] args) { }
38+
}
4039

41-
protected override async ValueTask ProcessCoreAsync(PipelineMessage message)
40+
class TestOutputPolicy(ITestOutputHelper output, Action<JsonNode>? onRequest = default, Action<JsonNode>? onResponse = default) : PipelinePolicy
4241
{
43-
message.BufferResponse = true;
44-
await inner.ProcessAsync(message);
42+
static readonly JsonSerializerOptions options = new JsonSerializerOptions(JsonSerializerDefaults.General)
43+
{
44+
WriteIndented = true,
45+
};
4546

46-
if (message.Request.Content is not null)
47+
public override void Process(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
4748
{
48-
using var memory = new MemoryStream();
49-
message.Request.Content.WriteTo(memory);
50-
memory.Position = 0;
51-
using var reader = new StreamReader(memory);
52-
var content = await reader.ReadToEndAsync();
53-
var node = JsonNode.Parse(content);
54-
Requests.Add(node!);
55-
output?.WriteLine(node!.ToJsonString(options));
49+
message.BufferResponse = true;
50+
ProcessNext(message, pipeline, currentIndex);
51+
52+
if (message.Request.Content is not null)
53+
{
54+
using var memory = new MemoryStream();
55+
message.Request.Content.WriteTo(memory);
56+
memory.Position = 0;
57+
using var reader = new StreamReader(memory);
58+
var content = reader.ReadToEnd();
59+
var node = JsonNode.Parse(content);
60+
onRequest?.Invoke(node!);
61+
output?.WriteLine(node!.ToJsonString(options));
62+
}
63+
64+
if (message.Response != null)
65+
{
66+
var node = JsonNode.Parse(message.Response.Content.ToString());
67+
onResponse?.Invoke(node!);
68+
output?.WriteLine(node!.ToJsonString(options));
69+
}
5670
}
5771

58-
if (message.Response != null)
72+
public override async ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
5973
{
60-
var node = JsonNode.Parse(message.Response.Content.ToString());
61-
Responses.Add(node!);
62-
output?.WriteLine(node!.ToJsonString(options));
74+
message.BufferResponse = true;
75+
await ProcessNextAsync(message, pipeline, currentIndex);
76+
77+
if (message.Request.Content is not null)
78+
{
79+
using var memory = new MemoryStream();
80+
message.Request.Content.WriteTo(memory);
81+
memory.Position = 0;
82+
using var reader = new StreamReader(memory);
83+
var content = await reader.ReadToEndAsync();
84+
var node = JsonNode.Parse(content);
85+
onRequest?.Invoke(node!);
86+
output?.WriteLine(node!.ToJsonString(options));
87+
}
88+
89+
if (message.Response != null)
90+
{
91+
var node = JsonNode.Parse(message.Response.Content.ToString());
92+
onResponse?.Invoke(node!);
93+
output?.WriteLine(node!.ToJsonString(options));
94+
}
6395
}
6496
}
65-
66-
protected override PipelineMessage CreateMessageCore() => inner.CreateMessage();
67-
protected override void ProcessCore(PipelineMessage message) => inner.Process(message);
68-
}
97+
}

src/AI.Tests/GrokTests.cs

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using System.ClientModel.Primitives;
2-
using System.Text.Json.Nodes;
1+
using System.Text.Json.Nodes;
32
using Microsoft.Extensions.AI;
43
using static ConfigurationExtensions;
54

@@ -49,9 +48,11 @@ public async Task GrokInvokesToolAndSearch()
4948
{ "user", "What's Tesla stock worth today?" },
5049
};
5150

52-
var transport = new TestPipelineTransport(HttpClientPipelineTransport.Shared, output);
51+
var requests = new List<JsonNode>();
52+
var responses = new List<JsonNode>();
5353

54-
var grok = new GrokChatClient(Configuration["XAI_API_KEY"]!, "grok-3", new OpenAI.OpenAIClientOptions() { Transport = transport })
54+
var grok = new GrokChatClient(Configuration["XAI_API_KEY"]!, "grok-3",
55+
new OpenAI.OpenAIClientOptions().WriteTo(output, requests.Add, responses.Add))
5556
.AsBuilder()
5657
.UseFunctionInvocation()
5758
.Build();
@@ -69,7 +70,7 @@ public async Task GrokInvokesToolAndSearch()
6970
// "search_parameters": {
7071
// "mode": "on"
7172
//}
72-
Assert.All(transport.Requests, x =>
73+
Assert.All(requests, x =>
7374
{
7475
var search = Assert.IsType<JsonObject>(x["search_parameters"]);
7576
Assert.Equal("on", search["mode"]?.GetValue<string>());
@@ -79,7 +80,7 @@ public async Task GrokInvokesToolAndSearch()
7980
Assert.Contains(response.Messages, x => x.Role == ChatRole.Tool);
8081

8182
// Citations include nasdaq.com at least as a web search source
82-
var node = transport.Responses.LastOrDefault();
83+
var node = responses.LastOrDefault();
8384
Assert.NotNull(node);
8485
var citations = Assert.IsType<JsonArray>(node["citations"], false);
8586
var yahoo = citations.Where(x => x != null).Any(x => x!.ToString().Contains("https://finance.yahoo.com/quote/TSLA/", StringComparison.Ordinal));
@@ -100,16 +101,18 @@ public async Task GrokInvokesHostedSearchTool()
100101
{ "user", "What's Tesla stock worth today? Search X and the news for latest info." },
101102
};
102103

103-
var transport = new TestPipelineTransport(HttpClientPipelineTransport.Shared, output);
104+
var requests = new List<JsonNode>();
105+
var responses = new List<JsonNode>();
104106

105-
var chat = new GrokChatClient(Configuration["XAI_API_KEY"]!, "grok-3", new OpenAI.OpenAIClientOptions() { Transport = transport });
107+
var grok = new GrokChatClient(Configuration["XAI_API_KEY"]!, "grok-3",
108+
new OpenAI.OpenAIClientOptions().WriteTo(output, requests.Add, responses.Add));
106109

107110
var options = new ChatOptions
108111
{
109112
Tools = [new HostedWebSearchTool()]
110113
};
111114

112-
var response = await chat.GetResponseAsync(messages, options);
115+
var response = await grok.GetResponseAsync(messages, options);
113116
var text = response.Text;
114117

115118
Assert.Contains("TSLA", text);
@@ -118,15 +121,15 @@ public async Task GrokInvokesHostedSearchTool()
118121
// "search_parameters": {
119122
// "mode": "auto"
120123
//}
121-
Assert.All(transport.Requests, x =>
124+
Assert.All(requests, x =>
122125
{
123126
var search = Assert.IsType<JsonObject>(x["search_parameters"]);
124127
Assert.Equal("auto", search["mode"]?.GetValue<string>());
125128
});
126129

127130
// Citations include nasdaq.com at least as a web search source
128-
Assert.Single(transport.Responses);
129-
var node = transport.Responses[0];
131+
Assert.Single(responses);
132+
var node = responses[0];
130133
Assert.NotNull(node);
131134
var citations = Assert.IsType<JsonArray>(node["citations"], false);
132135
var yahoo = citations.Where(x => x != null).Any(x => x!.ToString().Contains("https://finance.yahoo.com/quote/TSLA/", StringComparison.Ordinal));

src/AI.Tests/OpenAITests.cs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using System.Text.Json.Nodes;
2+
using Microsoft.Extensions.AI;
3+
using static ConfigurationExtensions;
4+
5+
namespace Devlooped.Extensions.AI;
6+
7+
public class OpenAITests(ITestOutputHelper output)
8+
{
9+
[SecretsFact("OPENAI_API_KEY")]
10+
public async Task OpenAISwitchesModel()
11+
{
12+
var messages = new Chat()
13+
{
14+
{ "user", "What products does Tesla make?" },
15+
};
16+
17+
var chat = new OpenAIChatClient(Configuration["OPENAI_API_KEY"]!, "gpt-4.1-nano", new OpenAI.OpenAIClientOptions().WriteTo(output));
18+
19+
var options = new ChatOptions
20+
{
21+
ModelId = "gpt-4.1-mini",
22+
};
23+
24+
var response = await chat.GetResponseAsync(messages, options);
25+
26+
// NOTE: the chat client was requested as grok-3 but the chat options wanted a
27+
// different model and the grok client honors that choice.
28+
Assert.StartsWith("gpt-4.1-mini", response.ModelId);
29+
}
30+
31+
[SecretsFact("OPENAI_API_KEY")]
32+
public async Task OpenAIThinks()
33+
{
34+
var messages = new Chat()
35+
{
36+
{ "system", "You are an intelligent AI assistant that's an expert on financial matters." },
37+
{ "user", "If you have a debt of 100k and accumulate a compounding 5% debt on top of it every year, how long before you are a negative millonaire? (round up to full integer value)" },
38+
};
39+
40+
var requests = new List<JsonNode>();
41+
42+
var chat = new OpenAIChatClient(Configuration["OPENAI_API_KEY"]!, "o3-mini", new OpenAI.OpenAIClientOptions()
43+
.WriteTo(output, requests.Add));
44+
45+
var options = new ChatOptions
46+
{
47+
ModelId = "o4-mini",
48+
ReasoningEffort = ReasoningEffort.Medium
49+
};
50+
51+
var response = await chat.GetResponseAsync(messages, options);
52+
53+
var text = response.Text;
54+
55+
Assert.Contains("48 years", text);
56+
// NOTE: the chat client was requested as grok-3 but the chat options wanted a
57+
// different model and the grok client honors that choice.
58+
Assert.StartsWith("o4-mini", response.ModelId);
59+
60+
// Reasoning should have been set to medium
61+
Assert.All(requests, x =>
62+
{
63+
var search = Assert.IsType<JsonObject>(x["reasoning"]);
64+
Assert.Equal("medium", search["effort"]?.GetValue<string>());
65+
});
66+
}
67+
}

src/AI/OpenAI/OpenAIChatClient.cs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using System.ClientModel;
2+
using System.ClientModel.Primitives;
3+
using System.Collections.Concurrent;
4+
using Microsoft.Extensions.AI;
5+
using OpenAI;
6+
using OpenAI.Responses;
7+
8+
namespace Devlooped.Extensions.AI;
9+
10+
/// <summary>
11+
/// An <see cref="IChatClient"/> implementation for OpenAI.
12+
/// </summary>
13+
public class OpenAIChatClient : IChatClient
14+
{
15+
readonly ConcurrentDictionary<string, IChatClient> clients = new();
16+
readonly string modelId;
17+
readonly ClientPipeline pipeline;
18+
readonly OpenAIClientOptions? options;
19+
20+
/// <summary>
21+
/// Initializes the client with the specified API key, model ID, and optional OpenAI client options.
22+
/// </summary>
23+
public OpenAIChatClient(string apiKey, string modelId, OpenAIClientOptions? options = default)
24+
{
25+
this.modelId = modelId;
26+
this.options = options;
27+
28+
// NOTE: by caching the pipeline, we speed up creation of new chat clients per model,
29+
// since the pipeline will be the same for all of them.
30+
pipeline = new OpenAIClient(new ApiKeyCredential(apiKey), options).Pipeline;
31+
}
32+
33+
/// <inheritdoc/>
34+
public Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellation = default)
35+
=> GetChatClient(options?.ModelId ?? modelId).GetResponseAsync(messages, SetOptions(options), cancellation);
36+
37+
/// <inheritdoc/>
38+
public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellation = default)
39+
=> GetChatClient(options?.ModelId ?? modelId).GetStreamingResponseAsync(messages, SetOptions(options), cancellation);
40+
41+
IChatClient GetChatClient(string modelId) => clients.GetOrAdd(modelId, model
42+
=> new PipelineClient(pipeline, options).GetOpenAIResponseClient(modelId).AsIChatClient());
43+
44+
static ChatOptions? SetOptions(ChatOptions? options)
45+
{
46+
if (options is null)
47+
return null;
48+
49+
if (options.ReasoningEffort is ReasoningEffort effort)
50+
{
51+
options.RawRepresentationFactory = _ => new ResponseCreationOptions
52+
{
53+
ReasoningOptions = new ResponseReasoningOptions(effort switch
54+
{
55+
ReasoningEffort.High => ResponseReasoningEffortLevel.High,
56+
ReasoningEffort.Medium => ResponseReasoningEffortLevel.Medium,
57+
_ => ResponseReasoningEffortLevel.Low
58+
})
59+
};
60+
}
61+
62+
return options;
63+
}
64+
65+
void IDisposable.Dispose() { }
66+
67+
public object? GetService(Type serviceType, object? serviceKey = null) => null;
68+
69+
// Allows creating the base OpenAIClient with a pre-created pipeline.
70+
class PipelineClient(ClientPipeline pipeline, OpenAIClientOptions? options) : OpenAIClient(pipeline, options) { }
71+
}

0 commit comments

Comments
 (0)