Skip to content

Commit c7cb5be

Browse files
authored
.NET: Improve resolving AITool from DI (#3175)
* remove localagenttoolregistry * also give the factory method API
1 parent 3e13909 commit c7cb5be

File tree

4 files changed

+170
-67
lines changed

4 files changed

+170
-67
lines changed

dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

33
using System;
4-
using System.Collections.Generic;
5-
using Microsoft.Agents.AI.Hosting.Local;
4+
using System.Linq;
65
using Microsoft.Extensions.AI;
76
using Microsoft.Extensions.DependencyInjection;
87
using Microsoft.Shared.Diagnostics;
@@ -29,7 +28,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s
2928
return services.AddAIAgent(name, (sp, key) =>
3029
{
3130
var chatClient = sp.GetRequiredService<IChatClient>();
32-
var tools = GetRegisteredToolsForAgent(sp, name);
31+
var tools = sp.GetKeyedServices<AITool>(name).ToList();
3332
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
3433
});
3534
}
@@ -49,7 +48,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s
4948
Throw.IfNullOrEmpty(name);
5049
return services.AddAIAgent(name, (sp, key) =>
5150
{
52-
var tools = GetRegisteredToolsForAgent(sp, name);
51+
var tools = sp.GetKeyedServices<AITool>(name).ToList();
5352
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
5453
});
5554
}
@@ -70,7 +69,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s
7069
return services.AddAIAgent(name, (sp, key) =>
7170
{
7271
var chatClient = chatClientServiceKey is null ? sp.GetRequiredService<IChatClient>() : sp.GetRequiredKeyedService<IChatClient>(chatClientServiceKey);
73-
var tools = GetRegisteredToolsForAgent(sp, name);
72+
var tools = sp.GetKeyedServices<AITool>(name).ToList();
7473
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
7574
});
7675
}
@@ -92,7 +91,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s
9291
return services.AddAIAgent(name, (sp, key) =>
9392
{
9493
var chatClient = chatClientServiceKey is null ? sp.GetRequiredService<IChatClient>() : sp.GetRequiredKeyedService<IChatClient>(chatClientServiceKey);
95-
var tools = GetRegisteredToolsForAgent(sp, name);
94+
var tools = sp.GetKeyedServices<AITool>(name).ToList();
9695
return new ChatClientAgent(chatClient, instructions: instructions, name: key, description: description, tools: tools);
9796
});
9897
}
@@ -127,10 +126,4 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s
127126

128127
return new HostedAgentBuilder(name, services);
129128
}
130-
131-
private static IList<AITool> GetRegisteredToolsForAgent(IServiceProvider serviceProvider, string agentName)
132-
{
133-
var registry = serviceProvider.GetService<LocalAgentToolRegistry>();
134-
return registry?.GetTools(agentName) ?? [];
135-
}
136129
}

dotnet/src/Microsoft.Agents.AI.Hosting/HostedAgentBuilderExtensions.cs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

33
using System;
4-
using System.Linq;
5-
using Microsoft.Agents.AI.Hosting.Local;
64
using Microsoft.Extensions.AI;
75
using Microsoft.Extensions.DependencyInjection;
86
using Microsoft.Shared.Diagnostics;
@@ -70,18 +68,7 @@ public static IHostedAgentBuilder WithAITool(this IHostedAgentBuilder builder, A
7068
Throw.IfNull(builder);
7169
Throw.IfNull(tool);
7270

73-
var agentName = builder.Name;
74-
var services = builder.ServiceCollection;
75-
76-
// Get or create the agent tool registry
77-
var descriptor = services.FirstOrDefault(sd => !sd.IsKeyedService && sd.ServiceType.Equals(typeof(LocalAgentToolRegistry)));
78-
if (descriptor?.ImplementationInstance is not LocalAgentToolRegistry toolRegistry)
79-
{
80-
toolRegistry = new();
81-
services.Add(ServiceDescriptor.Singleton(toolRegistry));
82-
}
83-
84-
toolRegistry.AddTool(agentName, tool);
71+
builder.ServiceCollection.AddKeyedSingleton(builder.Name, tool);
8572

8673
return builder;
8774
}
@@ -105,4 +92,19 @@ public static IHostedAgentBuilder WithAITools(this IHostedAgentBuilder builder,
10592

10693
return builder;
10794
}
95+
96+
/// <summary>
97+
/// Adds AI tool to an agent being configured with the service collection.
98+
/// </summary>
99+
/// <param name="builder">The hosted agent builder.</param>
100+
/// <param name="factory">A factory function that creates a AI tool using the provided service provider.</param>
101+
public static IHostedAgentBuilder WithAITool(this IHostedAgentBuilder builder, Func<IServiceProvider, AITool> factory)
102+
{
103+
Throw.IfNull(builder);
104+
Throw.IfNull(factory);
105+
106+
builder.ServiceCollection.AddKeyedSingleton(builder.Name, (sp, name) => factory(sp));
107+
108+
return builder;
109+
}
108110
}

dotnet/src/Microsoft.Agents.AI.Hosting/Local/LocalAgentToolRegistry.cs

Lines changed: 0 additions & 27 deletions
This file was deleted.

dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/HostedAgentBuilderToolsExtensionsTests.cs

Lines changed: 149 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
using System;
44
using System.Collections.Generic;
5+
using System.Linq;
56
using System.Threading;
67
using System.Threading.Tasks;
78
using Microsoft.Extensions.AI;
@@ -17,49 +18,40 @@ public sealed class HostedAgentBuilderToolsExtensionsTests
1718
[Fact]
1819
public void WithAITool_ThrowsWhenBuilderIsNull()
1920
{
20-
// Arrange
2121
var tool = new DummyAITool();
2222

23-
// Act & Assert
2423
Assert.Throws<ArgumentNullException>(() => HostedAgentBuilderExtensions.WithAITool(null!, tool));
2524
}
2625

2726
[Fact]
2827
public void WithAITool_ThrowsWhenToolIsNull()
2928
{
30-
// Arrange
3129
var services = new ServiceCollection();
3230
var builder = services.AddAIAgent("test-agent", "Test instructions");
3331

34-
// Act & Assert
35-
Assert.Throws<ArgumentNullException>(() => builder.WithAITool(null!));
32+
Assert.Throws<ArgumentNullException>(() => builder.WithAITool(tool: null!));
3633
}
3734

3835
[Fact]
3936
public void WithAITools_ThrowsWhenBuilderIsNull()
4037
{
41-
// Arrange
4238
var tools = new[] { new DummyAITool() };
4339

44-
// Act & Assert
4540
Assert.Throws<ArgumentNullException>(() => HostedAgentBuilderExtensions.WithAITools(null!, tools));
4641
}
4742

4843
[Fact]
4944
public void WithAITools_ThrowsWhenToolsArrayIsNull()
5045
{
51-
// Arrange
5246
var services = new ServiceCollection();
5347
var builder = services.AddAIAgent("test-agent", "Test instructions");
5448

55-
// Act & Assert
5649
Assert.Throws<ArgumentNullException>(() => builder.WithAITools(null!));
5750
}
5851

5952
[Fact]
6053
public void RegisteredTools_ResolvesAllToolsForAgent()
6154
{
62-
// Arrange
6355
var services = new ServiceCollection();
6456
services.AddSingleton<IChatClient>(new MockChatClient());
6557

@@ -73,9 +65,13 @@ public void RegisteredTools_ResolvesAllToolsForAgent()
7365

7466
var serviceProvider = services.BuildServiceProvider();
7567

76-
var agent1Tools = ResolveAgentTools(serviceProvider, "test-agent");
68+
var agent1Tools = ResolveToolsFromAgent(serviceProvider, "test-agent");
7769
Assert.Contains(tool1, agent1Tools);
7870
Assert.Contains(tool2, agent1Tools);
71+
72+
var agent1ToolsDI = ResolveToolsFromDI(serviceProvider, "test-agent");
73+
Assert.Contains(tool1, agent1ToolsDI);
74+
Assert.Contains(tool2, agent1ToolsDI);
7975
}
8076

8177
[Fact]
@@ -100,21 +96,160 @@ public void RegisteredTools_IsolatedPerAgent()
10096

10197
var serviceProvider = services.BuildServiceProvider();
10298

103-
var agent1Tools = ResolveAgentTools(serviceProvider, "agent1");
104-
var agent2Tools = ResolveAgentTools(serviceProvider, "agent2");
99+
var agent1Tools = ResolveToolsFromAgent(serviceProvider, "agent1");
100+
var agent2Tools = ResolveToolsFromAgent(serviceProvider, "agent2");
101+
102+
var agent1ToolsDI = ResolveToolsFromDI(serviceProvider, "agent1");
103+
var agent2ToolsDI = ResolveToolsFromDI(serviceProvider, "agent2");
105104

106105
Assert.Contains(tool1, agent1Tools);
107106
Assert.Contains(tool2, agent1Tools);
107+
Assert.Contains(tool1, agent1ToolsDI);
108+
Assert.Contains(tool2, agent1ToolsDI);
109+
108110
Assert.Contains(tool3, agent2Tools);
111+
Assert.Contains(tool3, agent2ToolsDI);
109112
}
110113

111-
private static IList<AITool> ResolveAgentTools(IServiceProvider serviceProvider, string name)
114+
private static IList<AITool> ResolveToolsFromAgent(IServiceProvider serviceProvider, string name)
112115
{
113116
var agent = serviceProvider.GetRequiredKeyedService<AIAgent>(name) as ChatClientAgent;
114117
Assert.NotNull(agent?.ChatOptions?.Tools);
115118
return agent.ChatOptions.Tools;
116119
}
117120

121+
private static List<AITool> ResolveToolsFromDI(IServiceProvider serviceProvider, string name)
122+
{
123+
var tools = serviceProvider.GetKeyedServices<AITool>(name);
124+
Assert.NotNull(tools);
125+
return tools.ToList();
126+
}
127+
128+
[Fact]
129+
public void WithAIToolFactory_ThrowsWhenBuilderIsNull()
130+
{
131+
Assert.Throws<ArgumentNullException>(() => HostedAgentBuilderExtensions.WithAITool(null!, CreateTool));
132+
133+
static AITool CreateTool(IServiceProvider _) => new DummyAITool();
134+
}
135+
136+
[Fact]
137+
public void WithAIToolFactory_ThrowsWhenFactoryIsNull()
138+
{
139+
var services = new ServiceCollection();
140+
var builder = services.AddAIAgent("test-agent", "Test instructions");
141+
142+
Assert.Throws<ArgumentNullException>(() => builder.WithAITool(factory: null!));
143+
}
144+
145+
[Fact]
146+
public void WithAIToolFactory_RegistersToolFromFactory()
147+
{
148+
var services = new ServiceCollection();
149+
services.AddSingleton<IChatClient>(new MockChatClient());
150+
151+
DummyAITool? createdTool = null;
152+
var builder = services.AddAIAgent("test-agent", "Test instructions");
153+
builder.WithAITool(sp =>
154+
{
155+
createdTool = new DummyAITool();
156+
return createdTool;
157+
});
158+
159+
var serviceProvider = services.BuildServiceProvider();
160+
var tools = ResolveToolsFromDI(serviceProvider, "test-agent");
161+
162+
Assert.Single(tools);
163+
Assert.Same(createdTool, tools[0]);
164+
}
165+
166+
[Fact]
167+
public void WithAIToolFactory_CanAccessServicesFromFactory()
168+
{
169+
var services = new ServiceCollection();
170+
var mockChatClient = new MockChatClient();
171+
services.AddSingleton<IChatClient>(mockChatClient);
172+
173+
IChatClient? resolvedChatClient = null;
174+
var builder = services.AddAIAgent("test-agent", "Test instructions");
175+
builder.WithAITool(sp =>
176+
{
177+
resolvedChatClient = sp.GetService<IChatClient>();
178+
return new DummyAITool();
179+
});
180+
181+
var serviceProvider = services.BuildServiceProvider();
182+
_ = ResolveToolsFromDI(serviceProvider, "test-agent");
183+
184+
Assert.Same(mockChatClient, resolvedChatClient);
185+
}
186+
187+
[Fact]
188+
public void WithAIToolFactory_ToolsAreIsolatedPerAgent()
189+
{
190+
var services = new ServiceCollection();
191+
services.AddSingleton<IChatClient>(new MockChatClient());
192+
193+
var tool1 = new DummyAITool();
194+
var tool2 = new DummyAITool();
195+
196+
var builder1 = services.AddAIAgent("agent1", "Agent 1 instructions");
197+
var builder2 = services.AddAIAgent("agent2", "Agent 2 instructions");
198+
199+
builder1.WithAITool(_ => tool1);
200+
builder2.WithAITool(_ => tool2);
201+
202+
var serviceProvider = services.BuildServiceProvider();
203+
var agent1Tools = ResolveToolsFromDI(serviceProvider, "agent1");
204+
var agent2Tools = ResolveToolsFromDI(serviceProvider, "agent2");
205+
206+
Assert.Single(agent1Tools);
207+
Assert.Contains(tool1, agent1Tools);
208+
Assert.DoesNotContain(tool2, agent1Tools);
209+
210+
Assert.Single(agent2Tools);
211+
Assert.Contains(tool2, agent2Tools);
212+
Assert.DoesNotContain(tool1, agent2Tools);
213+
}
214+
215+
[Fact]
216+
public void WithAIToolFactory_CanCombineWithDirectToolRegistration()
217+
{
218+
var services = new ServiceCollection();
219+
services.AddSingleton<IChatClient>(new MockChatClient());
220+
221+
var directTool = new DummyAITool();
222+
var factoryTool = new DummyAITool();
223+
224+
var builder = services.AddAIAgent("test-agent", "Test instructions");
225+
builder
226+
.WithAITool(directTool)
227+
.WithAITool(_ => factoryTool);
228+
229+
var serviceProvider = services.BuildServiceProvider();
230+
var tools = ResolveToolsFromDI(serviceProvider, "test-agent");
231+
232+
Assert.Equal(2, tools.Count);
233+
Assert.Contains(directTool, tools);
234+
Assert.Contains(factoryTool, tools);
235+
}
236+
237+
[Fact]
238+
public void WithAIToolFactory_ToolsAvailableOnAgent()
239+
{
240+
var services = new ServiceCollection();
241+
services.AddSingleton<IChatClient>(new MockChatClient());
242+
243+
var factoryTool = new DummyAITool();
244+
var builder = services.AddAIAgent("test-agent", "Test instructions");
245+
builder.WithAITool(_ => factoryTool);
246+
247+
var serviceProvider = services.BuildServiceProvider();
248+
var agentTools = ResolveToolsFromAgent(serviceProvider, "test-agent");
249+
250+
Assert.Contains(factoryTool, agentTools);
251+
}
252+
118253
/// <summary>
119254
/// Dummy AITool implementation for testing.
120255
/// </summary>

0 commit comments

Comments
 (0)