Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 11 additions & 64 deletions src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.Extensions.DependencyInjection;
using ModelContextProtocol.Protocol;
using System.ComponentModel;
using System.Diagnostics;
using System.Reflection;
using System.Text.Json;

Expand Down Expand Up @@ -57,8 +58,8 @@ internal sealed class AIFunctionMcpServerPrompt : McpServerPrompt
return Create(
AIFunctionFactory.Create(method, args =>
{
var request = (RequestContext<GetPromptRequestParams>)args.Context![typeof(RequestContext<GetPromptRequestParams>)]!;
return createTargetFunc(request);
Debug.Assert(args.Services is RequestServiceProvider<GetPromptRequestParams>, $"The service provider should be a {nameof(RequestServiceProvider<GetPromptRequestParams>)} for this method to work correctly.");
return createTargetFunc(((RequestServiceProvider<GetPromptRequestParams>)args.Services!).Request);
}, CreateAIFunctionFactoryOptions(method, options)),
options);
}
Expand All @@ -74,54 +75,15 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
JsonSchemaCreateOptions = options?.SchemaCreateOptions,
ConfigureParameterBinding = pi =>
{
if (pi.ParameterType == typeof(RequestContext<GetPromptRequestParams>))
if (RequestServiceProvider<GetPromptRequestParams>.IsAugmentedWith(pi.ParameterType) ||
(options?.Services?.GetService<IServiceProviderIsService>() is { } ispis &&
ispis.IsService(pi.ParameterType)))
{
return new()
{
ExcludeFromSchema = true,
BindParameter = (pi, args) => GetRequestContext(args),
};
}

if (pi.ParameterType == typeof(IMcpServer))
{
return new()
{
ExcludeFromSchema = true,
BindParameter = (pi, args) => GetRequestContext(args)?.Server,
};
}

if (pi.ParameterType == typeof(IProgress<ProgressNotificationValue>))
{
// Bind IProgress<ProgressNotificationValue> to the progress token in the request,
// if there is one. If we can't get one, return a nop progress.
return new()
{
ExcludeFromSchema = true,
BindParameter = (pi, args) =>
{
var requestContent = GetRequestContext(args);
if (requestContent?.Server is { } server &&
requestContent?.Params?.ProgressToken is { } progressToken)
{
return new TokenProgress(server, progressToken);
}

return NullProgress.Instance;
},
};
}

if (options?.Services is { } services &&
services.GetService<IServiceProviderIsService>() is { } ispis &&
ispis.IsService(pi.ParameterType))
{
return new()
{
ExcludeFromSchema = true,
BindParameter = (pi, args) =>
GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ??
args.Services?.GetService(pi.ParameterType) ??
(pi.HasDefaultValue ? null :
throw new ArgumentException("No service of the requested type was found.")),
};
Expand All @@ -133,24 +95,13 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
{
ExcludeFromSchema = true,
BindParameter = (pi, args) =>
(GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
(args?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
(pi.HasDefaultValue ? null :
throw new ArgumentException("No service of the requested type was found.")),
};
}

return default;

static RequestContext<GetPromptRequestParams>? GetRequestContext(AIFunctionArguments args)
{
if (args.Context?.TryGetValue(typeof(RequestContext<GetPromptRequestParams>), out var orc) is true &&
orc is RequestContext<GetPromptRequestParams> requestContext)
{
return requestContext;
}

return null;
}
},
};

Expand Down Expand Up @@ -226,14 +177,10 @@ public override async ValueTask<GetPromptResult> GetAsync(
Throw.IfNull(request);
cancellationToken.ThrowIfCancellationRequested();

AIFunctionArguments arguments = new()
{
Services = request.Services,
Context = new Dictionary<object, object?>() { [typeof(RequestContext<GetPromptRequestParams>)] = request }
};
request.Services = new RequestServiceProvider<GetPromptRequestParams>(request, request.Services);
AIFunctionArguments arguments = new() { Services = request.Services };

var argDict = request.Params?.Arguments;
if (argDict is not null)
if (request.Params?.Arguments is { } argDict)
{
foreach (var kvp in argDict)
{
Expand Down
72 changes: 10 additions & 62 deletions src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using ModelContextProtocol.Protocol;
using System.Collections.Concurrent;
using System.ComponentModel;
using System.Diagnostics;
using System.Globalization;
using System.Reflection;
using System.Text;
Expand Down Expand Up @@ -64,8 +65,8 @@ internal sealed class AIFunctionMcpServerResource : McpServerResource
return Create(
AIFunctionFactory.Create(method, args =>
{
var request = (RequestContext<ReadResourceRequestParams>)args.Context![typeof(RequestContext<ReadResourceRequestParams>)]!;
return createTargetFunc(request);
Debug.Assert(args.Services is RequestServiceProvider<ReadResourceRequestParams>, $"The service provider should be a {nameof(RequestServiceProvider<ReadResourceRequestParams>)} for this method to work correctly.");
return createTargetFunc(((RequestServiceProvider<ReadResourceRequestParams>)args.Services!).Request);
}, CreateAIFunctionFactoryOptions(method, options)),
options);
}
Expand All @@ -81,54 +82,15 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
JsonSchemaCreateOptions = options?.SchemaCreateOptions,
ConfigureParameterBinding = pi =>
{
if (pi.ParameterType == typeof(RequestContext<ReadResourceRequestParams>))
{
return new()
{
ExcludeFromSchema = true,
BindParameter = (pi, args) => GetRequestContext(args),
};
}

if (pi.ParameterType == typeof(IMcpServer))
{
return new()
{
ExcludeFromSchema = true,
BindParameter = (pi, args) => GetRequestContext(args)?.Server,
};
}

if (pi.ParameterType == typeof(IProgress<ProgressNotificationValue>))
{
// Bind IProgress<ProgressNotificationValue> to the progress token in the request,
// if there is one. If we can't get one, return a nop progress.
return new()
{
ExcludeFromSchema = true,
BindParameter = (pi, args) =>
{
var requestContent = GetRequestContext(args);
if (requestContent?.Server is { } server &&
requestContent?.Params?.ProgressToken is { } progressToken)
{
return new TokenProgress(server, progressToken);
}

return NullProgress.Instance;
},
};
}

if (options?.Services is { } services &&
services.GetService<IServiceProviderIsService>() is { } ispis &&
ispis.IsService(pi.ParameterType))
if (RequestServiceProvider<ReadResourceRequestParams>.IsAugmentedWith(pi.ParameterType) ||
(options?.Services?.GetService<IServiceProviderIsService>() is { } ispis &&
ispis.IsService(pi.ParameterType)))
{
return new()
{
ExcludeFromSchema = true,
BindParameter = (pi, args) =>
GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ??
args.Services?.GetService(pi.ParameterType) ??
(pi.HasDefaultValue ? null :
throw new ArgumentException("No service of the requested type was found.")),
};
Expand All @@ -140,7 +102,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
{
ExcludeFromSchema = true,
BindParameter = (pi, args) =>
(GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
(args?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
(pi.HasDefaultValue ? null :
throw new ArgumentException("No service of the requested type was found.")),
};
Expand Down Expand Up @@ -172,17 +134,6 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
}

return default;

static RequestContext<ReadResourceRequestParams>? GetRequestContext(AIFunctionArguments args)
{
if (args.Context?.TryGetValue(typeof(RequestContext<ReadResourceRequestParams>), out var rc) is true &&
rc is RequestContext<ReadResourceRequestParams> requestContext)
{
return requestContext;
}

return null;
}
},
};

Expand Down Expand Up @@ -365,11 +316,8 @@ private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resour
}

// Build up the arguments for the AIFunction call, including all of the name/value pairs from the URI.
AIFunctionArguments arguments = new()
{
Services = request.Services,
Context = new Dictionary<object, object?>() { [typeof(RequestContext<ReadResourceRequestParams>)] = request }
};
request.Services = new RequestServiceProvider<ReadResourceRequestParams>(request, request.Services);
AIFunctionArguments arguments = new() { Services = request.Services };

// For templates, populate the arguments from the URI template.
if (match is not null)
Expand Down
76 changes: 11 additions & 65 deletions src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using Microsoft.Extensions.Logging.Abstractions;
using ModelContextProtocol.Protocol;
using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Nodes;
Expand Down Expand Up @@ -64,8 +64,8 @@ internal sealed partial class AIFunctionMcpServerTool : McpServerTool
return Create(
AIFunctionFactory.Create(method, args =>
{
var request = (RequestContext<CallToolRequestParams>)args.Context![typeof(RequestContext<CallToolRequestParams>)]!;
return createTargetFunc(request);
Debug.Assert(args.Services is RequestServiceProvider<CallToolRequestParams>, $"The service provider should be a {nameof(RequestServiceProvider<CallToolRequestParams>)} for this method to work correctly.");
return createTargetFunc(((RequestServiceProvider<CallToolRequestParams>)args.Services!).Request);
}, CreateAIFunctionFactoryOptions(method, options)),
options);
}
Expand All @@ -81,54 +81,15 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
JsonSchemaCreateOptions = options?.SchemaCreateOptions,
ConfigureParameterBinding = pi =>
{
if (pi.ParameterType == typeof(RequestContext<CallToolRequestParams>))
if (RequestServiceProvider<CallToolRequestParams>.IsAugmentedWith(pi.ParameterType) ||
(options?.Services?.GetService<IServiceProviderIsService>() is { } ispis &&
ispis.IsService(pi.ParameterType)))
{
return new()
{
ExcludeFromSchema = true,
BindParameter = (pi, args) => GetRequestContext(args),
};
}

if (pi.ParameterType == typeof(IMcpServer))
{
return new()
{
ExcludeFromSchema = true,
BindParameter = (pi, args) => GetRequestContext(args)?.Server,
};
}

if (pi.ParameterType == typeof(IProgress<ProgressNotificationValue>))
{
// Bind IProgress<ProgressNotificationValue> to the progress token in the request,
// if there is one. If we can't get one, return a nop progress.
return new()
{
ExcludeFromSchema = true,
BindParameter = (pi, args) =>
{
var requestContent = GetRequestContext(args);
if (requestContent?.Server is { } server &&
requestContent?.Params?.ProgressToken is { } progressToken)
{
return new TokenProgress(server, progressToken);
}

return NullProgress.Instance;
},
};
}

if (options?.Services is { } services &&
services.GetService<IServiceProviderIsService>() is { } ispis &&
ispis.IsService(pi.ParameterType))
{
return new()
{
ExcludeFromSchema = true,
BindParameter = (pi, args) =>
GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ??
args.Services?.GetService(pi.ParameterType) ??
(pi.HasDefaultValue ? null :
throw new ArgumentException("No service of the requested type was found.")),
};
Expand All @@ -140,24 +101,13 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
{
ExcludeFromSchema = true,
BindParameter = (pi, args) =>
(GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
(args?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ??
(pi.HasDefaultValue ? null :
throw new ArgumentException("No service of the requested type was found.")),
};
}

return default;

static RequestContext<CallToolRequestParams>? GetRequestContext(AIFunctionArguments args)
{
if (args.Context?.TryGetValue(typeof(RequestContext<CallToolRequestParams>), out var orc) is true &&
orc is RequestContext<CallToolRequestParams> requestContext)
{
return requestContext;
}

return null;
}
},
};

Expand Down Expand Up @@ -260,14 +210,10 @@ public override async ValueTask<CallToolResult> InvokeAsync(
Throw.IfNull(request);
cancellationToken.ThrowIfCancellationRequested();

AIFunctionArguments arguments = new()
{
Services = request.Services,
Context = new Dictionary<object, object?>() { [typeof(RequestContext<CallToolRequestParams>)] = request }
};
request.Services = new RequestServiceProvider<CallToolRequestParams>(request, request.Services);
AIFunctionArguments arguments = new() { Services = request.Services };

var argDict = request.Params?.Arguments;
if (argDict is not null)
if (request.Params?.Arguments is { } argDict)
{
foreach (var kvp in argDict)
{
Expand Down
Loading