Skip to content

Commit 2fa658e

Browse files
authored
Add WithXx overloads that take target instance (#706)
* Add WithXx overloads that take target instance * Special-case enumerables
1 parent aef509f commit 2fa658e

File tree

4 files changed

+290
-4
lines changed

4 files changed

+290
-4
lines changed

src/ModelContextProtocol/McpServerBuilderExtensions.cs

Lines changed: 138 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,53 @@ public static partial class McpServerBuilderExtensions
5353
return builder;
5454
}
5555

56+
/// <summary>Adds <see cref="McpServerTool"/> instances to the service collection backing <paramref name="builder"/>.</summary>
57+
/// <typeparam name="TToolType">The tool type.</typeparam>
58+
/// <param name="builder">The builder instance.</param>
59+
/// <param name="target">The target instance from which the tools should be sourced.</param>
60+
/// <param name="serializerOptions">The serializer options governing tool parameter marshalling.</param>
61+
/// <returns>The builder provided in <paramref name="builder"/>.</returns>
62+
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
63+
/// <remarks>
64+
/// <para>
65+
/// This method discovers all methods (public and non-public) on the specified <typeparamref name="TToolType"/>
66+
/// type, where the methods are attributed as <see cref="McpServerToolAttribute"/>, and adds an <see cref="McpServerTool"/>
67+
/// instance for each, using <paramref name="target"/> as the associated instance for instance methods.
68+
/// </para>
69+
/// <para>
70+
/// However, if <typeparamref name="TToolType"/> is itself an <see cref="IEnumerable{T}"/> of <see cref="McpServerTool"/>,
71+
/// this method will register those tools directly without scanning for methods on <typeparamref name="TToolType"/>.
72+
/// </para>
73+
/// </remarks>
74+
public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers(
75+
DynamicallyAccessedMemberTypes.PublicMethods |
76+
DynamicallyAccessedMemberTypes.NonPublicMethods)] TToolType>(
77+
this IMcpServerBuilder builder,
78+
TToolType target,
79+
JsonSerializerOptions? serializerOptions = null)
80+
{
81+
Throw.IfNull(builder);
82+
Throw.IfNull(target);
83+
84+
if (target is IEnumerable<McpServerTool> tools)
85+
{
86+
return builder.WithTools(tools);
87+
}
88+
89+
foreach (var toolMethod in typeof(TToolType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
90+
{
91+
if (toolMethod.GetCustomAttribute<McpServerToolAttribute>() is not null)
92+
{
93+
builder.Services.AddSingleton(services => McpServerTool.Create(
94+
toolMethod,
95+
toolMethod.IsStatic ? null : target,
96+
new() { Services = services, SerializerOptions = serializerOptions }));
97+
}
98+
}
99+
100+
return builder;
101+
}
102+
56103
/// <summary>Adds <see cref="McpServerTool"/> instances to the service collection backing <paramref name="builder"/>.</summary>
57104
/// <param name="builder">The builder instance.</param>
58105
/// <param name="tools">The <see cref="McpServerTool"/> instances to add to the server.</param>
@@ -137,7 +184,7 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, IEnume
137184
/// </para>
138185
/// <para>
139186
/// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For
140-
/// Native AOT compatibility, consider using the generic <see cref="WithTools{TToolType}"/> method instead.
187+
/// Native AOT compatibility, consider using the generic <see cref="M:WithTools"/> method instead.
141188
/// </para>
142189
/// </remarks>
143190
[RequiresUnreferencedCode(WithToolsRequiresUnreferencedCodeMessage)]
@@ -193,6 +240,50 @@ where t.GetCustomAttribute<McpServerToolTypeAttribute>() is not null
193240
return builder;
194241
}
195242

243+
/// <summary>Adds <see cref="McpServerPrompt"/> instances to the service collection backing <paramref name="builder"/>.</summary>
244+
/// <typeparam name="TPromptType">The prompt type.</typeparam>
245+
/// <param name="builder">The builder instance.</param>
246+
/// <param name="target">The target instance from which the prompts should be sourced.</param>
247+
/// <param name="serializerOptions">The serializer options governing prompt parameter marshalling.</param>
248+
/// <returns>The builder provided in <paramref name="builder"/>.</returns>
249+
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
250+
/// <remarks>
251+
/// <para>
252+
/// This method discovers all methods (public and non-public) on the specified <typeparamref name="TPromptType"/>
253+
/// type, where the methods are attributed as <see cref="McpServerPromptAttribute"/>, and adds an <see cref="McpServerPrompt"/>
254+
/// instance for each, using <paramref name="target"/> as the associated instance for instance methods.
255+
/// </para>
256+
/// <para>
257+
/// However, if <typeparamref name="TPromptType"/> is itself an <see cref="IEnumerable{T}"/> of <see cref="McpServerPrompt"/>,
258+
/// this method will register those prompts directly without scanning for methods on <typeparamref name="TPromptType"/>.
259+
/// </para>
260+
/// </remarks>
261+
public static IMcpServerBuilder WithPrompts<[DynamicallyAccessedMembers(
262+
DynamicallyAccessedMemberTypes.PublicMethods |
263+
DynamicallyAccessedMemberTypes.NonPublicMethods)] TPromptType>(
264+
this IMcpServerBuilder builder,
265+
TPromptType target,
266+
JsonSerializerOptions? serializerOptions = null)
267+
{
268+
Throw.IfNull(builder);
269+
Throw.IfNull(target);
270+
271+
if (target is IEnumerable<McpServerPrompt> prompts)
272+
{
273+
return builder.WithPrompts(prompts);
274+
}
275+
276+
foreach (var promptMethod in typeof(TPromptType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
277+
{
278+
if (promptMethod.GetCustomAttribute<McpServerPromptAttribute>() is not null)
279+
{
280+
builder.Services.AddSingleton(services => McpServerPrompt.Create(promptMethod, target, new() { Services = services, SerializerOptions = serializerOptions }));
281+
}
282+
}
283+
284+
return builder;
285+
}
286+
196287
/// <summary>Adds <see cref="McpServerPrompt"/> instances to the service collection backing <paramref name="builder"/>.</summary>
197288
/// <param name="builder">The builder instance.</param>
198289
/// <param name="prompts">The <see cref="McpServerPrompt"/> instances to add to the server.</param>
@@ -277,7 +368,7 @@ public static IMcpServerBuilder WithPrompts(this IMcpServerBuilder builder, IEnu
277368
/// </para>
278369
/// <para>
279370
/// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For
280-
/// Native AOT compatibility, consider using the generic <see cref="WithPrompts{TPromptType}"/> method instead.
371+
/// Native AOT compatibility, consider using the generic <see cref="M:WithPrompts"/> method instead.
281372
/// </para>
282373
/// </remarks>
283374
[RequiresUnreferencedCode(WithPromptsRequiresUnreferencedCodeMessage)]
@@ -311,7 +402,8 @@ where t.GetCustomAttribute<McpServerPromptTypeAttribute>() is not null
311402
/// instance for each. For instance members, an instance will be constructed for each invocation of the resource.
312403
/// </remarks>
313404
public static IMcpServerBuilder WithResources<[DynamicallyAccessedMembers(
314-
DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods |
405+
DynamicallyAccessedMemberTypes.PublicMethods |
406+
DynamicallyAccessedMemberTypes.NonPublicMethods |
315407
DynamicallyAccessedMemberTypes.PublicConstructors)] TResourceType>(
316408
this IMcpServerBuilder builder)
317409
{
@@ -330,6 +422,48 @@ where t.GetCustomAttribute<McpServerPromptTypeAttribute>() is not null
330422
return builder;
331423
}
332424

425+
/// <summary>Adds <see cref="McpServerResource"/> instances to the service collection backing <paramref name="builder"/>.</summary>
426+
/// <typeparam name="TResourceType">The resource type.</typeparam>
427+
/// <param name="builder">The builder instance.</param>
428+
/// <param name="target">The target instance from which the prompts should be sourced.</param>
429+
/// <returns>The builder provided in <paramref name="builder"/>.</returns>
430+
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
431+
/// <remarks>
432+
/// <para>
433+
/// This method discovers all methods (public and non-public) on the specified <typeparamref name="TResourceType"/>
434+
/// type, where the methods are attributed as <see cref="McpServerResourceAttribute"/>, and adds an <see cref="McpServerResource"/>
435+
/// instance for each, using <paramref name="target"/> as the associated instance for instance methods.
436+
/// </para>
437+
/// <para>
438+
/// However, if <typeparamref name="TResourceType"/> is itself an <see cref="IEnumerable{T}"/> of <see cref="McpServerResource"/>,
439+
/// this method will register those resources directly without scanning for methods on <typeparamref name="TResourceType"/>.
440+
/// </para>
441+
/// </remarks>
442+
public static IMcpServerBuilder WithResources<[DynamicallyAccessedMembers(
443+
DynamicallyAccessedMemberTypes.PublicMethods |
444+
DynamicallyAccessedMemberTypes.NonPublicMethods)] TResourceType>(
445+
this IMcpServerBuilder builder,
446+
TResourceType target)
447+
{
448+
Throw.IfNull(builder);
449+
Throw.IfNull(target);
450+
451+
if (target is IEnumerable<McpServerResource> resources)
452+
{
453+
return builder.WithResources(resources);
454+
}
455+
456+
foreach (var resourceTemplateMethod in typeof(TResourceType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
457+
{
458+
if (resourceTemplateMethod.GetCustomAttribute<McpServerResourceAttribute>() is not null)
459+
{
460+
builder.Services.AddSingleton(services => McpServerResource.Create(resourceTemplateMethod, target, new() { Services = services }));
461+
}
462+
}
463+
464+
return builder;
465+
}
466+
333467
/// <summary>Adds <see cref="McpServerResource"/> instances to the service collection backing <paramref name="builder"/>.</summary>
334468
/// <param name="builder">The builder instance.</param>
335469
/// <param name="resourceTemplates">The <see cref="McpServerResource"/> instances to add to the server.</param>
@@ -412,7 +546,7 @@ public static IMcpServerBuilder WithResources(this IMcpServerBuilder builder, IE
412546
/// </para>
413547
/// <para>
414548
/// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For
415-
/// Native AOT compatibility, consider using the generic <see cref="WithResources{TResourceType}"/> method instead.
549+
/// Native AOT compatibility, consider using the generic <see cref="M:WithResources"/> method instead.
416550
/// </para>
417551
/// </remarks>
418552
[RequiresUnreferencedCode(WithResourcesRequiresUnreferencedCodeMessage)]

tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
using ModelContextProtocol.Client;
55
using ModelContextProtocol.Protocol;
66
using ModelContextProtocol.Server;
7+
using Moq;
8+
using System.Collections;
79
using System.ComponentModel;
10+
using System.Text.Json;
811
using System.Text.Json.Serialization;
912
using System.Threading.Channels;
1013

@@ -217,13 +220,63 @@ public void WithPrompts_InvalidArgs_Throws()
217220

218221
Assert.Throws<ArgumentNullException>("prompts", () => builder.WithPrompts((IEnumerable<McpServerPrompt>)null!));
219222
Assert.Throws<ArgumentNullException>("promptTypes", () => builder.WithPrompts((IEnumerable<Type>)null!));
223+
Assert.Throws<ArgumentNullException>("target", () => builder.WithPrompts<object>(target: null!));
220224

221225
IMcpServerBuilder nullBuilder = null!;
222226
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithPrompts<object>());
227+
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithPrompts(new object()));
223228
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithPrompts(Array.Empty<Type>()));
224229
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithPromptsFromAssembly());
225230
}
226231

232+
[Fact]
233+
public async Task WithPrompts_TargetInstance_UsesTarget()
234+
{
235+
ServiceCollection sc = new();
236+
237+
var target = new SimplePrompts(new ObjectWithId() { Id = "42" });
238+
sc.AddMcpServer().WithPrompts(target);
239+
240+
McpServerPrompt prompt = sc.BuildServiceProvider().GetServices<McpServerPrompt>().First(t => t.ProtocolPrompt.Name == "returns_string");
241+
var result = await prompt.GetAsync(new RequestContext<GetPromptRequestParams>(new Mock<IMcpServer>().Object)
242+
{
243+
Params = new GetPromptRequestParams
244+
{
245+
Name = "returns_string",
246+
Arguments = new Dictionary<string, JsonElement>
247+
{
248+
["message"] = JsonSerializer.SerializeToElement("hello", AIJsonUtilities.DefaultOptions),
249+
}
250+
}
251+
}, TestContext.Current.CancellationToken);
252+
253+
Assert.Equal(target.ReturnsString("hello"), (result.Messages[0].Content as TextContentBlock)?.Text);
254+
}
255+
256+
[Fact]
257+
public async Task WithPrompts_TargetInstance_UsesEnumerableImplementation()
258+
{
259+
ServiceCollection sc = new();
260+
261+
sc.AddMcpServer().WithPrompts(new MyPromptProvider());
262+
263+
var prompts = sc.BuildServiceProvider().GetServices<McpServerPrompt>().ToArray();
264+
Assert.Equal(2, prompts.Length);
265+
Assert.Contains(prompts, t => t.ProtocolPrompt.Name == "Returns42");
266+
Assert.Contains(prompts, t => t.ProtocolPrompt.Name == "Returns43");
267+
}
268+
269+
private sealed class MyPromptProvider : IEnumerable<McpServerPrompt>
270+
{
271+
public IEnumerator<McpServerPrompt> GetEnumerator()
272+
{
273+
yield return McpServerPrompt.Create(() => "42", new() { Name = "Returns42" });
274+
yield return McpServerPrompt.Create(() => "43", new() { Name = "Returns43" });
275+
}
276+
277+
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
278+
}
279+
227280
[Fact]
228281
public void Empty_Enumerables_Is_Allowed()
229282
{

tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
using ModelContextProtocol.Client;
55
using ModelContextProtocol.Protocol;
66
using ModelContextProtocol.Server;
7+
using Moq;
8+
using System.Collections;
79
using System.ComponentModel;
10+
using System.Text.Json;
811
using System.Threading.Channels;
12+
using static ModelContextProtocol.Tests.Configuration.McpServerBuilderExtensionsPromptsTests;
913

1014
namespace ModelContextProtocol.Tests.Configuration;
1115

@@ -243,13 +247,59 @@ public void WithResources_InvalidArgs_Throws()
243247

244248
Assert.Throws<ArgumentNullException>("resourceTemplates", () => builder.WithResources((IEnumerable<McpServerResource>)null!));
245249
Assert.Throws<ArgumentNullException>("resourceTemplateTypes", () => builder.WithResources((IEnumerable<Type>)null!));
250+
Assert.Throws<ArgumentNullException>("target", () => builder.WithResources<object>(target: null!));
246251

247252
IMcpServerBuilder nullBuilder = null!;
248253
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithResources<object>());
254+
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithResources(new object()));
249255
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithResources(Array.Empty<Type>()));
250256
Assert.Throws<ArgumentNullException>("builder", () => nullBuilder.WithResourcesFromAssembly());
251257
}
252258

259+
[Fact]
260+
public async Task WithResources_TargetInstance_UsesTarget()
261+
{
262+
ServiceCollection sc = new();
263+
264+
var target = new ResourceWithId(new ObjectWithId() { Id = "42" });
265+
sc.AddMcpServer().WithResources(target);
266+
267+
McpServerResource resource = sc.BuildServiceProvider().GetServices<McpServerResource>().First(t => t.ProtocolResource?.Name == "returns_string");
268+
var result = await resource.ReadAsync(new RequestContext<ReadResourceRequestParams>(new Mock<IMcpServer>().Object)
269+
{
270+
Params = new()
271+
{
272+
Uri = "returns://string"
273+
}
274+
}, TestContext.Current.CancellationToken);
275+
276+
Assert.Equal(target.ReturnsString(), (result?.Contents[0] as TextResourceContents)?.Text);
277+
}
278+
279+
[Fact]
280+
public async Task WithResources_TargetInstance_UsesEnumerableImplementation()
281+
{
282+
ServiceCollection sc = new();
283+
284+
sc.AddMcpServer().WithResources(new MyResourceProvider());
285+
286+
var resources = sc.BuildServiceProvider().GetServices<McpServerResource>().ToArray();
287+
Assert.Equal(2, resources.Length);
288+
Assert.Contains(resources, t => t.ProtocolResource?.Name == "Returns42");
289+
Assert.Contains(resources, t => t.ProtocolResource?.Name == "Returns43");
290+
}
291+
292+
private sealed class MyResourceProvider : IEnumerable<McpServerResource>
293+
{
294+
public IEnumerator<McpServerResource> GetEnumerator()
295+
{
296+
yield return McpServerResource.Create(() => "42", new() { Name = "Returns42" });
297+
yield return McpServerResource.Create(() => "43", new() { Name = "Returns43" });
298+
}
299+
300+
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
301+
}
302+
253303
[Fact]
254304
public void Empty_Enumerables_Is_Allowed()
255305
{
@@ -307,4 +357,11 @@ public sealed class MoreResources
307357
[McpServerResource, Description("Another neat direct resource")]
308358
public static string AnotherNeatDirectResource() => "This is a neat resource";
309359
}
360+
361+
[McpServerResourceType]
362+
public sealed class ResourceWithId(ObjectWithId id)
363+
{
364+
[McpServerResource(UriTemplate = "returns://string")]
365+
public string ReturnsString() => $"Id: {id.Id}";
366+
}
310367
}

0 commit comments

Comments
 (0)