Skip to content

Commit 76f7da9

Browse files
.Net: Add ONNX ChatClient Extensions + UT (#12477)
# Add IChatClient Extensions for ONNX Connector ## Summary This PR implements missing AddOnnxChatClient extension methods for the ONNX Connector, providing support for the new IChatClient interface alongside the existing IChatCompletionService extensions. ## Changes Made - New Extension Methods - IServiceCollection Extensions: Added AddOnnxRuntimeGenAIChatClient method in a dedicated ServiceCollectionExtensions.DependencyInjection.cs file following the same pattern as other connectors like OpenAI - IKernelBuilder Extensions: Added AddOnnxRuntimeGenAIChatClient method in OnnxKernelBuilderExtensions.ChatClient.cs for seamless kernel configuration --------- Co-authored-by: Mark Wallace <[email protected]>
1 parent ef392d2 commit 76f7da9

File tree

4 files changed

+308
-0
lines changed

4 files changed

+308
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using System.Linq;
4+
using Microsoft.Extensions.AI;
5+
using Microsoft.Extensions.DependencyInjection;
6+
using Microsoft.SemanticKernel;
7+
using Xunit;
8+
9+
namespace SemanticKernel.Connectors.Onnx.UnitTests;
10+
11+
/// <summary>
12+
/// Unit tests for <see cref="OnnxChatClientKernelBuilderExtensions"/> and Onnx IChatClient service collection extensions.
13+
/// </summary>
14+
public class OnnxChatClientExtensionsTests
15+
{
16+
[Fact]
17+
public void AddOnnxRuntimeGenAIChatClientToServiceCollection()
18+
{
19+
// Arrange
20+
var collection = new ServiceCollection();
21+
22+
// Act
23+
collection.AddOnnxRuntimeGenAIChatClient("modelId");
24+
25+
// Assert
26+
var serviceDescriptor = collection.FirstOrDefault(x => x.ServiceType == typeof(IChatClient));
27+
Assert.NotNull(serviceDescriptor);
28+
Assert.Equal(ServiceLifetime.Singleton, serviceDescriptor.Lifetime);
29+
}
30+
31+
[Fact]
32+
public void AddOnnxRuntimeGenAIChatClientToKernelBuilder()
33+
{
34+
// Arrange
35+
var collection = new ServiceCollection();
36+
var kernelBuilder = collection.AddKernel();
37+
38+
// Act
39+
kernelBuilder.AddOnnxRuntimeGenAIChatClient("modelPath");
40+
41+
// Assert
42+
var serviceDescriptor = collection.FirstOrDefault(x => x.ServiceType == typeof(IChatClient));
43+
Assert.NotNull(serviceDescriptor);
44+
Assert.Equal(ServiceLifetime.Singleton, serviceDescriptor.Lifetime);
45+
}
46+
47+
[Fact]
48+
public void AddOnnxRuntimeGenAIChatClientWithServiceId()
49+
{
50+
// Arrange
51+
var collection = new ServiceCollection();
52+
53+
// Act
54+
collection.AddOnnxRuntimeGenAIChatClient("modelPath", serviceId: "test-service");
55+
56+
// Assert
57+
var serviceDescriptor = collection.FirstOrDefault(x => x.ServiceType == typeof(IChatClient) && x.ServiceKey?.ToString() == "test-service");
58+
Assert.NotNull(serviceDescriptor);
59+
Assert.Equal(ServiceLifetime.Singleton, serviceDescriptor.Lifetime);
60+
}
61+
62+
[Fact]
63+
public void AddOnnxRuntimeGenAIChatClientToKernelBuilderWithServiceId()
64+
{
65+
// Arrange
66+
var collection = new ServiceCollection();
67+
var kernelBuilder = collection.AddKernel();
68+
69+
// Act
70+
kernelBuilder.AddOnnxRuntimeGenAIChatClient("modelPath", serviceId: "test-service");
71+
72+
// Assert
73+
var serviceDescriptor = collection.FirstOrDefault(x => x.ServiceType == typeof(IChatClient) && x.ServiceKey?.ToString() == "test-service");
74+
Assert.NotNull(serviceDescriptor);
75+
Assert.Equal(ServiceLifetime.Singleton, serviceDescriptor.Lifetime);
76+
}
77+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using Microsoft.Extensions.AI;
4+
using Microsoft.Extensions.DependencyInjection;
5+
using Microsoft.ML.OnnxRuntimeGenAI;
6+
7+
namespace Microsoft.SemanticKernel;
8+
9+
/// <summary>Extension methods for <see cref="IKernelBuilder"/>.</summary>
10+
public static class OnnxChatClientKernelBuilderExtensions
11+
{
12+
#region Chat Client
13+
14+
/// <summary>
15+
/// Adds an OnnxRuntimeGenAI <see cref="IChatClient"/> to the <see cref="IKernelBuilder.Services"/>.
16+
/// </summary>
17+
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
18+
/// <param name="modelPath">The generative AI ONNX model path.</param>
19+
/// <param name="chatClientOptions">The optional options for the chat client.</param>
20+
/// <param name="serviceId">A local identifier for the given AI service</param>
21+
/// <returns>The same instance as <paramref name="builder"/>.</returns>
22+
public static IKernelBuilder AddOnnxRuntimeGenAIChatClient(
23+
this IKernelBuilder builder,
24+
string modelPath,
25+
OnnxRuntimeGenAIChatClientOptions? chatClientOptions = null,
26+
string? serviceId = null)
27+
{
28+
Verify.NotNull(builder);
29+
30+
builder.Services.AddOnnxRuntimeGenAIChatClient(
31+
modelPath,
32+
chatClientOptions,
33+
serviceId);
34+
35+
return builder;
36+
}
37+
38+
#endregion
39+
}

dotnet/src/Connectors/Connectors.Onnx/OnnxServiceCollectionExtensions.DependencyInjection.cs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System;
34
using System.IO;
5+
using System.Text;
46
using Microsoft.Extensions.AI;
7+
using Microsoft.Extensions.Logging;
8+
using Microsoft.ML.OnnxRuntimeGenAI;
59
using Microsoft.SemanticKernel;
610
using Microsoft.SemanticKernel.Connectors.Onnx;
711
using Microsoft.SemanticKernel.Embeddings;
@@ -57,4 +61,56 @@ public static IServiceCollection AddBertOnnxEmbeddingGenerator(
5761
serviceId,
5862
BertOnnxTextEmbeddingGenerationService.Create(onnxModelStream, vocabStream, options).AsEmbeddingGenerator());
5963
}
64+
65+
/// <summary>
66+
/// Add OnnxRuntimeGenAI Chat Client to the service collection.
67+
/// </summary>
68+
/// <param name="services">The service collection.</param>
69+
/// <param name="modelPath">The generative AI ONNX model path.</param>
70+
/// <param name="chatClientOptions">The options for the chat client.</param>
71+
/// <param name="serviceId">The optional service ID.</param>
72+
/// <returns>The updated service collection.</returns>
73+
public static IServiceCollection AddOnnxRuntimeGenAIChatClient(
74+
this IServiceCollection services,
75+
string modelPath,
76+
OnnxRuntimeGenAIChatClientOptions? chatClientOptions = null,
77+
string? serviceId = null)
78+
{
79+
Verify.NotNull(services);
80+
Verify.NotNullOrWhiteSpace(modelPath);
81+
82+
IChatClient Factory(IServiceProvider serviceProvider, object? _)
83+
{
84+
var loggerFactory = serviceProvider.GetService<ILoggerFactory>();
85+
86+
var chatClient = new OnnxRuntimeGenAIChatClient(modelPath, chatClientOptions ?? new OnnxRuntimeGenAIChatClientOptions()
87+
{
88+
PromptFormatter = static (messages, _) =>
89+
{
90+
StringBuilder promptBuilder = new();
91+
foreach (var message in messages)
92+
{
93+
promptBuilder.Append($"<|{message.Role}|>\n{message.Text}");
94+
}
95+
promptBuilder.Append("<|end|>\n<|assistant|>");
96+
97+
return promptBuilder.ToString();
98+
}
99+
});
100+
101+
var builder = chatClient.AsBuilder()
102+
.UseKernelFunctionInvocation(loggerFactory);
103+
104+
if (loggerFactory is not null)
105+
{
106+
builder.UseLogging(loggerFactory);
107+
}
108+
109+
return builder.Build();
110+
}
111+
112+
services.AddKeyedSingleton<IChatClient>(serviceId, (Func<IServiceProvider, object?, IChatClient>)Factory);
113+
114+
return services;
115+
}
60116
}
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
#pragma warning disable SKEXP0010
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Net.Http;
8+
using System.Text;
9+
using System.Threading.Tasks;
10+
using Microsoft.Extensions.AI;
11+
using Microsoft.Extensions.Configuration;
12+
using Microsoft.Extensions.DependencyInjection;
13+
using Microsoft.SemanticKernel;
14+
using SemanticKernel.IntegrationTests.TestSettings;
15+
using Xunit;
16+
17+
namespace SemanticKernel.IntegrationTests.Connectors.Onnx;
18+
19+
public class OnnxRuntimeGenAIChatClientTests : BaseIntegrationTest
20+
{
21+
[Fact(Skip = "For manual verification only")]
22+
public async Task ItCanUseKernelInvokeAsyncWithChatClientAsync()
23+
{
24+
// Arrange
25+
var kernel = this.CreateAndInitializeKernelWithChatClient();
26+
27+
var func = kernel.CreateFunctionFromPrompt("List the two planets after '{{$input}}', excluding moons, using bullet points.");
28+
29+
// Act
30+
var result = await func.InvokeAsync(kernel, new() { ["input"] = "Jupiter" });
31+
32+
// Assert
33+
Assert.NotNull(result);
34+
Assert.Contains("Saturn", result.GetValue<string>(), StringComparison.InvariantCultureIgnoreCase);
35+
Assert.Contains("Uranus", result.GetValue<string>(), StringComparison.InvariantCultureIgnoreCase);
36+
}
37+
38+
[Fact(Skip = "For manual verification only")]
39+
public async Task ItCanUseKernelInvokeStreamingAsyncWithChatClientAsync()
40+
{
41+
// Arrange
42+
var kernel = this.CreateAndInitializeKernelWithChatClient();
43+
44+
var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin");
45+
46+
StringBuilder fullResult = new();
47+
48+
var prompt = "Where is the most famous fish market in Seattle, Washington, USA?";
49+
50+
// Act
51+
await foreach (var content in kernel.InvokeStreamingAsync<StreamingKernelContent>(plugins["ChatPlugin"]["Chat"], new() { ["input"] = prompt }))
52+
{
53+
fullResult.Append(content);
54+
}
55+
56+
// Assert
57+
Assert.Contains("Pike Place", fullResult.ToString(), StringComparison.OrdinalIgnoreCase);
58+
}
59+
60+
[Fact(Skip = "For manual verification only")]
61+
public async Task ItCanUseServiceGetResponseAsync()
62+
{
63+
using var chatClient = CreateChatClient();
64+
65+
var messages = new List<ChatMessage>
66+
{
67+
new(ChatRole.User, "Where is the most famous fish market in Seattle, Washington, USA?")
68+
};
69+
70+
var response = await chatClient.GetResponseAsync(messages);
71+
72+
// Assert
73+
Assert.NotNull(response);
74+
Assert.Contains("Pike Place", response.Text, StringComparison.OrdinalIgnoreCase);
75+
}
76+
77+
[Fact(Skip = "For manual verification only")]
78+
public async Task ItCanUseServiceGetStreamingResponseAsync()
79+
{
80+
using var chatClient = CreateChatClient();
81+
82+
var messages = new List<ChatMessage>
83+
{
84+
new(ChatRole.User, "Where is the most famous fish market in Seattle, Washington, USA?")
85+
};
86+
87+
StringBuilder fullResult = new();
88+
89+
await foreach (var update in chatClient.GetStreamingResponseAsync(messages))
90+
{
91+
fullResult.Append(update.Text);
92+
}
93+
94+
// Assert
95+
Assert.Contains("Pike Place", fullResult.ToString(), StringComparison.OrdinalIgnoreCase);
96+
}
97+
98+
private static IChatClient CreateChatClient()
99+
{
100+
Assert.NotNull(Configuration.ModelPath);
101+
Assert.NotNull(Configuration.ModelId);
102+
103+
var services = new ServiceCollection();
104+
services.AddOnnxRuntimeGenAIChatClient(Configuration.ModelId);
105+
106+
var serviceProvider = services.BuildServiceProvider();
107+
return serviceProvider.GetRequiredService<IChatClient>();
108+
}
109+
110+
#region internals
111+
112+
private Kernel CreateAndInitializeKernelWithChatClient(HttpClient? httpClient = null)
113+
{
114+
Assert.NotNull(Configuration.ModelPath);
115+
Assert.NotNull(Configuration.ModelId);
116+
117+
var kernelBuilder = base.CreateKernelBuilder();
118+
119+
kernelBuilder.AddOnnxRuntimeGenAIChatClient(
120+
modelPath: Configuration.ModelPath,
121+
serviceId: Configuration.ServiceId);
122+
123+
return kernelBuilder.Build();
124+
}
125+
126+
private static OnnxConfiguration Configuration => new ConfigurationBuilder()
127+
.AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true)
128+
.AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true)
129+
.AddEnvironmentVariables()
130+
.AddUserSecrets<OnnxRuntimeGenAIChatClientTests>()
131+
.Build()
132+
.GetRequiredSection("Onnx")
133+
.Get<OnnxConfiguration>()!;
134+
135+
#endregion
136+
}

0 commit comments

Comments
 (0)