Skip to content

Commit cb9c443

Browse files
committed
Improve integration with MEAI typical usage pattern
In MEAI, the expectation is that you get chat clients for specific models from the same OpenAI client. There's no concept of a default model at that level. This aligns the implementation with that approach. In addition, we add an analyzer that ensures the user can still invoke `GetChatClient(model).AsIChatClient()` while getting the proper behavior which in our `GrokClient` case is that you should use the returned client from `GetChatClient(model)` directly without an additional wrapping done by the MEAI `AsIChatClient`, since we implement the `IChatClient` in the returned client directly. This allows us to cache and use multiple chat clients dynamically matching the model in each request's ChatOptions. This also fixes a bug where we were not properly setting the search mode to On instead of Auto when using Tools.
1 parent ac3b656 commit cb9c443

15 files changed

+326
-91
lines changed

Extensions.AI.sln

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Weaving", "src\Weaving\Weav
1313
EndProject
1414
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Samples", "src\Samples\Samples.csproj", "{4B78F0E3-E03B-4283-AB0B-B1D76CAEF1BC}"
1515
EndProject
16+
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AI.CodeAnalysis", "src\AI.CodeAnalysis\AI.CodeAnalysis.csproj", "{F6A9F74B-5C63-4C53-9745-F00BE40AF8C8}"
17+
EndProject
1618
Global
1719
GlobalSection(SolutionConfigurationPlatforms) = preSolution
1820
Debug|Any CPU = Debug|Any CPU
@@ -83,6 +85,18 @@ Global
8385
{4B78F0E3-E03B-4283-AB0B-B1D76CAEF1BC}.Release|x64.Build.0 = Release|Any CPU
8486
{4B78F0E3-E03B-4283-AB0B-B1D76CAEF1BC}.Release|x86.ActiveCfg = Release|Any CPU
8587
{4B78F0E3-E03B-4283-AB0B-B1D76CAEF1BC}.Release|x86.Build.0 = Release|Any CPU
88+
{F6A9F74B-5C63-4C53-9745-F00BE40AF8C8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
89+
{F6A9F74B-5C63-4C53-9745-F00BE40AF8C8}.Debug|Any CPU.Build.0 = Debug|Any CPU
90+
{F6A9F74B-5C63-4C53-9745-F00BE40AF8C8}.Debug|x64.ActiveCfg = Debug|Any CPU
91+
{F6A9F74B-5C63-4C53-9745-F00BE40AF8C8}.Debug|x64.Build.0 = Debug|Any CPU
92+
{F6A9F74B-5C63-4C53-9745-F00BE40AF8C8}.Debug|x86.ActiveCfg = Debug|Any CPU
93+
{F6A9F74B-5C63-4C53-9745-F00BE40AF8C8}.Debug|x86.Build.0 = Debug|Any CPU
94+
{F6A9F74B-5C63-4C53-9745-F00BE40AF8C8}.Release|Any CPU.ActiveCfg = Release|Any CPU
95+
{F6A9F74B-5C63-4C53-9745-F00BE40AF8C8}.Release|Any CPU.Build.0 = Release|Any CPU
96+
{F6A9F74B-5C63-4C53-9745-F00BE40AF8C8}.Release|x64.ActiveCfg = Release|Any CPU
97+
{F6A9F74B-5C63-4C53-9745-F00BE40AF8C8}.Release|x64.Build.0 = Release|Any CPU
98+
{F6A9F74B-5C63-4C53-9745-F00BE40AF8C8}.Release|x86.ActiveCfg = Release|Any CPU
99+
{F6A9F74B-5C63-4C53-9745-F00BE40AF8C8}.Release|x86.Build.0 = Release|Any CPU
86100
EndGlobalSection
87101
GlobalSection(SolutionProperties) = preSolution
88102
HideSolutionNode = FALSE
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<TargetFramework>netstandard2.0</TargetFramework>
5+
<IsRoslynComponent>true</IsRoslynComponent>
6+
<PackFolder>analyzers/dotnet/roslyn4.0/cs</PackFolder>
7+
</PropertyGroup>
8+
9+
<ItemGroup>
10+
<EmbeddedResource Include="..\AI\ChatClientExtensions.cs" Link="ChatClientExtensions.cs" />
11+
</ItemGroup>
12+
13+
<ItemGroup>
14+
<PackageReference Include="NuGetizer" Version="1.2.4" PrivateAssets="all" />
15+
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.0.1" Pack="false" />
16+
<PackageReference Include="PolySharp" Version="1.15.0" PrivateAssets="All" />
17+
<PackageReference Include="ThisAssembly.Resources" Version="2.0.14" PrivateAssets="all" />
18+
</ItemGroup>
19+
20+
</Project>
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System.Text;
2+
using Microsoft.CodeAnalysis;
3+
using Microsoft.CodeAnalysis.Text;
4+
5+
namespace Devlooped.Extensions.AI;
6+
7+
/// <summary>
8+
/// This generator produces the <see cref="ChatClientExtensions"/> source code so that it
9+
/// exists in the user's target compilation and can successfully overload (and override)
10+
/// the <c>OpenAIClientExtensions.AsIChatClient</c> that would otherwise be used. We
11+
/// need this to ensure that the <see cref="ChatClient"/> can be used directly as an
12+
/// <c>IChatClient</c> instead of wrapping it in the M.E.AI.OpenAI adapter.
13+
/// </summary>
14+
[Generator(LanguageNames.CSharp)]
15+
public class ChatClientExtensionsGenerator : IIncrementalGenerator
16+
{
17+
public void Initialize(IncrementalGeneratorInitializationContext context)
18+
{
19+
context.RegisterSourceOutput(context.CompilationProvider,
20+
(spc, _) =>
21+
{
22+
spc.AddSource(
23+
$"{nameof(ThisAssembly.Resources.ChatClientExtensions)}.g.cs",
24+
SourceText.From(ThisAssembly.Resources.ChatClientExtensions.Text, Encoding.UTF8));
25+
});
26+
}
27+
}

src/AI.Tests/AI.Tests.csproj

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@
33
<PropertyGroup>
44
<TargetFrameworks>net8.0;net10.0</TargetFrameworks>
55
<NoWarn>OPENAI001;$(NoWarn)</NoWarn>
6+
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
67
</PropertyGroup>
78

9+
<ItemGroup>
10+
<Compile Include="..\AI\ChatClientExtensions.cs" Link="ChatClientExtensions.cs" />
11+
</ItemGroup>
12+
813
<ItemGroup>
914
<PackageReference Include="coverlet.collector" Version="6.0.4" />
1015
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" Version="9.6.0-preview.1.25310.2" />
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
using System.ClientModel.Primitives;
2+
using System.Text.Json;
3+
using System.Text.Json.Nodes;
4+
5+
namespace Devlooped.Extensions.AI;
6+
7+
public static class PipelineTestOutput
8+
{
9+
/// <summary>
10+
/// Sets a <see cref="ClientPipelineOptions.Transport"/> that renders HTTP messages to the
11+
/// console using Spectre.Console rich JSON formatting, but only if the console is interactive.
12+
/// </summary>
13+
/// <typeparam name="TOptions">The options type to configure for HTTP logging.</typeparam>
14+
/// <param name="pipelineOptions">The options instance to configure.</param>
15+
/// <remarks>
16+
/// NOTE: this is the lowst-level logging after all chat pipeline processing has been done.
17+
/// <para>
18+
/// If the options already provide a transport, it will be wrapped with the console
19+
/// logging transport to minimize the impact on existing configurations.
20+
/// </para>
21+
/// </remarks>
22+
public static TOptions UseTestOutput<TOptions>(this TOptions pipelineOptions, ITestOutputHelper output)
23+
where TOptions : ClientPipelineOptions
24+
{
25+
pipelineOptions.Transport = new TestPipelineTransport(pipelineOptions.Transport ?? HttpClientPipelineTransport.Shared, output);
26+
27+
return pipelineOptions;
28+
}
29+
}
30+
31+
public class TestPipelineTransport(PipelineTransport inner, ITestOutputHelper? output = null) : PipelineTransport
32+
{
33+
static readonly JsonSerializerOptions options = new JsonSerializerOptions(JsonSerializerDefaults.General)
34+
{
35+
WriteIndented = true,
36+
};
37+
38+
public List<JsonNode> Requests { get; } = [];
39+
public List<JsonNode> Responses { get; } = [];
40+
41+
protected override async ValueTask ProcessCoreAsync(PipelineMessage message)
42+
{
43+
message.BufferResponse = true;
44+
await inner.ProcessAsync(message);
45+
46+
if (message.Request.Content is not null)
47+
{
48+
using var memory = new MemoryStream();
49+
message.Request.Content.WriteTo(memory);
50+
memory.Position = 0;
51+
using var reader = new StreamReader(memory);
52+
var content = await reader.ReadToEndAsync();
53+
var node = JsonNode.Parse(content);
54+
Requests.Add(node!);
55+
output?.WriteLine(node!.ToJsonString(options));
56+
}
57+
58+
if (message.Response != null)
59+
{
60+
var node = JsonNode.Parse(message.Response.Content.ToString());
61+
Responses.Add(node!);
62+
output?.WriteLine(node!.ToJsonString(options));
63+
}
64+
}
65+
66+
protected override PipelineMessage CreateMessageCore() => inner.CreateMessage();
67+
protected override void ProcessCore(PipelineMessage message) => inner.Process(message);
68+
}

src/AI.Tests/GrokTests.cs

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
namespace Devlooped.Extensions.AI;
2-
1+
using System.ClientModel.Primitives;
2+
using System.Text.Json.Nodes;
33
using Microsoft.Extensions.AI;
44
using static ConfigurationExtensions;
55

6-
public class GrokTests
6+
namespace Devlooped.Extensions.AI;
7+
8+
public class GrokTests(ITestOutputHelper output)
79
{
810
[SecretsFact("XAI_API_KEY")]
911
public async Task GrokInvokesTools()
@@ -23,12 +25,18 @@ public async Task GrokInvokesTools()
2325
Tools = [AIFunctionFactory.Create(() => DateTimeOffset.Now.ToString("O"), "get_date")]
2426
};
2527

26-
var response = await grok.GetResponseAsync(messages, options);
28+
var client = grok.GetChatClient("grok-3");
29+
var chat = Assert.IsType<IChatClient>(client, false);
30+
31+
var response = await chat.GetResponseAsync(messages, options);
2732
var getdate = response.Messages
2833
.SelectMany(x => x.Contents.OfType<FunctionCallContent>())
2934
.Any(x => x.Name == "get_date");
3035

3136
Assert.True(getdate);
37+
// NOTE: the chat client was requested as grok-3 but the chat options wanted a
38+
// different model and the grok client honors that choice.
39+
Assert.Equal("grok-3-mini", response.ModelId);
3240
}
3341

3442
[SecretsFact("XAI_API_KEY")]
@@ -40,7 +48,11 @@ public async Task GrokInvokesToolAndSearch()
4048
{ "user", "What's Tesla stock worth today?" },
4149
};
4250

43-
var grok = new GrokClient(Configuration["XAI_API_KEY"]!)
51+
var transport = new TestPipelineTransport(HttpClientPipelineTransport.Shared, output);
52+
53+
var grok = new GrokClient(Configuration["XAI_API_KEY"]!, new OpenAI.OpenAIClientOptions() { Transport = transport })
54+
.GetChatClient("grok-3")
55+
.AsIChatClient()
4456
.AsBuilder()
4557
.UseFunctionInvocation()
4658
.Build();
@@ -54,14 +66,30 @@ public async Task GrokInvokesToolAndSearch()
5466

5567
var response = await grok.GetResponseAsync(messages, options);
5668

69+
// assert that the request contains the following node
70+
// "search_parameters": {
71+
// "mode": "on"
72+
//}
73+
Assert.All(transport.Requests, x =>
74+
{
75+
var search = Assert.IsType<JsonObject>(x["search_parameters"]);
76+
Assert.Equal("on", search["mode"]?.GetValue<string>());
77+
});
78+
5779
// The get_date result shows up as a tool role
5880
Assert.Contains(response.Messages, x => x.Role == ChatRole.Tool);
5981

60-
var text = response.Text;
82+
// Citations include nasdaq.com at least as a web search source
83+
var node = transport.Responses.LastOrDefault();
84+
Assert.NotNull(node);
85+
var citations = Assert.IsType<JsonArray>(node["citations"], false);
86+
var yahoo = citations.Where(x => x != null).Any(x => x!.ToString().Contains("https://finance.yahoo.com/quote/TSLA/", StringComparison.Ordinal));
6187

62-
Assert.Contains("TSLA", text);
63-
Assert.Contains("$", text);
64-
Assert.Contains("Nasdaq", text, StringComparison.OrdinalIgnoreCase);
88+
Assert.True(yahoo, "Expected at least one citation to nasdaq.com");
89+
90+
// NOTE: the chat client was requested as grok-3 but the chat options wanted a
91+
// different model and the grok client honors that choice.
92+
Assert.Equal("grok-3-mini", response.ModelId);
6593
}
6694

6795
[SecretsFact("XAI_API_KEY")]
@@ -73,20 +101,43 @@ public async Task GrokInvokesHostedSearchTool()
73101
{ "user", "What's Tesla stock worth today? Search X and the news for latest info." },
74102
};
75103

76-
var grok = new GrokClient(Configuration["XAI_API_KEY"]!);
104+
var transport = new TestPipelineTransport(HttpClientPipelineTransport.Shared, output);
105+
106+
var grok = new GrokClient(Configuration["XAI_API_KEY"]!, new OpenAI.OpenAIClientOptions() { Transport = transport });
107+
var client = grok.GetChatClient("grok-3");
108+
var chat = Assert.IsType<IChatClient>(client, false);
77109

78110
var options = new ChatOptions
79111
{
80-
ModelId = "grok-3",
81112
Tools = [new HostedWebSearchTool()]
82113
};
83114

84-
var response = await grok.GetResponseAsync(messages, options);
115+
var response = await chat.GetResponseAsync(messages, options);
85116
var text = response.Text;
86117

87118
Assert.Contains("TSLA", text);
88-
Assert.Contains("$", text);
89-
Assert.Contains("Nasdaq", text, StringComparison.OrdinalIgnoreCase);
119+
120+
// assert that the request contains the following node
121+
// "search_parameters": {
122+
// "mode": "auto"
123+
//}
124+
Assert.All(transport.Requests, x =>
125+
{
126+
var search = Assert.IsType<JsonObject>(x["search_parameters"]);
127+
Assert.Equal("auto", search["mode"]?.GetValue<string>());
128+
});
129+
130+
// Citations include nasdaq.com at least as a web search source
131+
Assert.Single(transport.Responses);
132+
var node = transport.Responses[0];
133+
Assert.NotNull(node);
134+
var citations = Assert.IsType<JsonArray>(node["citations"], false);
135+
var yahoo = citations.Where(x => x != null).Any(x => x!.ToString().Contains("https://finance.yahoo.com/quote/TSLA/", StringComparison.Ordinal));
136+
137+
Assert.True(yahoo, "Expected at least one citation to nasdaq.com");
138+
139+
// Uses the default model set by the client when we asked for it
140+
Assert.Equal("grok-3", response.ModelId);
90141
}
91142

92143
[SecretsFact("XAI_API_KEY")]
@@ -99,6 +150,8 @@ public async Task GrokThinksHard()
99150
};
100151

101152
var grok = new GrokClient(Configuration["XAI_API_KEY"]!)
153+
.GetChatClient("grok-3")
154+
.AsIChatClient()
102155
.AsBuilder()
103156
.UseFunctionInvocation()
104157
.Build();
@@ -115,5 +168,8 @@ public async Task GrokThinksHard()
115168
var text = response.Text;
116169

117170
Assert.Contains("48 years", text);
171+
// NOTE: the chat client was requested as grok-3 but the chat options wanted a
172+
// different model and the grok client honors that choice.
173+
Assert.StartsWith("grok-3-mini", response.ModelId);
118174
}
119175
}

src/AI/AI.csproj

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,21 @@
88
</PropertyGroup>
99

1010
<ItemGroup>
11-
<PackageReference Include="Microsoft.Extensions.Configuration.Abstractions" Version="9.0.6" />
1211
<PackageReference Include="NuGetizer" Version="1.2.4" PrivateAssets="all" />
1312
<PackageReference Include="Microsoft.Extensions.AI" Version="9.6.0" />
1413
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" Version="9.6.0-preview.1.25310.2" />
14+
<PackageReference Include="Microsoft.Extensions.Configuration.Abstractions" Version="9.0.6" />
1515
<PackageReference Include="OpenAI" Version="2.2.0-beta.4" />
1616
<PackageReference Include="Spectre.Console" Version="0.50.0" />
1717
<PackageReference Include="Spectre.Console.Json" Version="0.50.0" />
1818
</ItemGroup>
1919

2020
<ItemGroup>
21+
<ProjectReference Include="..\AI.CodeAnalysis\AI.CodeAnalysis.csproj" ReferenceOutputAssembly="false" />
22+
</ItemGroup>
23+
24+
<ItemGroup>
25+
<None Update="Devlooped.Extensions.AI.targets" PackFolder="build" />
2126
<None Update="Devlooped.Extensions.AI.props" PackFolder="build" />
2227
</ItemGroup>
2328

src/AI/ChatClientExtensions.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using Microsoft.Extensions.AI;
2+
using OpenAI.Chat;
3+
4+
/// <summary>
5+
/// Smarter casting to <see cref="IChatClient"/> when the target <see cref="ChatClient"/>
6+
/// already implements the interface.
7+
/// </summary>
8+
static class ChatClientExtensions
9+
{
10+
/// <summary>Gets an <see cref="IChatClient"/> for use with this <see cref="ChatClient"/>.</summary>
11+
public static IChatClient AsIChatClient(this ChatClient client) =>
12+
client as IChatClient ?? OpenAIClientExtensions.AsIChatClient(client);
13+
}

src/AI/Console/JsonConsoleOptions.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ internal Panel CreatePanel(object value)
115115
return panel;
116116
}
117117

118+
#pragma warning disable CS9113 // Parameter is unread. BOGUS
118119
sealed class WrappedJsonText(string json, int maxWidth) : Renderable
120+
#pragma warning restore CS9113 // Parameter is unread. BOGUS
119121
{
120122
readonly JsonText jsonText = new(json);
121123

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
<Project>
2-
<PropertyGroup>
3-
<Nullable>enable</Nullable>
4-
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
5-
</PropertyGroup>
62

73
<ItemGroup>
4+
<Compile Update="@(Compile -> WithMetadataValue('NuGetPackageId', 'Devlooped.Extensions.AI'))" Visible="false" />
85
</ItemGroup>
96

107
<ItemGroup>
11-
<Using Include="Microsoft.Extensions.AI"/>
12-
<Using Include="Devlooped.Extensions.AI"/>
8+
<!--<Using Include="Devlooped.Extensions.AI"/>-->
139
</ItemGroup>
14-
10+
1511
</Project>

0 commit comments

Comments
 (0)