Skip to content

Commit f3539cc

Browse files
authored
Enable injecting IMcpServer and friends into ctors (#570)
1 parent 9d835a0 commit f3539cc

File tree

7 files changed

+249
-191
lines changed

7 files changed

+249
-191
lines changed

src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs

Lines changed: 11 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using Microsoft.Extensions.DependencyInjection;
33
using ModelContextProtocol.Protocol;
44
using System.ComponentModel;
5+
using System.Diagnostics;
56
using System.Reflection;
67
using System.Text.Json;
78

@@ -57,8 +58,8 @@ internal sealed class AIFunctionMcpServerPrompt : McpServerPrompt
5758
return Create(
5859
AIFunctionFactory.Create(method, args =>
5960
{
60-
var request = (RequestContext<GetPromptRequestParams>)args.Context![typeof(RequestContext<GetPromptRequestParams>)]!;
61-
return createTargetFunc(request);
61+
Debug.Assert(args.Services is RequestServiceProvider<GetPromptRequestParams>, $"The service provider should be a {nameof(RequestServiceProvider<GetPromptRequestParams>)} for this method to work correctly.");
62+
return createTargetFunc(((RequestServiceProvider<GetPromptRequestParams>)args.Services!).Request);
6263
}, CreateAIFunctionFactoryOptions(method, options)),
6364
options);
6465
}
@@ -74,54 +75,15 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
7475
JsonSchemaCreateOptions = options?.SchemaCreateOptions,
7576
ConfigureParameterBinding = pi =>
7677
{
77-
if (pi.ParameterType == typeof(RequestContext<GetPromptRequestParams>))
78+
if (RequestServiceProvider<GetPromptRequestParams>.IsAugmentedWith(pi.ParameterType) ||
79+
(options?.Services?.GetService<IServiceProviderIsService>() is { } ispis &&
80+
ispis.IsService(pi.ParameterType)))
7881
{
79-
return new()
80-
{
81-
ExcludeFromSchema = true,
82-
BindParameter = (pi, args) => GetRequestContext(args),
83-
};
84-
}
85-
86-
if (pi.ParameterType == typeof(IMcpServer))
87-
{
88-
return new()
89-
{
90-
ExcludeFromSchema = true,
91-
BindParameter = (pi, args) => GetRequestContext(args)?.Server,
92-
};
93-
}
94-
95-
if (pi.ParameterType == typeof(IProgress<ProgressNotificationValue>))
96-
{
97-
// Bind IProgress<ProgressNotificationValue> to the progress token in the request,
98-
// if there is one. If we can't get one, return a nop progress.
9982
return new()
10083
{
10184
ExcludeFromSchema = true,
10285
BindParameter = (pi, args) =>
103-
{
104-
var requestContent = GetRequestContext(args);
105-
if (requestContent?.Server is { } server &&
106-
requestContent?.Params?.ProgressToken is { } progressToken)
107-
{
108-
return new TokenProgress(server, progressToken);
109-
}
110-
111-
return NullProgress.Instance;
112-
},
113-
};
114-
}
115-
116-
if (options?.Services is { } services &&
117-
services.GetService<IServiceProviderIsService>() is { } ispis &&
118-
ispis.IsService(pi.ParameterType))
119-
{
120-
return new()
121-
{
122-
ExcludeFromSchema = true,
123-
BindParameter = (pi, args) =>
124-
GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ??
86+
args.Services?.GetService(pi.ParameterType) ??
12587
(pi.HasDefaultValue ? null :
12688
throw new ArgumentException("No service of the requested type was found.")),
12789
};
@@ -133,24 +95,13 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
13395
{
13496
ExcludeFromSchema = true,
13597
BindParameter = (pi, args) =>
136-
(GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
98+
(args?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
13799
(pi.HasDefaultValue ? null :
138100
throw new ArgumentException("No service of the requested type was found.")),
139101
};
140102
}
141103

142104
return default;
143-
144-
static RequestContext<GetPromptRequestParams>? GetRequestContext(AIFunctionArguments args)
145-
{
146-
if (args.Context?.TryGetValue(typeof(RequestContext<GetPromptRequestParams>), out var orc) is true &&
147-
orc is RequestContext<GetPromptRequestParams> requestContext)
148-
{
149-
return requestContext;
150-
}
151-
152-
return null;
153-
}
154105
},
155106
};
156107

@@ -226,14 +177,10 @@ public override async ValueTask<GetPromptResult> GetAsync(
226177
Throw.IfNull(request);
227178
cancellationToken.ThrowIfCancellationRequested();
228179

229-
AIFunctionArguments arguments = new()
230-
{
231-
Services = request.Services,
232-
Context = new Dictionary<object, object?>() { [typeof(RequestContext<GetPromptRequestParams>)] = request }
233-
};
180+
request.Services = new RequestServiceProvider<GetPromptRequestParams>(request, request.Services);
181+
AIFunctionArguments arguments = new() { Services = request.Services };
234182

235-
var argDict = request.Params?.Arguments;
236-
if (argDict is not null)
183+
if (request.Params?.Arguments is { } argDict)
237184
{
238185
foreach (var kvp in argDict)
239186
{

src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs

Lines changed: 10 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using ModelContextProtocol.Protocol;
44
using System.Collections.Concurrent;
55
using System.ComponentModel;
6+
using System.Diagnostics;
67
using System.Globalization;
78
using System.Reflection;
89
using System.Text;
@@ -64,8 +65,8 @@ internal sealed class AIFunctionMcpServerResource : McpServerResource
6465
return Create(
6566
AIFunctionFactory.Create(method, args =>
6667
{
67-
var request = (RequestContext<ReadResourceRequestParams>)args.Context![typeof(RequestContext<ReadResourceRequestParams>)]!;
68-
return createTargetFunc(request);
68+
Debug.Assert(args.Services is RequestServiceProvider<ReadResourceRequestParams>, $"The service provider should be a {nameof(RequestServiceProvider<ReadResourceRequestParams>)} for this method to work correctly.");
69+
return createTargetFunc(((RequestServiceProvider<ReadResourceRequestParams>)args.Services!).Request);
6970
}, CreateAIFunctionFactoryOptions(method, options)),
7071
options);
7172
}
@@ -81,54 +82,15 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
8182
JsonSchemaCreateOptions = options?.SchemaCreateOptions,
8283
ConfigureParameterBinding = pi =>
8384
{
84-
if (pi.ParameterType == typeof(RequestContext<ReadResourceRequestParams>))
85-
{
86-
return new()
87-
{
88-
ExcludeFromSchema = true,
89-
BindParameter = (pi, args) => GetRequestContext(args),
90-
};
91-
}
92-
93-
if (pi.ParameterType == typeof(IMcpServer))
94-
{
95-
return new()
96-
{
97-
ExcludeFromSchema = true,
98-
BindParameter = (pi, args) => GetRequestContext(args)?.Server,
99-
};
100-
}
101-
102-
if (pi.ParameterType == typeof(IProgress<ProgressNotificationValue>))
103-
{
104-
// Bind IProgress<ProgressNotificationValue> to the progress token in the request,
105-
// if there is one. If we can't get one, return a nop progress.
106-
return new()
107-
{
108-
ExcludeFromSchema = true,
109-
BindParameter = (pi, args) =>
110-
{
111-
var requestContent = GetRequestContext(args);
112-
if (requestContent?.Server is { } server &&
113-
requestContent?.Params?.ProgressToken is { } progressToken)
114-
{
115-
return new TokenProgress(server, progressToken);
116-
}
117-
118-
return NullProgress.Instance;
119-
},
120-
};
121-
}
122-
123-
if (options?.Services is { } services &&
124-
services.GetService<IServiceProviderIsService>() is { } ispis &&
125-
ispis.IsService(pi.ParameterType))
85+
if (RequestServiceProvider<ReadResourceRequestParams>.IsAugmentedWith(pi.ParameterType) ||
86+
(options?.Services?.GetService<IServiceProviderIsService>() is { } ispis &&
87+
ispis.IsService(pi.ParameterType)))
12688
{
12789
return new()
12890
{
12991
ExcludeFromSchema = true,
13092
BindParameter = (pi, args) =>
131-
GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ??
93+
args.Services?.GetService(pi.ParameterType) ??
13294
(pi.HasDefaultValue ? null :
13395
throw new ArgumentException("No service of the requested type was found.")),
13496
};
@@ -140,7 +102,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
140102
{
141103
ExcludeFromSchema = true,
142104
BindParameter = (pi, args) =>
143-
(GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
105+
(args?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
144106
(pi.HasDefaultValue ? null :
145107
throw new ArgumentException("No service of the requested type was found.")),
146108
};
@@ -172,17 +134,6 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
172134
}
173135

174136
return default;
175-
176-
static RequestContext<ReadResourceRequestParams>? GetRequestContext(AIFunctionArguments args)
177-
{
178-
if (args.Context?.TryGetValue(typeof(RequestContext<ReadResourceRequestParams>), out var rc) is true &&
179-
rc is RequestContext<ReadResourceRequestParams> requestContext)
180-
{
181-
return requestContext;
182-
}
183-
184-
return null;
185-
}
186137
},
187138
};
188139

@@ -365,11 +316,8 @@ private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resour
365316
}
366317

367318
// Build up the arguments for the AIFunction call, including all of the name/value pairs from the URI.
368-
AIFunctionArguments arguments = new()
369-
{
370-
Services = request.Services,
371-
Context = new Dictionary<object, object?>() { [typeof(RequestContext<ReadResourceRequestParams>)] = request }
372-
};
319+
request.Services = new RequestServiceProvider<ReadResourceRequestParams>(request, request.Services);
320+
AIFunctionArguments arguments = new() { Services = request.Services };
373321

374322
// For templates, populate the arguments from the URI template.
375323
if (match is not null)

src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs

Lines changed: 11 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
using Microsoft.Extensions.Logging.Abstractions;
55
using ModelContextProtocol.Protocol;
66
using System.ComponentModel;
7-
using System.Diagnostics.CodeAnalysis;
7+
using System.Diagnostics;
88
using System.Reflection;
99
using System.Text.Json;
1010
using System.Text.Json.Nodes;
@@ -64,8 +64,8 @@ internal sealed partial class AIFunctionMcpServerTool : McpServerTool
6464
return Create(
6565
AIFunctionFactory.Create(method, args =>
6666
{
67-
var request = (RequestContext<CallToolRequestParams>)args.Context![typeof(RequestContext<CallToolRequestParams>)]!;
68-
return createTargetFunc(request);
67+
Debug.Assert(args.Services is RequestServiceProvider<CallToolRequestParams>, $"The service provider should be a {nameof(RequestServiceProvider<CallToolRequestParams>)} for this method to work correctly.");
68+
return createTargetFunc(((RequestServiceProvider<CallToolRequestParams>)args.Services!).Request);
6969
}, CreateAIFunctionFactoryOptions(method, options)),
7070
options);
7171
}
@@ -81,54 +81,15 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
8181
JsonSchemaCreateOptions = options?.SchemaCreateOptions,
8282
ConfigureParameterBinding = pi =>
8383
{
84-
if (pi.ParameterType == typeof(RequestContext<CallToolRequestParams>))
84+
if (RequestServiceProvider<CallToolRequestParams>.IsAugmentedWith(pi.ParameterType) ||
85+
(options?.Services?.GetService<IServiceProviderIsService>() is { } ispis &&
86+
ispis.IsService(pi.ParameterType)))
8587
{
86-
return new()
87-
{
88-
ExcludeFromSchema = true,
89-
BindParameter = (pi, args) => GetRequestContext(args),
90-
};
91-
}
92-
93-
if (pi.ParameterType == typeof(IMcpServer))
94-
{
95-
return new()
96-
{
97-
ExcludeFromSchema = true,
98-
BindParameter = (pi, args) => GetRequestContext(args)?.Server,
99-
};
100-
}
101-
102-
if (pi.ParameterType == typeof(IProgress<ProgressNotificationValue>))
103-
{
104-
// Bind IProgress<ProgressNotificationValue> to the progress token in the request,
105-
// if there is one. If we can't get one, return a nop progress.
10688
return new()
10789
{
10890
ExcludeFromSchema = true,
10991
BindParameter = (pi, args) =>
110-
{
111-
var requestContent = GetRequestContext(args);
112-
if (requestContent?.Server is { } server &&
113-
requestContent?.Params?.ProgressToken is { } progressToken)
114-
{
115-
return new TokenProgress(server, progressToken);
116-
}
117-
118-
return NullProgress.Instance;
119-
},
120-
};
121-
}
122-
123-
if (options?.Services is { } services &&
124-
services.GetService<IServiceProviderIsService>() is { } ispis &&
125-
ispis.IsService(pi.ParameterType))
126-
{
127-
return new()
128-
{
129-
ExcludeFromSchema = true,
130-
BindParameter = (pi, args) =>
131-
GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ??
92+
args.Services?.GetService(pi.ParameterType) ??
13293
(pi.HasDefaultValue ? null :
13394
throw new ArgumentException("No service of the requested type was found.")),
13495
};
@@ -140,24 +101,13 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
140101
{
141102
ExcludeFromSchema = true,
142103
BindParameter = (pi, args) =>
143-
(GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
104+
(args?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
144105
(pi.HasDefaultValue ? null :
145106
throw new ArgumentException("No service of the requested type was found.")),
146107
};
147108
}
148109

149110
return default;
150-
151-
static RequestContext<CallToolRequestParams>? GetRequestContext(AIFunctionArguments args)
152-
{
153-
if (args.Context?.TryGetValue(typeof(RequestContext<CallToolRequestParams>), out var orc) is true &&
154-
orc is RequestContext<CallToolRequestParams> requestContext)
155-
{
156-
return requestContext;
157-
}
158-
159-
return null;
160-
}
161111
},
162112
};
163113

@@ -260,14 +210,10 @@ public override async ValueTask<CallToolResult> InvokeAsync(
260210
Throw.IfNull(request);
261211
cancellationToken.ThrowIfCancellationRequested();
262212

263-
AIFunctionArguments arguments = new()
264-
{
265-
Services = request.Services,
266-
Context = new Dictionary<object, object?>() { [typeof(RequestContext<CallToolRequestParams>)] = request }
267-
};
213+
request.Services = new RequestServiceProvider<CallToolRequestParams>(request, request.Services);
214+
AIFunctionArguments arguments = new() { Services = request.Services };
268215

269-
var argDict = request.Params?.Arguments;
270-
if (argDict is not null)
216+
if (request.Params?.Arguments is { } argDict)
271217
{
272218
foreach (var kvp in argDict)
273219
{

0 commit comments

Comments
 (0)