diff --git a/docs/concepts/filters.md b/docs/concepts/filters.md new file mode 100644 index 00000000..27462e17 --- /dev/null +++ b/docs/concepts/filters.md @@ -0,0 +1,257 @@ +--- +title: Filters +author: halter73 +description: MCP Server Handler Filters +uid: filters +--- + +# MCP Server Handler Filters + +For each handler type in the MCP Server, there are corresponding `AddXXXFilter` methods in `McpServerBuilderExtensions.cs` that allow you to add filters to the handler pipeline. The filters are stored in `McpServerOptions.Filters` and applied during server configuration. + +## Available Filter Methods + +The following filter methods are available: + +- `AddListResourceTemplatesFilter` - Filter for list resource templates handlers +- `AddListToolsFilter` - Filter for list tools handlers +- `AddCallToolFilter` - Filter for call tool handlers +- `AddListPromptsFilter` - Filter for list prompts handlers +- `AddGetPromptFilter` - Filter for get prompt handlers +- `AddListResourcesFilter` - Filter for list resources handlers +- `AddReadResourceFilter` - Filter for read resource handlers +- `AddCompleteFilter` - Filter for completion handlers +- `AddSubscribeToResourcesFilter` - Filter for resource subscription handlers +- `AddUnsubscribeFromResourcesFilter` - Filter for resource unsubscription handlers +- `AddSetLoggingLevelFilter` - Filter for logging level handlers + +## Usage + +Filters are functions that take a handler and return a new handler, allowing you to wrap the original handler with additional functionality: + +```csharp +services.AddMcpServer() + .WithListToolsHandler(async (context, cancellationToken) => + { + // Your base handler logic + return new ListToolsResult { Tools = GetTools() }; + }) + .AddListToolsFilter(next => async (context, cancellationToken) => + { + // Pre-processing logic + Console.WriteLine("Before handler execution"); + + var result = await next(context, cancellationToken); + + // Post-processing logic + Console.WriteLine("After handler execution"); + return result; + }); +``` + +## Filter Execution Order + +```csharp +services.AddMcpServer() + .WithListToolsHandler(baseHandler) + .AddListToolsFilter(filter1) // Executes first (outermost) + .AddListToolsFilter(filter2) // Executes second + .AddListToolsFilter(filter3); // Executes third (closest to handler) +``` + +Execution flow: `filter1 -> filter2 -> filter3 -> baseHandler -> filter3 -> filter2 -> filter1` + +## Common Use Cases + +### Logging +```csharp +.AddListToolsFilter(next => async (context, cancellationToken) => +{ + Console.WriteLine($"Processing request from {context.Meta.ProgressToken}"); + var result = await next(context, cancellationToken); + Console.WriteLine($"Returning {result.Tools?.Count ?? 0} tools"); + return result; +}); +``` + +### Error Handling +```csharp +.AddCallToolFilter(next => async (context, cancellationToken) => +{ + try + { + return await next(context, cancellationToken); + } + catch (Exception ex) + { + return new CallToolResult + { + Content = new[] { new TextContent { Type = "text", Text = $"Error: {ex.Message}" } }, + IsError = true + }; + } +}); +``` + +### Performance Monitoring +```csharp +.AddListToolsFilter(next => async (context, cancellationToken) => +{ + var stopwatch = Stopwatch.StartNew(); + var result = await next(context, cancellationToken); + stopwatch.Stop(); + Console.WriteLine($"Handler took {stopwatch.ElapsedMilliseconds}ms"); + return result; +}); +``` + +### Caching +```csharp +.AddListResourcesFilter(next => async (context, cancellationToken) => +{ + var cacheKey = $"resources:{context.Params.Cursor}"; + if (cache.TryGetValue(cacheKey, out var cached)) + return cached; + + var result = await next(context, cancellationToken); + cache.Set(cacheKey, result, TimeSpan.FromMinutes(5)); + return result; +}); +``` + +## Built-in Authorization Filters + +When using the ASP.NET Core integration (`ModelContextProtocol.AspNetCore`), authorization filters are automatically configured to support `[Authorize]` and `[AllowAnonymous]` attributes on MCP server tools, prompts, and resources. + +### Authorization Attributes Support + +The MCP server automatically respects the following authorization attributes: + +- **`[Authorize]`** - Requires authentication for access +- **`[Authorize(Roles = "RoleName")]`** - Requires specific roles +- **`[Authorize(Policy = "PolicyName")]`** - Requires specific authorization policies +- **`[AllowAnonymous]`** - Explicitly allows anonymous access (overrides `[Authorize]`) + +### Tool Authorization + +Tools can be decorated with authorization attributes to control access: + +```csharp +[McpServerToolType] +public class WeatherTools +{ + [McpServerTool, Description("Gets public weather data")] + public static string GetWeather(string location) + { + return $"Weather for {location}: Sunny, 25°C"; + } + + [McpServerTool, Description("Gets detailed weather forecast")] + [Authorize] // Requires authentication + public static string GetDetailedForecast(string location) + { + return $"Detailed forecast for {location}: ..."; + } + + [McpServerTool, Description("Manages weather alerts")] + [Authorize(Roles = "Admin")] // Requires Admin role + public static string ManageWeatherAlerts(string alertType) + { + return $"Managing alert: {alertType}"; + } +} +``` + +### Class-Level Authorization + +You can apply authorization at the class level, which affects all tools in the class: + +```csharp +[McpServerToolType] +[Authorize] // All tools require authentication +public class RestrictedTools +{ + [McpServerTool, Description("Restricted tool accessible to authenticated users")] + public static string RestrictedOperation() + { + return "Restricted operation completed"; + } + + [McpServerTool, Description("Public tool accessible to anonymous users")] + [AllowAnonymous] // Overrides class-level [Authorize] + public static string PublicOperation() + { + return "Public operation completed"; + } +} +``` + +### How Authorization Filters Work + +The authorization filters work differently for list operations versus individual operations: + +#### List Operations (ListTools, ListPrompts, ListResources) +For list operations, the filters automatically remove unauthorized items from the results. Users only see tools, prompts, or resources they have permission to access. + +#### Individual Operations (CallTool, GetPrompt, ReadResource) +For individual operations, the filters return authorization errors when access is denied: + +- **Tools**: Returns a `CallToolResult` with `IsError = true` and an error message +- **Prompts**: Throws an `McpException` with "Access forbidden" message +- **Resources**: Throws an `McpException` with "Access forbidden" message + +### Setup Requirements + +To use authorization features, you must configure authentication and authorization in your ASP.NET Core application: + +```csharp +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddAuthentication("Bearer") + .AddJwtBearer(options => { /* JWT configuration */ }) + .AddMcp(options => { /* Resource metadata configuration */ }); +builder.Services.AddAuthorization(); + +builder.Services.AddMcpServer() + .WithHttpTransport() + .WithTools() + .AddCallToolFilter(next => async (context, cancellationToken) => + { + // Custom call tool logic + return await next(context, cancellationToken); + }); + +var app = builder.Build(); + +app.MapMcp(); +app.Run(); +``` + +### Custom Authorization Filters + +You can also create custom authorization filters using the filter methods: + +```csharp +.AddCallToolFilter(next => async (context, cancellationToken) => +{ + // Custom authorization logic + if (context.User?.Identity?.IsAuthenticated != true) + { + return new CallToolResult + { + Content = [new TextContent { Text = "Custom: Authentication required" }], + IsError = true + }; + } + + return await next(context, cancellationToken); +}); +``` + +### RequestContext + +Within filters, you have access to: + +- `context.User` - The current user's `ClaimsPrincipal` +- `context.Services` - The request's service provider for resolving authorization services +- `context.MatchedPrimitive` - The matched tool/prompt/resource with its metadata including authorization attributes via `context.MatchedPrimitive.Metadata` diff --git a/docs/concepts/toc.yml b/docs/concepts/toc.yml index 939f21fc..2f7c930f 100644 --- a/docs/concepts/toc.yml +++ b/docs/concepts/toc.yml @@ -13,3 +13,7 @@ items: items: - name: Logging uid: logging +- name: Server Features + items: + - name: Filters + uid: filters \ No newline at end of file diff --git a/samples/AspNetCoreMcpServer/Properties/launchSettings.json b/samples/AspNetCoreMcpServer/Properties/launchSettings.json index a5b8a22f..6670029e 100644 --- a/samples/AspNetCoreMcpServer/Properties/launchSettings.json +++ b/samples/AspNetCoreMcpServer/Properties/launchSettings.json @@ -7,7 +7,7 @@ "applicationUrl": "http://localhost:3001", "environmentVariables": { "ASPNETCORE_ENVIRONMENT": "Development", - "OTEL_SERVICE_NAME": "aspnetcore-mcp-server", + "OTEL_SERVICE_NAME": "aspnetcore-mcp-server" } }, "https": { @@ -16,7 +16,7 @@ "applicationUrl": "https://localhost:7133;http://localhost:3001", "environmentVariables": { "ASPNETCORE_ENVIRONMENT": "Development", - "OTEL_SERVICE_NAME": "aspnetcore-mcp-server", + "OTEL_SERVICE_NAME": "aspnetcore-mcp-server" } } } diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs new file mode 100644 index 00000000..7d2c30f2 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs @@ -0,0 +1,222 @@ +using System.Security.Claims; +using Microsoft.AspNetCore.Authorization; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Evaluates authorization policies from endpoint metadata. +/// +internal sealed class AuthorizationFilterSetup(IAuthorizationPolicyProvider? policyProvider = null) : IConfigureOptions +{ + public void Configure(McpServerOptions options) + { + ConfigureListToolsFilter(options); + ConfigureCallToolFilter(options); + + ConfigureListResourcesFilter(options); + ConfigureListResourceTemplatesFilter(options); + ConfigureReadResourceFilter(options); + + ConfigureListPromptsFilter(options); + ConfigureGetPromptFilter(options); + } + + private void ConfigureListToolsFilter(McpServerOptions options) + { + options.Filters.ListToolsFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + await FilterAuthorizedItemsAsync( + result.Tools, static tool => tool.McpServerTool, + context.User, context.Services, context); + return result; + }); + } + + private void ConfigureCallToolFilter(McpServerOptions options) + { + options.Filters.CallToolFilters.Add(next => async (context, cancellationToken) => + { + var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); + if (!authResult.Succeeded) + { + return new CallToolResult + { + Content = [new TextContentBlock { Text = "Access forbidden: This tool requires authorization." }], + IsError = true + }; + } + + return await next(context, cancellationToken); + }); + } + + private void ConfigureListResourcesFilter(McpServerOptions options) + { + options.Filters.ListResourcesFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + await FilterAuthorizedItemsAsync( + result.Resources, static resource => resource.McpServerResource, + context.User, context.Services, context); + return result; + }); + } + + private void ConfigureListResourceTemplatesFilter(McpServerOptions options) + { + options.Filters.ListResourceTemplatesFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + await FilterAuthorizedItemsAsync( + result.ResourceTemplates, static resourceTemplate => resourceTemplate.McpServerResource, + context.User, context.Services, context); + return result; + }); + } + + private void ConfigureReadResourceFilter(McpServerOptions options) + { + options.Filters.ReadResourceFilters.Add(next => async (context, cancellationToken) => + { + var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); + if (!authResult.Succeeded) + { + throw new McpException("Access forbidden: This resource requires authorization.", McpErrorCode.InvalidRequest); + } + + return await next(context, cancellationToken); + }); + } + + private void ConfigureListPromptsFilter(McpServerOptions options) + { + options.Filters.ListPromptsFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + await FilterAuthorizedItemsAsync( + result.Prompts, static prompt => prompt.McpServerPrompt, + context.User, context.Services, context); + return result; + }); + } + + private void ConfigureGetPromptFilter(McpServerOptions options) + { + options.Filters.GetPromptFilters.Add(next => async (context, cancellationToken) => + { + var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); + if (!authResult.Succeeded) + { + throw new McpException("Access forbidden: This prompt requires authorization.", McpErrorCode.InvalidRequest); + } + + return await next(context, cancellationToken); + }); + } + + /// + /// Filters a collection of items based on authorization policies in their metadata. + /// For list operations where we need to filter results by authorization. + /// + private async ValueTask FilterAuthorizedItemsAsync(IList items, Func primitiveSelector, + ClaimsPrincipal? user, IServiceProvider? requestServices, object context) + { + for (int i = items.Count - 1; i >= 0; i--) + { + var authorizationResult = await GetAuthorizationResultAsync( + user, primitiveSelector(items[i]), requestServices, context); + + if (!authorizationResult.Succeeded) + { + items.RemoveAt(i); + } + } + } + + private async ValueTask GetAuthorizationResultAsync( + ClaimsPrincipal? user, IMcpServerPrimitive? primitive, IServiceProvider? requestServices, object context) + { + // If no primitive was found for this request or there is IAllowAnonymous metadata anywhere on the class or method, + // the request should go through as normal. + if (primitive is null || primitive.Metadata.Any(static m => m is IAllowAnonymous)) + { + return AuthorizationResult.Success(); + } + + // There are no [Authorize] style attributes applied to the method or containing class. Any fallback policies + // have already been enforced at the HTTP request level by the ASP.NET Core authorization middleware. + if (!primitive.Metadata.Any(static m => m is IAuthorizeData or AuthorizationPolicy or IAuthorizationRequirementData)) + { + return AuthorizationResult.Success(); + } + + if (policyProvider is null) + { + throw new InvalidOperationException($"You must call AddAuthorization() because an authorization related attribute was found on {primitive.Id}"); + } + + // TODO: Cache policy lookup. We would probably use a singleton (not-static) ConditionalWeakTable. + var policy = await CombineAsync(policyProvider, primitive.Metadata); + if (policy is null) + { + return AuthorizationResult.Success(); + } + + if (requestServices is null) + { + // The IAuthorizationPolicyProvider service must be non-null to get to this line, so it's very unexpected for RequestContext.Services to not be set. + throw new InvalidOperationException("RequestContext.Services is not set! The IMcpServer must be initialized with a non-null IServiceProvider."); + } + + // ASP.NET Core's AuthorizationMiddleware resolves the IAuthorizationService from scoped request services, so we do the same. + var authService = requestServices.GetRequiredService(); + return await authService.AuthorizeAsync(user ?? new ClaimsPrincipal(new ClaimsIdentity()), context, policy); + } + + /// + /// Combines authorization policies and requirements from endpoint metadata without considering . + /// + /// The authorization policy provider. + /// The endpoint metadata collection. + /// The combined authorization policy, or null if no authorization is required. + private static async ValueTask CombineAsync(IAuthorizationPolicyProvider policyProvider, IReadOnlyList endpointMetadata) + { + // https://github.com/dotnet/aspnetcore/issues/63365 tracks adding this as public API to AuthorizationPolicy itself. + // Copied from https://github.com/dotnet/aspnetcore/blob/9f2977bf9cfb539820983bda3bedf81c8cda9f20/src/Security/Authorization/Policy/src/AuthorizationMiddleware.cs#L116-L138 + var authorizeData = endpointMetadata.OfType(); + var policies = endpointMetadata.OfType(); + + var policy = await AuthorizationPolicy.CombineAsync(policyProvider, authorizeData, policies); + + AuthorizationPolicyBuilder? reqPolicyBuilder = null; + + foreach (var m in endpointMetadata) + { + if (m is not IAuthorizationRequirementData requirementData) + { + continue; + } + + reqPolicyBuilder ??= new AuthorizationPolicyBuilder(); + foreach (var requirement in requirementData.GetRequirements()) + { + reqPolicyBuilder.AddRequirements(requirement); + } + } + + if (reqPolicyBuilder is null) + { + return policy; + } + + // Combine policy with requirements or just use requirements if no policy + return (policy is null) + ? reqPolicyBuilder.Build() + : AuthorizationPolicy.Combine(policy, reqPolicyBuilder.Build()); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index 2d6b29fd..70835a83 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Options; using ModelContextProtocol.AspNetCore; using ModelContextProtocol.Server; @@ -29,6 +30,9 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder builder.Services.AddHostedService(); builder.Services.AddDataProtection(); + // Register authorization filter setup for automatic filter configuration + builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton, AuthorizationFilterSetup>()); + if (configureOptions is not null) { builder.Services.Configure(configureOptions); diff --git a/src/ModelContextProtocol.AspNetCore/SseHandler.cs b/src/ModelContextProtocol.AspNetCore/SseHandler.cs index 6ed72fb6..fffdd45e 100644 --- a/src/ModelContextProtocol.AspNetCore/SseHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/SseHandler.cs @@ -2,7 +2,6 @@ using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using System.Collections.Concurrent; using System.Diagnostics; @@ -97,7 +96,7 @@ public async Task HandleMessageRequestAsync(HttpContext context) return; } - var message = (JsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), context.RequestAborted); + var message = await StreamableHttpHandler.ReadJsonRpcMessageAsync(context); if (message is null) { await Results.BadRequest("No message in request body.").ExecuteAsync(context); diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index bfbd805d..d3db7e96 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -8,7 +8,6 @@ using ModelContextProtocol.AspNetCore.Stateless; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; -using System.IO.Pipelines; using System.Security.Claims; using System.Security.Cryptography; using System.Text.Json; @@ -26,6 +25,8 @@ internal sealed class StreamableHttpHandler( IServiceProvider applicationServices) { private const string McpSessionIdHeaderName = "Mcp-Session-Id"; + + private static readonly JsonTypeInfo s_messageTypeInfo = GetRequiredJsonTypeInfo(); private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions.Value; @@ -55,8 +56,17 @@ await WriteJsonRpcErrorAsync(context, await using var _ = await session.AcquireReferenceAsync(context.RequestAborted); + var message = await ReadJsonRpcMessageAsync(context); + if (message is null) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: The POST body did not contain a valid JSON-RPC message.", + StatusCodes.Status400BadRequest); + return; + } + InitializeSseResponse(context); - var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted); + var wroteResponse = await session.Transport.HandlePostRequest(message, context.Response.Body, context.RequestAborted); if (!wroteResponse) { // We wound up writing nothing, so there should be no Content-Type response header. @@ -264,6 +274,24 @@ internal static string MakeNewSessionId() return WebEncoders.Base64UrlEncode(buffer); } + internal static async Task ReadJsonRpcMessageAsync(HttpContext context) + { + // Implementation for reading a JSON-RPC message from the request body + var message = await context.Request.ReadFromJsonAsync(s_messageTypeInfo, context.RequestAborted); + + if (context.User?.Identity?.IsAuthenticated ?? false && message is not null) + { + // We get weird CS0131 errors only on the Windows build GitHub Action if we use "message?.Context = ..." + // https://productionresultssa0.blob.core.windows.net/actions-results/f2218319-0fdd-473b-891d-06e5a4a0f826/workflow-job-run-98901492-cf7c-5406-85d9-0f7057e0516f/logs/job/job-logs.txt?rsct=text%2Fplain&se=2025-08-26T16%3A06%3A31Z&sig=RvEQo6DgrpDUW9mnbgDvf6FVDAAoHKzk9rsDdcPxOhw%3D&ske=2025-08-27T03%3A39%3A43Z&skoid=ca7593d4-ee42-46cd-af88-8b886a2f84eb&sks=b&skt=2025-08-26T15%3A39%3A43Z&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skv=2025-05-05&sp=r&spr=https&sr=b&st=2025-08-26T15%3A56%3A26Z&sv=2025-05-05 + message!.Context = new() + { + User = context.User, + }; + } + + return message; + } + private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttpServerTransport transport) { transport.OnInitRequestReceived = initRequestParams => @@ -304,17 +332,11 @@ internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session return null; } - private static JsonTypeInfo GetRequiredJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); + internal static JsonTypeInfo GetRequiredJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); private static bool MatchesApplicationJsonMediaType(MediaTypeHeaderValue acceptHeaderValue) => acceptHeaderValue.MatchesMediaType("application/json"); private static bool MatchesTextEventStreamMediaType(MediaTypeHeaderValue acceptHeaderValue) => acceptHeaderValue.MatchesMediaType("text/event-stream"); - - private sealed class HttpDuplexPipe(HttpContext context) : IDuplexPipe - { - public PipeReader Input => context.Request.BodyReader; - public PipeWriter Output => context.Response.BodyWriter; - } } diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs index ffeafada..7c8a3195 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs @@ -100,15 +100,17 @@ public async ValueTask DisposeAsync() try { - await _disposeCts.CancelAsync(); - try { + // Dispose transport first to complete the incoming MessageReader gracefully and avoid a potentially unnecessary OCE. + await transport.DisposeAsync(); + await _disposeCts.CancelAsync(); + await ServerRunTask; } finally { - await DisposeServerThenTransportAsync(); + await server.DisposeAsync(); } } catch (OperationCanceledException) @@ -124,18 +126,6 @@ public async ValueTask DisposeAsync() } } - private async ValueTask DisposeServerThenTransportAsync() - { - try - { - await server.DisposeAsync(); - } - finally - { - await transport.DisposeAsync(); - } - } - private sealed class UnreferenceDisposable(StreamableHttpSession session) : IAsyncDisposable { public ValueTask DisposeAsync() diff --git a/src/ModelContextProtocol.Core/McpSession.cs b/src/ModelContextProtocol.Core/McpSession.cs index da954205..75215fee 100644 --- a/src/ModelContextProtocol.Core/McpSession.cs +++ b/src/ModelContextProtocol.Core/McpSession.cs @@ -116,14 +116,14 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken) LogMessageRead(EndpointName, message.GetType().Name); // Fire and forget the message handling to avoid blocking the transport. - if (message.ExecutionContext is null) + if (message.Context?.ExecutionContext is null) { _ = ProcessMessageAsync(); } else { // Flow the execution context from the HTTP request corresponding to this message if provided. - ExecutionContext.Run(message.ExecutionContext, _ => _ = ProcessMessageAsync(), null); + ExecutionContext.Run(message.Context.ExecutionContext, _ => _ = ProcessMessageAsync(), null); } async Task ProcessMessageAsync() @@ -176,13 +176,15 @@ ex is OperationCanceledException && Message = "An error occurred.", }; - await SendMessageAsync(new JsonRpcError + var errorMessage = new JsonRpcError { Id = request.Id, JsonRpc = "2.0", Error = detail, - RelatedTransport = request.RelatedTransport, - }, cancellationToken).ConfigureAwait(false); + Context = new JsonRpcMessageContext { RelatedTransport = request.Context?.RelatedTransport }, + }; + + await SendMessageAsync(errorMessage, cancellationToken).ConfigureAwait(false); } else if (ex is not OperationCanceledException) { @@ -329,7 +331,7 @@ await SendMessageAsync(new JsonRpcResponse { Id = request.Id, Result = result, - RelatedTransport = request.RelatedTransport, + Context = request.Context, }, cancellationToken).ConfigureAwait(false); return result; @@ -349,7 +351,7 @@ private CancellationTokenRegistration RegisterCancellation(CancellationToken can { Method = NotificationMethods.CancelledNotification, Params = JsonSerializer.SerializeToNode(new CancelledNotificationParams { RequestId = state.Item2.Id }, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams), - RelatedTransport = state.Item2.RelatedTransport, + Context = new JsonRpcMessageContext { RelatedTransport = state.Item2.Context?.RelatedTransport }, }); }, Tuple.Create(this, request)); } @@ -527,7 +529,7 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can // Streamable HTTP transport where the specification states that the server SHOULD include JSON-RPC responses in // the HTTP response body for the POST request containing the corresponding JSON-RPC request. private Task SendToRelatedTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken) - => (message.RelatedTransport ?? _transport).SendMessageAsync(message, cancellationToken); + => (message.Context?.RelatedTransport ?? _transport).SendMessageAsync(message, cancellationToken); private static CancelledNotificationParams? GetCancelledNotificationParams(JsonNode? notificationParams) { diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs index b3176937..ae15453d 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs @@ -1,5 +1,6 @@ using ModelContextProtocol.Server; using System.ComponentModel; +using System.Security.Claims; using System.Text.Json; using System.Text.Json.Serialization; @@ -29,28 +30,21 @@ private protected JsonRpcMessage() public string JsonRpc { get; init; } = "2.0"; /// - /// Gets or sets the transport the was received on or should be sent over. + /// Gets or sets the contextual information for this JSON-RPC message. /// /// - /// This is used to support the Streamable HTTP transport where the specification states that the server - /// SHOULD include JSON-RPC responses in the HTTP response body for the POST request containing - /// the corresponding JSON-RPC request. It may be for other transports. + /// This property contains transport-specific and runtime context information that accompanies + /// JSON-RPC messages but is not serialized as part of the JSON-RPC payload. This includes + /// transport references, execution context, and authenticated user information. /// - [JsonIgnore] - public ITransport? RelatedTransport { get; set; } - - /// - /// Gets or sets the that should be used to run any handlers - /// /// - /// This is used to support the Streamable HTTP transport in its default stateful mode. In this mode, - /// the outlives the initial HTTP request context it was created on, and new - /// JSON-RPC messages can originate from future HTTP requests. This allows the transport to flow the - /// context with the JSON-RPC message. This is particularly useful for enabling IHttpContextAccessor - /// in tool calls. + /// This property should only be set when implementing a custom + /// that needs to pass additional per-message context or to pass a + /// to + /// or . /// [JsonIgnore] - public ExecutionContext? ExecutionContext { get; set; } + public JsonRpcMessageContext? Context { get; set; } /// /// Provides a for messages, diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs new file mode 100644 index 00000000..30b6745a --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs @@ -0,0 +1,61 @@ +using ModelContextProtocol.Server; +using System.Security.Claims; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Contains contextual information for JSON-RPC messages that is not part of the JSON-RPC protocol specification. +/// +/// +/// This class holds transport-specific and runtime context information that accompanies JSON-RPC messages +/// but is not serialized as part of the JSON-RPC payload. This includes transport references, execution context, +/// and authenticated user information. +/// +public class JsonRpcMessageContext +{ + /// + /// Gets or sets the transport the was received on or should be sent over. + /// + /// + /// This is used to support the Streamable HTTP transport where the specification states that the server + /// SHOULD include JSON-RPC responses in the HTTP response body for the POST request containing + /// the corresponding JSON-RPC request. It may be for other transports. + /// + public ITransport? RelatedTransport { get; set; } + + /// + /// Gets or sets the that should be used to run any handlers + /// + /// + /// This is used to support the Streamable HTTP transport in its default stateful mode. In this mode, + /// the outlives the initial HTTP request context it was created on, and new + /// JSON-RPC messages can originate from future HTTP requests. This allows the transport to flow the + /// context with the JSON-RPC message. This is particularly useful for enabling IHttpContextAccessor + /// in tool calls. + /// + public ExecutionContext? ExecutionContext { get; set; } + + /// + /// Gets or sets the authenticated user associated with this JSON-RPC message. + /// + /// + /// + /// This property contains the representing the authenticated user + /// who initiated this JSON-RPC message. This enables request handlers to access user identity + /// and authorization information without requiring dependency on HTTP context accessors + /// or other HTTP-specific abstractions. + /// + /// + /// The user information is automatically populated by the transport layer when processing + /// incoming HTTP requests in ASP.NET Core scenarios. For other transport types or scenarios + /// where user authentication is not applicable, this property may be . + /// + /// + /// This property is particularly useful in the Streamable HTTP transport where JSON-RPC messages + /// may outlive the original HTTP request context, allowing user identity to be preserved + /// throughout the message processing pipeline. + /// + /// + public ClaimsPrincipal? User { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcRequest.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcRequest.cs index ed6c8982..e80b25f4 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcRequest.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcRequest.cs @@ -9,7 +9,7 @@ namespace ModelContextProtocol.Protocol; /// /// Requests are messages that require a response from the receiver. Each request includes a unique ID /// that will be included in the corresponding response message (either a success response or an error). -/// +/// /// The receiver of a request message is expected to execute the specified method with the provided parameters /// and return either a with the result, or a /// if the method execution fails. @@ -36,7 +36,7 @@ internal JsonRpcRequest WithId(RequestId id) Id = id, Method = Method, Params = Params, - RelatedTransport = RelatedTransport, + Context = Context, }; } } diff --git a/src/ModelContextProtocol.Core/Protocol/Prompt.cs b/src/ModelContextProtocol.Core/Protocol/Prompt.cs index 1a500406..fcd3053f 100644 --- a/src/ModelContextProtocol.Core/Protocol/Prompt.cs +++ b/src/ModelContextProtocol.Core/Protocol/Prompt.cs @@ -1,5 +1,6 @@ using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -59,4 +60,10 @@ public sealed class Prompt : IBaseMetadata /// [JsonPropertyName("_meta")] public JsonObject? Meta { get; set; } + + /// + /// Gets or sets the callable server prompt corresponding to this metadata if any. + /// + [JsonIgnore] + public McpServerPrompt? McpServerPrompt { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/Resource.cs b/src/ModelContextProtocol.Core/Protocol/Resource.cs index 63dce7fd..1b8a0e9c 100644 --- a/src/ModelContextProtocol.Core/Protocol/Resource.cs +++ b/src/ModelContextProtocol.Core/Protocol/Resource.cs @@ -1,5 +1,6 @@ using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -87,4 +88,10 @@ public sealed class Resource : IBaseMetadata /// [JsonPropertyName("_meta")] public JsonObject? Meta { get; init; } + + /// + /// Gets or sets the callable server resource corresponding to this metadata if any. + /// + [JsonIgnore] + public McpServerResource? McpServerResource { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/ResourceTemplate.cs b/src/ModelContextProtocol.Core/Protocol/ResourceTemplate.cs index d2959d18..f0f29498 100644 --- a/src/ModelContextProtocol.Core/Protocol/ResourceTemplate.cs +++ b/src/ModelContextProtocol.Core/Protocol/ResourceTemplate.cs @@ -1,5 +1,6 @@ using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -84,6 +85,12 @@ public sealed class ResourceTemplate : IBaseMetadata [JsonIgnore] public bool IsTemplated => UriTemplate.Contains('{'); + /// + /// Gets or sets the callable server resource corresponding to this metadata if any. + /// + [JsonIgnore] + public McpServerResource? McpServerResource { get; set; } + /// Converts the into a . /// A if is ; otherwise, . public Resource? AsResource() @@ -102,6 +109,7 @@ public sealed class ResourceTemplate : IBaseMetadata MimeType = MimeType, Annotations = Annotations, Meta = Meta, + McpServerResource = McpServerResource, }; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/Tool.cs b/src/ModelContextProtocol.Core/Protocol/Tool.cs index c09598ca..1c471669 100644 --- a/src/ModelContextProtocol.Core/Protocol/Tool.cs +++ b/src/ModelContextProtocol.Core/Protocol/Tool.cs @@ -1,6 +1,7 @@ using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -43,7 +44,7 @@ public sealed class Tool : IBaseMetadata /// if an invalid schema is provided. /// /// - /// The schema typically defines the properties (parameters) that the tool accepts, + /// The schema typically defines the properties (parameters) that the tool accepts, /// their types, and which ones are required. This helps AI models understand /// how to structure their calls to the tool. /// @@ -52,9 +53,9 @@ public sealed class Tool : IBaseMetadata /// /// [JsonPropertyName("inputSchema")] - public JsonElement InputSchema - { - get => field; + public JsonElement InputSchema + { + get => field; set { if (!McpJsonUtilities.IsValidMcpToolSchema(value)) @@ -114,4 +115,10 @@ public JsonElement? OutputSchema /// [JsonPropertyName("_meta")] public JsonObject? Meta { get; set; } + + /// + /// Gets or sets the callable server tool corresponding to this metadata if any. + /// + [JsonIgnore] + public McpServerTool? McpServerTool { get; set; } } diff --git a/src/ModelContextProtocol.Core/RequestHandlers.cs b/src/ModelContextProtocol.Core/RequestHandlers.cs index 854a4bdd..0c2b54fa 100644 --- a/src/ModelContextProtocol.Core/RequestHandlers.cs +++ b/src/ModelContextProtocol.Core/RequestHandlers.cs @@ -23,13 +23,13 @@ internal sealed class RequestHandlers : Dictionary /// - /// The handler function receives the deserialized request object and a cancellation token, and should return - /// a response object that will be serialized back to the client. + /// The handler function receives the deserialized request object, the full JSON-RPC request, and a cancellation token, + /// and should return a response object that will be serialized back to the client. /// /// public void Set( string method, - Func> handler, + Func> handler, JsonTypeInfo requestTypeInfo, JsonTypeInfo responseTypeInfo) { @@ -41,7 +41,7 @@ public void Set( this[method] = async (request, cancellationToken) => { TRequest? typedRequest = JsonSerializer.Deserialize(request.Params, requestTypeInfo); - object? result = await handler(typedRequest, request.RelatedTransport, cancellationToken).ConfigureAwait(false); + object? result = await handler(typedRequest, request, cancellationToken).ConfigureAwait(false); return JsonSerializer.SerializeToNode(result, responseTypeInfo); }; } diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs index d651d7ee..ef068c55 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs @@ -11,6 +11,7 @@ namespace ModelContextProtocol.Server; /// Provides an that's implemented via an . internal sealed class AIFunctionMcpServerPrompt : McpServerPrompt { + private readonly IReadOnlyList _metadata; /// /// Creates an instance for a method, specified via a instance. /// @@ -136,7 +137,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( Arguments = args, }; - return new AIFunctionMcpServerPrompt(function, prompt); + return new AIFunctionMcpServerPrompt(function, prompt, options?.Metadata ?? []); } private static McpServerPromptCreateOptions DeriveOptions(MethodInfo method, McpServerPromptCreateOptions? options) @@ -154,6 +155,9 @@ private static McpServerPromptCreateOptions DeriveOptions(MethodInfo method, Mcp newOptions.Description ??= descAttr.Description; } + // Set metadata if not already provided + newOptions.Metadata ??= AIFunctionMcpServerTool.CreateMetadata(method); + return newOptions; } @@ -161,15 +165,20 @@ private static McpServerPromptCreateOptions DeriveOptions(MethodInfo method, Mcp internal AIFunction AIFunction { get; } /// Initializes a new instance of the class. - private AIFunctionMcpServerPrompt(AIFunction function, Prompt prompt) + private AIFunctionMcpServerPrompt(AIFunction function, Prompt prompt, IReadOnlyList metadata) { AIFunction = function; ProtocolPrompt = prompt; + ProtocolPrompt.McpServerPrompt = this; + _metadata = metadata; } /// public override Prompt ProtocolPrompt { get; } + /// + public override IReadOnlyList Metadata => _metadata; + /// public override async ValueTask GetAsync( RequestContext request, CancellationToken cancellationToken = default) @@ -177,7 +186,7 @@ public override async ValueTask GetAsync( Throw.IfNull(request); cancellationToken.ThrowIfCancellationRequested(); - request.Services = new RequestServiceProvider(request, request.Services); + request.Services = new RequestServiceProvider(request); AIFunctionArguments arguments = new() { Services = request.Services }; if (request.Params?.Arguments is { } argDict) diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs index a8b0d248..69b8deb8 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Globalization; using System.Reflection; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Text.RegularExpressions; @@ -17,6 +18,7 @@ internal sealed class AIFunctionMcpServerResource : McpServerResource { private readonly Regex? _uriParser; private readonly string[] _templateVariableNames = []; + private readonly IReadOnlyList _metadata; /// /// Creates an instance for a method, specified via a instance. @@ -218,7 +220,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( MimeType = options?.MimeType ?? "application/octet-stream", }; - return new AIFunctionMcpServerResource(function, resource); + return new AIFunctionMcpServerResource(function, resource, options?.Metadata ?? []); } private static McpServerResourceCreateOptions DeriveOptions(MemberInfo member, McpServerResourceCreateOptions? options) @@ -238,6 +240,12 @@ private static McpServerResourceCreateOptions DeriveOptions(MemberInfo member, M newOptions.Description ??= descAttr.Description; } + // Set metadata if not already provided and the member is a MethodInfo + if (member is MethodInfo method) + { + newOptions.Metadata ??= AIFunctionMcpServerTool.CreateMetadata(method); + } + return newOptions; } @@ -270,11 +278,13 @@ private static string DeriveUriTemplate(string name, AIFunction function) internal AIFunction AIFunction { get; } /// Initializes a new instance of the class. - private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resourceTemplate) + private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resourceTemplate, IReadOnlyList metadata) { AIFunction = function; ProtocolResourceTemplate = resourceTemplate; + ProtocolResourceTemplate.McpServerResource = this; ProtocolResource = resourceTemplate.AsResource(); + _metadata = metadata; if (ProtocolResource is null) { @@ -289,6 +299,9 @@ private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resour /// public override Resource? ProtocolResource { get; } + /// + public override IReadOnlyList Metadata => _metadata; + /// public override async ValueTask ReadAsync( RequestContext request, CancellationToken cancellationToken = default) @@ -316,7 +329,7 @@ private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resour } // Build up the arguments for the AIFunction call, including all of the name/value pairs from the URI. - request.Services = new RequestServiceProvider(request, request.Services); + request.Services = new RequestServiceProvider(request); AIFunctionArguments arguments = new() { Services = request.Services }; // For templates, populate the arguments from the URI template. diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs index 664ede5a..cb475848 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs @@ -1,7 +1,5 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; using System.ComponentModel; using System.Diagnostics; @@ -15,8 +13,8 @@ namespace ModelContextProtocol.Server; /// Provides an that's implemented via an . internal sealed partial class AIFunctionMcpServerTool : McpServerTool { - private readonly ILogger _logger; private readonly bool _structuredOutputRequiresWrapping; + private readonly IReadOnlyList _metadata; /// /// Creates an instance for a method, specified via a instance. @@ -26,7 +24,7 @@ internal sealed partial class AIFunctionMcpServerTool : McpServerTool McpServerToolCreateOptions? options) { Throw.IfNull(method); - + options = DeriveOptions(method.Method, options); return Create(method.Method, method.Target, options); @@ -146,7 +144,7 @@ options.OpenWorld is not null || } } - return new AIFunctionMcpServerTool(function, tool, options?.Services, structuredOutputRequiresWrapping); + return new AIFunctionMcpServerTool(function, tool, options?.Services, structuredOutputRequiresWrapping, options?.Metadata ?? []); } private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpServerToolCreateOptions? options) @@ -186,6 +184,9 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe newOptions.Description ??= descAttr.Description; } + // Set metadata if not already provided + newOptions.Metadata ??= CreateMetadata(method); + return newOptions; } @@ -193,17 +194,22 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe internal AIFunction AIFunction { get; } /// Initializes a new instance of the class. - private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider? serviceProvider, bool structuredOutputRequiresWrapping) + private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider? serviceProvider, bool structuredOutputRequiresWrapping, IReadOnlyList metadata) { AIFunction = function; ProtocolTool = tool; - _logger = serviceProvider?.GetService()?.CreateLogger() ?? (ILogger)NullLogger.Instance; + ProtocolTool.McpServerTool = this; + _structuredOutputRequiresWrapping = structuredOutputRequiresWrapping; + _metadata = metadata; } /// public override Tool ProtocolTool { get; } + /// + public override IReadOnlyList Metadata => _metadata; + /// public override async ValueTask InvokeAsync( RequestContext request, CancellationToken cancellationToken = default) @@ -211,7 +217,7 @@ public override async ValueTask InvokeAsync( Throw.IfNull(request); cancellationToken.ThrowIfCancellationRequested(); - request.Services = new RequestServiceProvider(request, request.Services); + request.Services = new RequestServiceProvider(request); AIFunctionArguments arguments = new() { Services = request.Services }; if (request.Params?.Arguments is { } argDict) @@ -223,24 +229,7 @@ public override async ValueTask InvokeAsync( } object? result; - try - { - result = await AIFunction.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false); - } - catch (Exception e) when (e is not OperationCanceledException) - { - ToolCallError(request.Params?.Name ?? string.Empty, e); - - string errorMessage = e is McpException ? - $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : - $"An error occurred invoking '{request.Params?.Name}'."; - - return new() - { - IsError = true, - Content = [new TextContentBlock { Text = errorMessage }], - }; - } + result = await AIFunction.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false); JsonNode? structuredContent = CreateStructuredResponse(result); return result switch @@ -257,27 +246,27 @@ public override async ValueTask InvokeAsync( Content = [], StructuredContent = structuredContent, }, - + string text => new() { Content = [new TextContentBlock { Text = text }], StructuredContent = structuredContent, }, - + ContentBlock content => new() { Content = [content], StructuredContent = structuredContent, }, - + IEnumerable contentItems => ConvertAIContentEnumerableToCallToolResult(contentItems, structuredContent), - + IEnumerable contents => new() { Content = [.. contents], StructuredContent = structuredContent, }, - + CallToolResult callToolResponse => callToolResponse, _ => new() @@ -336,6 +325,26 @@ static bool IsAsyncMethod(MethodInfo method) } } + /// Creates metadata from attributes on the specified method and its declaring class, with the MethodInfo as the first item. + internal static IReadOnlyList CreateMetadata(MethodInfo method) + { + // Add the MethodInfo to the start of the metadata similar to what RouteEndpointDataSource does for minimal endpoints. + List metadata = [method]; + + // Add class-level attributes first, since those are less specific. + if (method.DeclaringType is not null) + { + metadata.AddRange(method.DeclaringType.GetCustomAttributes()); + } + + // Add method-level attributes second, since those are more specific. + // When metadata conflicts, later metadata usually takes precedence with exceptions for metadata like + // IAllowAnonymous which always take precedence over IAuthorizeData no matter the order. + metadata.AddRange(method.GetCustomAttributes()); + + return metadata.AsReadOnly(); + } + /// Regex that flags runs of characters other than ASCII digits or letters. #if NET [GeneratedRegex("[^0-9A-Za-z]+")] @@ -446,7 +455,4 @@ private static CallToolResult ConvertAIContentEnumerableToCallToolResult(IEnumer IsError = allErrorContent && hasAny }; } - - [LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")] - private partial void ToolCallError(string toolName, Exception exception); } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs index d286d1ef..78346c39 100644 --- a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs @@ -22,15 +22,25 @@ internal sealed class DestinationBoundMcpServer(McpServer server, ITransport? tr public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { - Debug.Assert(message.RelatedTransport is null); - message.RelatedTransport = transport; + if (message.Context is not null) + { + throw new ArgumentException("Only transports can provide a JsonRpcMessageContext."); + } + + message.Context = new JsonRpcMessageContext(); + message.Context.RelatedTransport = transport; return server.SendMessageAsync(message, cancellationToken); } public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) { - Debug.Assert(request.RelatedTransport is null); - request.RelatedTransport = transport; + if (request.Context is not null) + { + throw new ArgumentException("Only transports can provide a JsonRpcMessageContext."); + } + + request.Context = new JsonRpcMessageContext(); + request.Context.RelatedTransport = transport; return server.SendRequestAsync(request, cancellationToken); } } diff --git a/src/ModelContextProtocol.Core/Server/IMcpServerPrimitive.cs b/src/ModelContextProtocol.Core/Server/IMcpServerPrimitive.cs index 597fdec9..f3ec6221 100644 --- a/src/ModelContextProtocol.Core/Server/IMcpServerPrimitive.cs +++ b/src/ModelContextProtocol.Core/Server/IMcpServerPrimitive.cs @@ -7,4 +7,13 @@ public interface IMcpServerPrimitive { /// Gets the unique identifier of the primitive. string Id { get; } + + /// + /// Gets the metadata for this primitive instance. + /// + /// + /// Contains attributes from the associated MethodInfo and declaring class (if any), + /// with class-level attributes appearing before method-level attributes. + /// + IReadOnlyList Metadata { get; } } diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index 6c5858f9..0056b1ae 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -7,7 +7,7 @@ namespace ModelContextProtocol.Server; /// -internal sealed class McpServer : McpEndpoint, IMcpServer +internal sealed partial class McpServer : McpEndpoint, IMcpServer { internal static Implementation DefaultImplementation { get; } = new() { @@ -195,9 +195,12 @@ private void ConfigureCompletion(McpServerOptions options) return; } + var completeHandler = completionsCapability.CompleteHandler ?? (static async (_, __) => new CompleteResult()); + completeHandler = BuildFilterPipeline(completeHandler, options.Filters.CompleteFilters); + ServerCapabilities.Completions = new() { - CompleteHandler = completionsCapability.CompleteHandler ?? (static async (_, __) => new CompleteResult()) + CompleteHandler = completeHandler }; SetHandler( @@ -279,30 +282,14 @@ await originalListResourceTemplatesHandler(request, cancellationToken).Configure var originalReadResourceHandler = readResourceHandler; readResourceHandler = async (request, cancellationToken) => { - if (request.Params?.Uri is string uri) + if (request.MatchedPrimitive is McpServerResource matchedResource) { - // First try an O(1) lookup by exact match. - if (resources.TryGetPrimitive(uri, out var resource)) + if (await matchedResource.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) { - if (await resource.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) - { - return result; - } - } - - // Fall back to an O(N) lookup, trying to match against each URI template. - // The number of templates is controlled by the server developer, and the number is expected to be - // not terribly large. If that changes, this can be tweaked to enable a more efficient lookup. - foreach (var resourceTemplate in resources) - { - if (await resourceTemplate.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) - { - return result; - } + return result; } } - // Finally fall back to the handler. return await originalReadResourceHandler(request, cancellationToken).ConfigureAwait(false); }; @@ -312,6 +299,43 @@ await originalListResourceTemplatesHandler(request, cancellationToken).Configure // subscribe = true; } + listResourcesHandler = BuildFilterPipeline(listResourcesHandler, options.Filters.ListResourcesFilters); + listResourceTemplatesHandler = BuildFilterPipeline(listResourceTemplatesHandler, options.Filters.ListResourceTemplatesFilters); + readResourceHandler = BuildFilterPipeline(readResourceHandler, options.Filters.ReadResourceFilters, handler => + async (request, cancellationToken) => + { + // Initial handler that sets MatchedPrimitive + if (request.Params?.Uri is { } uri && resources is not null) + { + // First try an O(1) lookup by exact match. + if (resources.TryGetPrimitive(uri, out var resource)) + { + request.MatchedPrimitive = resource; + } + else + { + // Fall back to an O(N) lookup, trying to match against each URI template. + // The number of templates is controlled by the server developer, and the number is expected to be + // not terribly large. If that changes, this can be tweaked to enable a more efficient lookup. + foreach (var resourceTemplate in resources) + { + // Check if this template would handle the request by testing if ReadAsync would succeed + if (resourceTemplate.IsTemplated) + { + // This is a simplified check - a more robust implementation would match the URI pattern + // For now, we'll let the actual handler attempt the match + request.MatchedPrimitive = resourceTemplate; + break; + } + } + } + } + + return await handler(request, cancellationToken).ConfigureAwait(false); + }); + subscribeHandler = BuildFilterPipeline(subscribeHandler, options.Filters.SubscribeToResourcesFilters); + unsubscribeHandler = BuildFilterPipeline(unsubscribeHandler, options.Filters.UnsubscribeFromResourcesFilters); + ServerCapabilities.Resources.ListResourcesHandler = listResourcesHandler; ServerCapabilities.Resources.ListResourceTemplatesHandler = listResourceTemplatesHandler; ServerCapabilities.Resources.ReadResourceHandler = readResourceHandler; @@ -390,8 +414,7 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals var originalGetPromptHandler = getPromptHandler; getPromptHandler = (request, cancellationToken) => { - if (request.Params is not null && - prompts.TryGetPrimitive(request.Params.Name, out var prompt)) + if (request.MatchedPrimitive is McpServerPrompt prompt) { return prompt.GetAsync(request, cancellationToken); } @@ -402,6 +425,20 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals listChanged = true; } + listPromptsHandler = BuildFilterPipeline(listPromptsHandler, options.Filters.ListPromptsFilters); + getPromptHandler = BuildFilterPipeline(getPromptHandler, options.Filters.GetPromptFilters, handler => + (request, cancellationToken) => + { + // Initial handler that sets MatchedPrimitive + if (request.Params?.Name is { } promptName && prompts is not null && + prompts.TryGetPrimitive(promptName, out var prompt)) + { + request.MatchedPrimitive = prompt; + } + + return handler(request, cancellationToken); + }); + ServerCapabilities.Prompts.ListPromptsHandler = listPromptsHandler; ServerCapabilities.Prompts.GetPromptHandler = getPromptHandler; ServerCapabilities.Prompts.PromptCollection = prompts; @@ -458,8 +495,7 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) var originalCallToolHandler = callToolHandler; callToolHandler = (request, cancellationToken) => { - if (request.Params is not null && - tools.TryGetPrimitive(request.Params.Name, out var tool)) + if (request.MatchedPrimitive is McpServerTool tool) { return tool.InvokeAsync(request, cancellationToken); } @@ -470,6 +506,51 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) listChanged = true; } + listToolsHandler = BuildFilterPipeline(listToolsHandler, options.Filters.ListToolsFilters); + callToolHandler = BuildFilterPipeline(callToolHandler, options.Filters.CallToolFilters, handler => + (request, cancellationToken) => + { + // Initial handler that sets MatchedPrimitive + if (request.Params?.Name is { } toolName && tools is not null && + tools.TryGetPrimitive(toolName, out var tool)) + { + request.MatchedPrimitive = tool; + } + + return handler(request, cancellationToken); + }, handler => + async (request, cancellationToken) => + { + // Final handler that provides exception handling only for tool execution + // Only wrap tool execution in try-catch, not tool resolution + if (request.MatchedPrimitive is McpServerTool) + { + try + { + return await handler(request, cancellationToken).ConfigureAwait(false); + } + catch (Exception e) when (e is not OperationCanceledException) + { + ToolCallError(request.Params?.Name ?? string.Empty, e); + + string errorMessage = e is McpException ? + $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : + $"An error occurred invoking '{request.Params?.Name}'."; + + return new() + { + IsError = true, + Content = [new TextContentBlock { Text = errorMessage }], + }; + } + } + else + { + // For unmatched tools, let exceptions bubble up as protocol errors + return await handler(request, cancellationToken).ConfigureAwait(false); + } + }); + ServerCapabilities.Tools.ListToolsHandler = listToolsHandler; ServerCapabilities.Tools.CallToolHandler = callToolHandler; ServerCapabilities.Tools.ToolCollection = tools; @@ -493,12 +574,18 @@ private void ConfigureLogging(McpServerOptions options) // We don't require that the handler be provided, as we always store the provided log level to the server. var setLoggingLevelHandler = options.Capabilities?.Logging?.SetLoggingLevelHandler; + // Apply filters to the handler + if (setLoggingLevelHandler is not null) + { + setLoggingLevelHandler = BuildFilterPipeline(setLoggingLevelHandler, options.Filters.SetLoggingLevelFilters); + } + ServerCapabilities.Logging = new(); ServerCapabilities.Logging.SetLoggingLevelHandler = setLoggingLevelHandler; RequestHandlers.Set( RequestMethods.LoggingSetLevel, - (request, destinationTransport, cancellationToken) => + (request, jsonRpcRequest, cancellationToken) => { // Store the provided level. if (request is not null) @@ -514,7 +601,7 @@ private void ConfigureLogging(McpServerOptions options) // If a handler was provided, now delegate to it. if (setLoggingLevelHandler is not null) { - return InvokeHandlerAsync(setLoggingLevelHandler, request, destinationTransport, cancellationToken); + return InvokeHandlerAsync(setLoggingLevelHandler, request, jsonRpcRequest, cancellationToken); } // Otherwise, consider it handled. @@ -527,23 +614,24 @@ private void ConfigureLogging(McpServerOptions options) private ValueTask InvokeHandlerAsync( Func, CancellationToken, ValueTask> handler, TParams? args, - ITransport? destinationTransport = null, + JsonRpcRequest jsonRpcRequest, CancellationToken cancellationToken = default) { return _servicesScopePerRequest ? - InvokeScopedAsync(handler, args, cancellationToken) : - handler(new(new DestinationBoundMcpServer(this, destinationTransport)) { Params = args }, cancellationToken); + InvokeScopedAsync(handler, args, jsonRpcRequest, cancellationToken) : + handler(new(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest) { Params = args }, cancellationToken); async ValueTask InvokeScopedAsync( Func, CancellationToken, ValueTask> handler, TParams? args, + JsonRpcRequest jsonRpcRequest, CancellationToken cancellationToken) { var scope = Services?.GetService()?.CreateAsyncScope(); try { return await handler( - new RequestContext(new DestinationBoundMcpServer(this, destinationTransport)) + new RequestContext(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest) { Services = scope?.ServiceProvider ?? Services, Params = args @@ -566,12 +654,33 @@ private void SetHandler( JsonTypeInfo requestTypeInfo, JsonTypeInfo responseTypeInfo) { - RequestHandlers.Set(method, - (request, destinationTransport, cancellationToken) => - InvokeHandlerAsync(handler, request, destinationTransport, cancellationToken), + RequestHandlers.Set(method, + (request, jsonRpcRequest, cancellationToken) => + InvokeHandlerAsync(handler, request, jsonRpcRequest, cancellationToken), requestTypeInfo, responseTypeInfo); } + private static THandler BuildFilterPipeline( + THandler baseHandler, List> filters, + Func? initialHandler = null, + Func? finalHandler = null) + { + THandler current = baseHandler; + if (finalHandler is not null) + { + current = finalHandler(current); + } + for (int i = filters.Count - 1; i >= 0; i--) + { + current = filters[i](current); + } + if (initialHandler is not null) + { + current = initialHandler(current); + } + return current; + } + private void UpdateEndpointNameWithClientInfo() { if (ClientInfo is null) @@ -594,4 +703,7 @@ internal static LoggingLevel ToLoggingLevel(LogLevel level) => LogLevel.Critical => Protocol.LoggingLevel.Critical, _ => Protocol.LoggingLevel.Emergency, }; + + [LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")] + private partial void ToolCallError(string toolName, Exception exception); } diff --git a/src/ModelContextProtocol.Core/Server/McpServerFilters.cs b/src/ModelContextProtocol.Core/Server/McpServerFilters.cs new file mode 100644 index 00000000..d15154dd --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpServerFilters.cs @@ -0,0 +1,161 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server; + +/// +/// Provides filter collections for MCP server handlers. +/// +/// +/// This class contains collections of filters that can be applied to various MCP server handlers. +/// This allows for middleware-style composition where filters can perform actions before and after the inner handler. +/// +public sealed class McpServerFilters +{ + /// + /// Gets the filters for the list tools handler pipeline. + /// + /// + /// + /// These filters wrap handlers that return a list of available tools when requested by a client. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more tools. + /// + /// + /// These filters work alongside any tools defined in the collection. + /// Tools from both sources will be combined when returning results to clients. + /// + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ListToolsFilters { get; } = new(); + + /// + /// Gets the filters for the call tool handler pipeline. + /// + /// + /// These filters wrap handlers that are invoked when a client makes a call to a tool that isn't found in the collection. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to execute the requested tool and return appropriate results. + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> CallToolFilters { get; } = new(); + + /// + /// Gets the filters for the list prompts handler pipeline. + /// + /// + /// + /// These filters wrap handlers that return a list of available prompts when requested by a client. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more prompts. + /// + /// + /// These filters work alongside any prompts defined in the collection. + /// Prompts from both sources will be combined when returning results to clients. + /// + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ListPromptsFilters { get; } = new(); + + /// + /// Gets the filters for the get prompt handler pipeline. + /// + /// + /// These filters wrap handlers that are invoked when a client requests details for a specific prompt that isn't found in the collection. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to fetch or generate the requested prompt and return appropriate results. + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> GetPromptFilters { get; } = new(); + + /// + /// Gets the filters for the list resource templates handler pipeline. + /// + /// + /// These filters wrap handlers that return a list of available resource templates when requested by a client. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resource templates. + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ListResourceTemplatesFilters { get; } = new(); + + /// + /// Gets the filters for the list resources handler pipeline. + /// + /// + /// These filters wrap handlers that return a list of available resources when requested by a client. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resources. + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ListResourcesFilters { get; } = new(); + + /// + /// Gets the filters for the read resource handler pipeline. + /// + /// + /// These filters wrap handlers that are invoked when a client requests the content of a specific resource identified by its URI. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to locate and retrieve the requested resource. + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ReadResourceFilters { get; } = new(); + + /// + /// Gets the filters for the complete handler pipeline. + /// + /// + /// These filters wrap handlers that provide auto-completion suggestions for prompt arguments or resource references in the Model Context Protocol. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler processes auto-completion requests, returning a list of suggestions based on the + /// reference type and current argument value. + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> CompleteFilters { get; } = new(); + + /// + /// Gets the filters for the subscribe to resources handler pipeline. + /// + /// + /// + /// These filters wrap handlers that are invoked when a client wants to receive notifications about changes to specific resources or resource patterns. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to register the client's interest in the specified resources + /// and set up the necessary infrastructure to send notifications when those resources change. + /// + /// + /// After a successful subscription, the server should send resource change notifications to the client + /// whenever a relevant resource is created, updated, or deleted. + /// + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> SubscribeToResourcesFilters { get; } = new(); + + /// + /// Gets the filters for the unsubscribe from resources handler pipeline. + /// + /// + /// + /// These filters wrap handlers that are invoked when a client wants to stop receiving notifications about previously subscribed resources. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to remove the client's subscriptions to the specified resources + /// and clean up any associated resources. + /// + /// + /// After a successful unsubscription, the server should no longer send resource change notifications + /// to the client for the specified resources. + /// + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> UnsubscribeFromResourcesFilters { get; } = new(); + + /// + /// Gets the filters for the set logging level handler pipeline. + /// + /// + /// + /// These filters wrap handlers that process requests from clients. When set, it enables + /// clients to control which log messages they receive by specifying a minimum severity threshold. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. + /// + /// + /// After handling a level change request, the server typically begins sending log messages + /// at or above the specified level to the client as notifications/message notifications. + /// + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> SetLoggingLevelFilters { get; } = new(); +} diff --git a/src/ModelContextProtocol.Core/Server/McpServerOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerOptions.cs index 8c50a9b5..1c981b77 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerOptions.cs @@ -79,4 +79,14 @@ public sealed class McpServerOptions /// /// public Implementation? KnownClientInfo { get; set; } + + /// + /// Gets the filter collections for MCP server handlers. + /// + /// + /// This property provides access to filter collections that can be used to modify the behavior + /// of various MCP server handlers. Filters are applied in reverse order, so the last filter + /// added will be the outermost (first to execute). + /// + public McpServerFilters Filters { get; } = new(); } diff --git a/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs b/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs index 68874df3..74627879 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs @@ -20,7 +20,7 @@ namespace ModelContextProtocol.Server; /// /// /// Most commonly, instances are created using the static methods. -/// These methods enable creating an for a method, specified via a or +/// These methods enable creating an for a method, specified via a or /// , and are what are used implicitly by WithPromptsFromAssembly and WithPrompts. The methods /// create instances capable of working with a large variety of .NET method signatures, automatically handling /// how parameters are marshaled into the method from the JSON received from the MCP client, and how the return value is marshaled back @@ -61,15 +61,15 @@ namespace ModelContextProtocol.Server; /// /// /// -/// When the is constructed, it may be passed an via +/// When the is constructed, it may be passed an via /// . Any parameter that can be satisfied by that -/// according to will be resolved from the provided to +/// according to will be resolved from the provided to /// rather than from the argument collection. /// /// /// /// -/// Any parameter attributed with will similarly be resolved from the +/// Any parameter attributed with will similarly be resolved from the /// provided to rather than from the argument collection. /// /// @@ -80,7 +80,7 @@ namespace ModelContextProtocol.Server; /// /// /// In general, the data supplied via the 's dictionary is passed along from the caller and -/// should thus be considered unvalidated and untrusted. To provide validated and trusted data to the invocation of the prompt, consider having +/// should thus be considered unvalidated and untrusted. To provide validated and trusted data to the invocation of the prompt, consider having /// the prompt be an instance method, referring to data stored in the instance, or using an instance or parameters resolved from the /// to provide data to the method. /// @@ -128,6 +128,15 @@ protected McpServerPrompt() /// public abstract Prompt ProtocolPrompt { get; } + /// + /// Gets the metadata for this prompt instance. + /// + /// + /// Contains attributes from the associated MethodInfo and declaring class (if any), + /// with class-level attributes appearing before method-level attributes. + /// + public abstract IReadOnlyList Metadata { get; } + /// /// Gets the prompt, rendering it with the provided request parameters and returning the prompt result. /// @@ -170,7 +179,7 @@ public static McpServerPrompt Create( /// is . /// is an instance method but is . public static McpServerPrompt Create( - MethodInfo method, + MethodInfo method, object? target = null, McpServerPromptCreateOptions? options = null) => AIFunctionMcpServerPrompt.Create(method, target, options); diff --git a/src/ModelContextProtocol.Core/Server/McpServerPromptCreateOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerPromptCreateOptions.cs index 95d712ff..1853b0f1 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPromptCreateOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPromptCreateOptions.cs @@ -68,6 +68,15 @@ public sealed class McpServerPromptCreateOptions /// public AIJsonSchemaCreateOptions? SchemaCreateOptions { get; set; } + /// + /// Gets or sets the metadata associated with the prompt. + /// + /// + /// Metadata includes information such as the attributes extracted from the method and its declaring class. + /// If not provided, metadata will be automatically generated for methods created via reflection. + /// + public IReadOnlyList? Metadata { get; set; } + /// /// Creates a shallow clone of the current instance. /// @@ -80,5 +89,6 @@ internal McpServerPromptCreateOptions Clone() => Description = Description, SerializerOptions = SerializerOptions, SchemaCreateOptions = SchemaCreateOptions, + Metadata = Metadata, }; } diff --git a/src/ModelContextProtocol.Core/Server/McpServerResource.cs b/src/ModelContextProtocol.Core/Server/McpServerResource.cs index 8e42d3e1..9508cda0 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerResource.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerResource.cs @@ -11,13 +11,13 @@ namespace ModelContextProtocol.Server; /// /// /// is an abstract base class that represents an MCP resource for use in the server (as opposed -/// to or , which provide the protocol representations of a resource). Instances of +/// to or , which provide the protocol representations of a resource). Instances of /// can be added into a to be picked up automatically when /// is used to create an , or added into a . /// /// /// Most commonly, instances are created using the static methods. -/// These methods enable creating an for a method, specified via a or +/// These methods enable creating an for a method, specified via a or /// , and are what are used implicitly by WithResourcesFromAssembly and /// . The methods /// create instances capable of working with a large variety of .NET method signatures, automatically handling @@ -62,15 +62,15 @@ namespace ModelContextProtocol.Server; /// /// /// -/// When the is constructed, it may be passed an via +/// When the is constructed, it may be passed an via /// . Any parameter that can be satisfied by that -/// according to will be resolved from the provided to the +/// according to will be resolved from the provided to the /// resource invocation rather than from the argument collection. /// /// /// /// -/// Any parameter attributed with will similarly be resolved from the +/// Any parameter attributed with will similarly be resolved from the /// provided to the resource invocation rather than from the argument collection. /// /// @@ -149,6 +149,15 @@ protected McpServerResource() /// public virtual Resource? ProtocolResource => ProtocolResourceTemplate.AsResource(); + /// + /// Gets the metadata for this resource instance. + /// + /// + /// Contains attributes from the associated MethodInfo and declaring class (if any), + /// with class-level attributes appearing before method-level attributes. + /// + public abstract IReadOnlyList Metadata { get; } + /// /// Gets the resource, rendering it with the provided request parameters and returning the resource result. /// @@ -192,7 +201,7 @@ public static McpServerResource Create( /// is . /// is an instance method but is . public static McpServerResource Create( - MethodInfo method, + MethodInfo method, object? target = null, McpServerResourceCreateOptions? options = null) => AIFunctionMcpServerResource.Create(method, target, options); diff --git a/src/ModelContextProtocol.Core/Server/McpServerResourceCreateOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerResourceCreateOptions.cs index 24051a7f..2d6b66b3 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerResourceCreateOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerResourceCreateOptions.cs @@ -83,6 +83,15 @@ public sealed class McpServerResourceCreateOptions /// public AIJsonSchemaCreateOptions? SchemaCreateOptions { get; set; } + /// + /// Gets or sets the metadata associated with the resource. + /// + /// + /// Metadata includes information such as attributes extracted from the method and its declaring class. + /// If not provided, metadata will be automatically generated for methods created via reflection. + /// + public IReadOnlyList? Metadata { get; set; } + /// /// Creates a shallow clone of the current instance. /// @@ -97,5 +106,6 @@ internal McpServerResourceCreateOptions Clone() => MimeType = MimeType, SerializerOptions = SerializerOptions, SchemaCreateOptions = SchemaCreateOptions, + Metadata = Metadata, }; } diff --git a/src/ModelContextProtocol.Core/Server/McpServerTool.cs b/src/ModelContextProtocol.Core/Server/McpServerTool.cs index e3958271..baddf88f 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerTool.cs @@ -20,7 +20,7 @@ namespace ModelContextProtocol.Server; /// /// /// Most commonly, instances are created using the static methods. -/// These methods enable creating an for a method, specified via a or +/// These methods enable creating an for a method, specified via a or /// , and are what are used implicitly by WithToolsFromAssembly and WithTools. The methods /// create instances capable of working with a large variety of .NET method signatures, automatically handling /// how parameters are marshaled into the method from the JSON received from the MCP client, and how the return value is marshaled back @@ -56,22 +56,22 @@ namespace ModelContextProtocol.Server; /// /// parameters accepting values /// are not included in the JSON schema and are bound to an instance manufactured -/// to forward progress notifications from the tool to the client. If the client included a in their request, +/// to forward progress notifications from the tool to the client. If the client included a in their request, /// progress reports issued to this instance will propagate to the client as notifications with /// that token. If the client did not include a , the instance will ignore any progress reports issued to it. /// /// /// /// -/// When the is constructed, it may be passed an via +/// When the is constructed, it may be passed an via /// . Any parameter that can be satisfied by that -/// according to will not be included in the generated JSON schema and will be resolved +/// according to will not be included in the generated JSON schema and will be resolved /// from the provided to rather than from the argument collection. /// /// /// /// -/// Any parameter attributed with will similarly be resolved from the +/// Any parameter attributed with will similarly be resolved from the /// provided to rather than from the argument /// collection, and will not be included in the generated JSON schema. /// @@ -79,13 +79,13 @@ namespace ModelContextProtocol.Server; /// /// /// -/// All other parameters are deserialized from the s in the dictionary, -/// using the supplied in , or if none was provided, +/// All other parameters are deserialized from the s in the dictionary, +/// using the supplied in , or if none was provided, /// using . /// /// /// In general, the data supplied via the 's dictionary is passed along from the caller and -/// should thus be considered unvalidated and untrusted. To provide validated and trusted data to the invocation of the tool, consider having +/// should thus be considered unvalidated and untrusted. To provide validated and trusted data to the invocation of the tool, consider having /// the tool be an instance method, referring to data stored in the instance, or using an instance or parameters resolved from the /// to provide data to the method. /// @@ -141,6 +141,15 @@ protected McpServerTool() /// Gets the protocol type for this instance. public abstract Tool ProtocolTool { get; } + /// + /// Gets the metadata for this tool instance. + /// + /// + /// Contains attributes from the associated MethodInfo and declaring class (if any), + /// with class-level attributes appearing before method-level attributes. + /// + public abstract IReadOnlyList Metadata { get; } + /// Invokes the . /// The request information resulting in the invocation of this tool. /// The to monitor for cancellation requests. The default is . @@ -172,7 +181,7 @@ public static McpServerTool Create( /// is . /// is an instance method but is . public static McpServerTool Create( - MethodInfo method, + MethodInfo method, object? target = null, McpServerToolCreateOptions? options = null) => AIFunctionMcpServerTool.Create(method, target, options); diff --git a/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs index bdb4ecb8..d18af8c0 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs @@ -80,7 +80,7 @@ public sealed class McpServerToolCreateOptions public bool? Destructive { get; set; } /// - /// Gets or sets whether calling the tool repeatedly with the same arguments + /// Gets or sets whether calling the tool repeatedly with the same arguments /// will have no additional effect on its environment. /// /// @@ -155,6 +155,15 @@ public sealed class McpServerToolCreateOptions /// public AIJsonSchemaCreateOptions? SchemaCreateOptions { get; set; } + /// + /// Gets or sets the metadata associated with the tool. + /// + /// + /// Metadata includes information such as attributes extracted from the method and its declaring class. + /// If not provided, metadata will be automatically generated for methods created via reflection. + /// + public IReadOnlyList? Metadata { get; set; } + /// /// Creates a shallow clone of the current instance. /// @@ -172,5 +181,6 @@ internal McpServerToolCreateOptions Clone() => UseStructuredContent = UseStructuredContent, SerializerOptions = SerializerOptions, SchemaCreateOptions = SchemaCreateOptions, + Metadata = Metadata, }; } diff --git a/src/ModelContextProtocol.Core/Server/RequestContext.cs b/src/ModelContextProtocol.Core/Server/RequestContext.cs index b0ea9d99..8af9f666 100644 --- a/src/ModelContextProtocol.Core/Server/RequestContext.cs +++ b/src/ModelContextProtocol.Core/Server/RequestContext.cs @@ -1,3 +1,6 @@ +using System.Security.Claims; +using ModelContextProtocol.Protocol; + namespace ModelContextProtocol.Server; /// @@ -15,19 +18,23 @@ public sealed class RequestContext private IMcpServer _server; /// - /// Initializes a new instance of the class with the specified server. + /// Initializes a new instance of the class with the specified server and JSON-RPC request. /// /// The server with which this instance is associated. - public RequestContext(IMcpServer server) + /// The JSON-RPC request associated with this context. + public RequestContext(IMcpServer server, JsonRpcRequest jsonRpcRequest) { Throw.IfNull(server); + Throw.IfNull(jsonRpcRequest); _server = server; + JsonRpcRequest = jsonRpcRequest; Services = server.Services; + User = jsonRpcRequest.Context?.User; } /// Gets or sets the server with which this instance is associated. - public IMcpServer Server + public IMcpServer Server { get => _server; set @@ -46,6 +53,23 @@ public IMcpServer Server /// public IServiceProvider? Services { get; set; } + /// Gets or sets the user associated with this request. + public ClaimsPrincipal? User { get; set; } + /// Gets or sets the parameters associated with this request. public TParams? Params { get; set; } + + /// + /// Gets or sets the primitive that matched the request. + /// + public IMcpServerPrimitive? MatchedPrimitive { get; set; } + + /// + /// Gets the JSON-RPC request associated with this context. + /// + /// + /// This property provides access to the complete JSON-RPC request that initiated this handler invocation, + /// including the method name, parameters, request ID, and associated transport and user information. + /// + public JsonRpcRequest JsonRpcRequest { get; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs b/src/ModelContextProtocol.Core/Server/RequestServiceProvider.cs similarity index 69% rename from src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs rename to src/ModelContextProtocol.Core/Server/RequestServiceProvider.cs index 3372072f..38af614c 100644 --- a/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs +++ b/src/ModelContextProtocol.Core/Server/RequestServiceProvider.cs @@ -1,16 +1,17 @@ using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol; +using System.Security.Claims; namespace ModelContextProtocol.Server; /// Augments a service provider with additional request-related services. -internal sealed class RequestServiceProvider( - RequestContext request, IServiceProvider? innerServices) : - IServiceProvider, IKeyedServiceProvider, - IServiceProviderIsService, IServiceProviderIsKeyedService, +internal sealed class RequestServiceProvider(RequestContext request) : + IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, IDisposable, IAsyncDisposable where TRequestParams : RequestParams { + private readonly IServiceProvider? _innerServices = request.Services; + /// Gets the request associated with this instance. public RequestContext Request => request; @@ -18,7 +19,8 @@ internal sealed class RequestServiceProvider( public static bool IsAugmentedWith(Type serviceType) => serviceType == typeof(RequestContext) || serviceType == typeof(IMcpServer) || - serviceType == typeof(IProgress); + serviceType == typeof(IProgress) || + serviceType == typeof(ClaimsPrincipal); /// public object? GetService(Type serviceType) => @@ -26,22 +28,23 @@ public static bool IsAugmentedWith(Type serviceType) => serviceType == typeof(IMcpServer) ? request.Server : serviceType == typeof(IProgress) ? (request.Params?.ProgressToken is { } progressToken ? new TokenProgress(request.Server, progressToken) : NullProgress.Instance) : - innerServices?.GetService(serviceType); + serviceType == typeof(ClaimsPrincipal) ? request.User : + _innerServices?.GetService(serviceType); /// public bool IsService(Type serviceType) => IsAugmentedWith(serviceType) || - (innerServices as IServiceProviderIsService)?.IsService(serviceType) is true; + (_innerServices as IServiceProviderIsService)?.IsService(serviceType) is true; /// public bool IsKeyedService(Type serviceType, object? serviceKey) => (serviceKey is null && IsService(serviceType)) || - (innerServices as IServiceProviderIsKeyedService)?.IsKeyedService(serviceType, serviceKey) is true; + (_innerServices as IServiceProviderIsKeyedService)?.IsKeyedService(serviceType, serviceKey) is true; /// public object? GetKeyedService(Type serviceType, object? serviceKey) => serviceKey is null ? GetService(serviceType) : - (innerServices as IKeyedServiceProvider)?.GetKeyedService(serviceType, serviceKey); + (_innerServices as IKeyedServiceProvider)?.GetKeyedService(serviceType, serviceKey); /// public object GetRequiredKeyedService(Type serviceType, object? serviceKey) => @@ -50,9 +53,9 @@ public object GetRequiredKeyedService(Type serviceType, object? serviceKey) => /// public void Dispose() => - (innerServices as IDisposable)?.Dispose(); + (_innerServices as IDisposable)?.Dispose(); /// public ValueTask DisposeAsync() => - innerServices is IAsyncDisposable asyncDisposable ? asyncDisposable.DisposeAsync() : default; + _innerServices is IAsyncDisposable asyncDisposable ? asyncDisposable.DisposeAsync() : default; } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs index 438421f2..8941e4ed 100644 --- a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Protocol; +using System.Security.Claims; using System.Threading.Channels; namespace ModelContextProtocol.Server; @@ -9,7 +10,7 @@ namespace ModelContextProtocol.Server; /// /// /// This transport provides one-way communication from server to client using the SSE protocol over HTTP, -/// while receiving client messages through a separate mechanism. It writes messages as +/// while receiving client messages through a separate mechanism. It writes messages as /// SSE events to a response stream, typically associated with an HTTP response. /// /// @@ -41,7 +42,7 @@ public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? /// /// The to monitor for cancellation requests. The default is . /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. - public async Task RunAsync(CancellationToken cancellationToken) + public async Task RunAsync(CancellationToken cancellationToken = default) { _isConnected = true; await _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken).ConfigureAwait(false); @@ -64,6 +65,7 @@ public async ValueTask DisposeAsync() /// public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { + Throw.IfNull(message); await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); } @@ -76,8 +78,8 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can /// Thrown when there is an attempt to process a message before calling . /// /// - /// This method is the entry point for processing client-to-server communication in the SSE transport model. - /// While the SSE protocol itself is unidirectional (server to client), this method allows bidirectional + /// This method is the entry point for processing client-to-server communication in the SSE transport model. + /// While the SSE protocol itself is unidirectional (server to client), this method allows bidirectional /// communication by handling HTTP POST requests sent to the message endpoint. /// /// @@ -85,11 +87,11 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can /// process the message and make it available to the MCP server via the channel. /// /// - /// This method validates that the transport is connected before processing the message, ensuring proper - /// sequencing of operations in the transport lifecycle. + /// If an authenticated sent the message, that can be included in the . + /// No other part of the context should be set. /// /// - public async Task OnMessageReceivedAsync(JsonRpcMessage message, CancellationToken cancellationToken) + public async Task OnMessageReceivedAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { Throw.IfNull(message); diff --git a/src/ModelContextProtocol.Core/Server/SseWriter.cs b/src/ModelContextProtocol.Core/Server/SseWriter.cs index 18571e2c..4fb7feaf 100644 --- a/src/ModelContextProtocol.Core/Server/SseWriter.cs +++ b/src/ModelContextProtocol.Core/Server/SseWriter.cs @@ -26,6 +26,8 @@ internal sealed class SseWriter(string? messageEndpoint = null, BoundedChannelOp public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellationToken) { + Throw.IfNull(sseResponseStream); + // When messageEndpoint is set, the very first SSE event isn't really an IJsonRpcMessage, but there's no API to write a single // item of a different type, so we fib and special-case the "endpoint" event type in the formatter. if (messageEndpoint is not null && !_messages.Writer.TryWrite(new SseItem(null, "endpoint"))) diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index 9d225caa..1992939d 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -1,7 +1,9 @@ using ModelContextProtocol.Protocol; +using System.Diagnostics; using System.IO.Pipelines; using System.Net.ServerSentEvents; using System.Runtime.CompilerServices; +using System.Security.Claims; using System.Text.Json; using System.Threading.Channels; @@ -9,14 +11,14 @@ namespace ModelContextProtocol.Server; /// /// Handles processing the request/response body pairs for the Streamable HTTP transport. -/// This is typically used via . +/// This is typically used via . /// -internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport parentTransport, IDuplexPipe httpBodies) : ITransport +internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport parentTransport, Stream responseStream) : ITransport { private readonly SseWriter _sseWriter = new(); private RequestId _pendingRequest; - public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.RelatedTransport should only be used for sending messages."); + public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.Context.RelatedTransport should only be used for sending messages."); string? ITransport.SessionId => parentTransport.SessionId; @@ -25,11 +27,31 @@ internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport /// False, if nothing was written because the request body did not contain any messages to respond to. /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario. /// - public async ValueTask RunAsync(CancellationToken cancellationToken) + public async ValueTask HandlePostAsync(JsonRpcMessage message, CancellationToken cancellationToken) { - var message = await JsonSerializer.DeserializeAsync(httpBodies.Input.AsStream(), - McpJsonUtilities.JsonContext.Default.JsonRpcMessage, cancellationToken).ConfigureAwait(false); - await OnMessageReceivedAsync(message, cancellationToken).ConfigureAwait(false); + Debug.Assert(_pendingRequest.Id is null); + + if (message is JsonRpcRequest request) + { + _pendingRequest = request.Id; + + // Invoke the initialize request callback if applicable. + if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize) + { + var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); + await onInitRequest(initializeRequest).ConfigureAwait(false); + } + } + + message.Context ??= new JsonRpcMessageContext(); + message.Context.RelatedTransport = this; + + if (parentTransport.FlowExecutionContextFromRequests) + { + message.Context.ExecutionContext = ExecutionContext.Capture(); + } + + await parentTransport.MessageWriter.WriteAsync(message, cancellationToken).ConfigureAwait(false); if (_pendingRequest.Id is null) { @@ -37,12 +59,14 @@ public async ValueTask RunAsync(CancellationToken cancellationToken) } _sseWriter.MessageFilter = StopOnFinalResponseFilter; - await _sseWriter.WriteAllAsync(httpBodies.Output.AsStream(), cancellationToken).ConfigureAwait(false); + await _sseWriter.WriteAllAsync(responseStream, cancellationToken).ConfigureAwait(false); return true; } public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { + Throw.IfNull(message); + if (parentTransport.Stateless && message is JsonRpcRequest) { throw new InvalidOperationException("Server to client requests are not supported in stateless mode."); @@ -69,33 +93,4 @@ public async ValueTask DisposeAsync() } } } - - private async ValueTask OnMessageReceivedAsync(JsonRpcMessage? message, CancellationToken cancellationToken) - { - if (message is null) - { - throw new InvalidOperationException("Received invalid null message."); - } - - if (message is JsonRpcRequest request) - { - _pendingRequest = request.Id; - - // Invoke the initialize request callback if applicable. - if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize) - { - var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); - await onInitRequest(initializeRequest).ConfigureAwait(false); - } - } - - message.RelatedTransport = this; - - if (parentTransport.FlowExecutionContextFromRequests) - { - message.ExecutionContext = ExecutionContext.Capture(); - } - - await parentTransport.MessageWriter.WriteAsync(message, cancellationToken).ConfigureAwait(false); - } } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index b63c8a65..57283e9a 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -1,5 +1,6 @@ using ModelContextProtocol.Protocol; using System.IO.Pipelines; +using System.Security.Claims; using System.Threading.Channels; namespace ModelContextProtocol.Server; @@ -49,8 +50,8 @@ public sealed class StreamableHttpServerTransport : ITransport public bool Stateless { get; init; } /// - /// Gets a value indicating whether the execution context should flow from the calls to - /// to the corresponding emitted by the . + /// Gets a value indicating whether the execution context should flow from the calls to + /// to the corresponding property contained in the instances returned by the . /// /// /// Defaults to . @@ -75,8 +76,10 @@ public sealed class StreamableHttpServerTransport : ITransport /// The response stream to write MCP JSON-RPC messages as SSE events to. /// The to monitor for cancellation requests. The default is . /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. - public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken cancellationToken) + public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken cancellationToken = default) { + Throw.IfNull(sseResponseStream); + if (Stateless) { throw new InvalidOperationException("GET requests are not supported in stateless mode."); @@ -96,23 +99,33 @@ public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken c /// and other correlated messages are sent back to the client directly in response /// to the that initiated the message. /// - /// The duplex pipe facilitates the reading and writing of HTTP request and response data. - /// This token allows for the operation to be canceled if needed. + /// The JSON-RPC message received from the client via the POST request body. + /// This token allows for the operation to be canceled if needed. The default is . + /// The POST response body to write MCP JSON-RPC messages to. /// /// True, if data was written to the response body. /// False, if nothing was written because the request body did not contain any messages to respond to. /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario. /// - public async Task HandlePostRequest(IDuplexPipe httpBodies, CancellationToken cancellationToken) + /// + /// If 's an authenticated sent the message, that can be included in the . + /// No other part of the context should be set. + /// + public async Task HandlePostRequest(JsonRpcMessage message, Stream responseStream, CancellationToken cancellationToken = default) { + Throw.IfNull(message); + Throw.IfNull(responseStream); + using var postCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken); - await using var postTransport = new StreamableHttpPostTransport(this, httpBodies); - return await postTransport.RunAsync(postCts.Token).ConfigureAwait(false); + await using var postTransport = new StreamableHttpPostTransport(this, responseStream); + return await postTransport.HandlePostAsync(message, postCts.Token).ConfigureAwait(false); } /// public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { + Throw.IfNull(message); + if (Stateless) { throw new InvalidOperationException("Unsolicited server to client messages are not supported in stateless mode."); @@ -126,6 +139,7 @@ public async ValueTask DisposeAsync() { try { + _incomingChannel.Writer.TryComplete(); await _disposeCts.CancelAsync(); } finally diff --git a/src/ModelContextProtocol/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/McpServerBuilderExtensions.cs index d925b24f..db1b029d 100644 --- a/src/ModelContextProtocol/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpServerBuilderExtensions.cs @@ -707,6 +707,278 @@ public static IMcpServerBuilder WithSetLoggingLevelHandler(this IMcpServerBuilde } #endregion + #region Filters + /// + /// Adds a filter to the list resource templates handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that return a list of available resource templates when requested by a client. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resource templates. + /// + /// + public static IMcpServerBuilder AddListResourceTemplatesFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ListResourceTemplatesFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the list tools handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that return a list of available tools when requested by a client. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more tools. + /// + /// + /// This filter works alongside any tools defined in the collection. + /// Tools from both sources will be combined when returning results to clients. + /// + /// + public static IMcpServerBuilder AddListToolsFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ListToolsFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the call tool handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client makes a call to a tool that isn't found in the collection. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to execute the requested tool and return appropriate results. + /// + /// + public static IMcpServerBuilder AddCallToolFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.CallToolFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the list prompts handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that return a list of available prompts when requested by a client. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more prompts. + /// + /// + /// This filter works alongside any prompts defined in the collection. + /// Prompts from both sources will be combined when returning results to clients. + /// + /// + public static IMcpServerBuilder AddListPromptsFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ListPromptsFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the get prompt handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client requests details for a specific prompt that isn't found in the collection. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to fetch or generate the requested prompt and return appropriate results. + /// + /// + public static IMcpServerBuilder AddGetPromptFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.GetPromptFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the list resources handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that return a list of available resources when requested by a client. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resources. + /// + /// + public static IMcpServerBuilder AddListResourcesFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ListResourcesFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the read resource handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client requests the content of a specific resource identified by its URI. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to locate and retrieve the requested resource. + /// + /// + public static IMcpServerBuilder AddReadResourceFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ReadResourceFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the complete handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that provide auto-completion suggestions for prompt arguments or resource references in the Model Context Protocol. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler processes auto-completion requests, returning a list of suggestions based on the + /// reference type and current argument value. + /// + /// + public static IMcpServerBuilder AddCompleteFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.CompleteFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the subscribe to resources handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client wants to receive notifications about changes to specific resources or resource patterns. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to register the client's interest in the specified resources + /// and set up the necessary infrastructure to send notifications when those resources change. + /// + /// + /// After a successful subscription, the server should send resource change notifications to the client + /// whenever a relevant resource is created, updated, or deleted. + /// + /// + public static IMcpServerBuilder AddSubscribeToResourcesFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.SubscribeToResourcesFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the unsubscribe from resources handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client wants to stop receiving notifications about previously subscribed resources. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to remove the client's subscriptions to the specified resources + /// and clean up any associated resources. + /// + /// + /// After a successful unsubscription, the server should no longer send resource change notifications + /// to the client for the specified resources. + /// + /// + public static IMcpServerBuilder AddUnsubscribeFromResourcesFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.UnsubscribeFromResourcesFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the set logging level handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that process requests from clients. When set, it enables + /// clients to control which log messages they receive by specifying a minimum severity threshold. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. + /// + /// + /// After handling a level change request, the server typically begins sending log messages + /// at or above the specified level to the client as notifications/message notifications. + /// + /// + public static IMcpServerBuilder AddSetLoggingLevelFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.SetLoggingLevelFilters.Add(filter)); + return builder; + } + #endregion + #region Transports /// /// Adds a server transport that uses standard input (stdin) and standard output (stdout) for communication. diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs new file mode 100644 index 00000000..8c173d89 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs @@ -0,0 +1,374 @@ +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Security.Claims; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Tests for MCP authorization functionality with [Authorize], [AllowAnonymous] and role-based authorization. +/// +public class AuthorizeAttributeTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) +{ + private async Task ConnectAsync() + { + await using var transport = new SseClientTransport(new SseClientTransportOptions + { + Endpoint = new("http://localhost:5000"), + }, HttpClient, LoggerFactory); + + return await McpClientFactory.CreateAsync(transport, cancellationToken: TestContext.Current.CancellationToken, loggerFactory: LoggerFactory); + } + + [Fact] + public async Task Authorize_Tool_RequiresAuthentication() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools()); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "authorized_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + // Should return error because tool requires authorization but user is anonymous + Assert.True(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Access forbidden: This tool requires authorization.", content.Text); + } + + [Fact] + public async Task ClassLevelAuthorize_Tool_RequiresAuthentication() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools()); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "anonymous_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Anonymous: test", content.Text); + } + + [Fact] + public async Task AllowAnonymous_Tool_AllowsAnonymousAccess() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools()); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "anonymous_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Anonymous: test", content.Text); + } + + [Fact] + public async Task Authorize_Tool_AllowsAuthenticatedUser() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "TestUser"); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "authorized_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Authorized: test", content.Text); + } + + [Fact] + public async Task AuthorizeWithRoles_Tool_RequiresAdminRole() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "TestUser", "User"); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "admin_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + // Should return error because tool requires Admin role but user only has User role + Assert.True(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Access forbidden: This tool requires authorization.", content.Text); + } + + [Fact] + public async Task AuthorizeWithRoles_Tool_AllowsAdminUser() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "AdminUser", "Admin"); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "admin_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Admin: test", content.Text); + } + + [Fact] + public async Task ListTools_Anonymous_OnlyReturnsAnonymousTools() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools()); + + var client = await ConnectAsync(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Single(tools); + Assert.Equal("anonymous_tool", tools[0].Name); + } + + [Fact] + public async Task ListTools_AuthenticatedUser_ReturnsAuthorizedTools() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "TestUser"); + + var client = await ConnectAsync(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Authenticated user should see anonymous and basic authorized tools, but not admin-only tools + Assert.Equal(2, tools.Count); + var toolNames = tools.Select(t => t.Name).OrderBy(n => n).ToList(); + Assert.Equal(["anonymous_tool", "authorized_tool"], toolNames); + } + + [Fact] + public async Task ListTools_AdminUser_ReturnsAllTools() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "AdminUser", "Admin"); + + var client = await ConnectAsync(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Admin user should see all tools + Assert.Equal(3, tools.Count); + var toolNames = tools.Select(t => t.Name).OrderBy(n => n).ToList(); + Assert.Equal(["admin_tool", "anonymous_tool", "authorized_tool"], toolNames); + } + + [Fact] + public async Task ListTools_UserRole_DoesNotReturnAdminTools() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "TestUser", "User"); + + var client = await ConnectAsync(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // User with User role should not see admin-only tools + Assert.Equal(2, tools.Count); + var toolNames = tools.Select(t => t.Name).OrderBy(n => n).ToList(); + Assert.Equal(["anonymous_tool", "authorized_tool"], toolNames); + } + + [Fact] + public async Task Authorize_Prompt_RequiresAuthentication() + { + await using var app = await StartServerWithAuth(builder => builder.WithPrompts()); + + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.GetPromptAsync( + "authorized_prompt", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): Access forbidden: This prompt requires authorization.", exception.Message); + Assert.Equal(McpErrorCode.InvalidRequest, exception.ErrorCode); + } + + [Fact] + public async Task Authorize_Prompt_AllowsAuthenticatedUser() + { + await using var app = await StartServerWithAuth(builder => builder.WithPrompts(), "TestUser"); + + var client = await ConnectAsync(); + var result = await client.GetPromptAsync( + "authorized_prompt", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + var message = Assert.Single(result.Messages); + Assert.Equal(Role.User, message.Role); + var content = Assert.IsType(message.Content); + Assert.Equal("Authorized prompt: test", content.Text); + } + + [Fact] + public async Task ListPrompts_Anonymous_OnlyReturnsAnonymousPrompts() + { + await using var app = await StartServerWithAuth(builder => builder.WithPrompts()); + + var client = await ConnectAsync(); + var prompts = await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Anonymous user should only see prompts marked with [AllowAnonymous] + Assert.Single(prompts); + Assert.Equal("anonymous_prompt", prompts[0].Name); + } + + [Fact] + public async Task Authorize_Resource_RequiresAuthentication() + { + await using var app = await StartServerWithAuth(builder => builder.WithResources()); + + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.ReadResourceAsync( + "resource://authorized", + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): Access forbidden: This resource requires authorization.", exception.Message); + Assert.Equal(McpErrorCode.InvalidRequest, exception.ErrorCode); + } + + [Fact] + public async Task Authorize_Resource_AllowsAuthenticatedUser() + { + await using var app = await StartServerWithAuth(builder => builder.WithResources(), "TestUser"); + + var client = await ConnectAsync(); + var result = await client.ReadResourceAsync( + "resource://authorized", + cancellationToken: TestContext.Current.CancellationToken); + + var content = Assert.Single(result.Contents.OfType()); + Assert.Equal("Authorized resource content", content.Text); + } + + [Fact] + public async Task ListResources_Anonymous_OnlyReturnsAnonymousResources() + { + await using var app = await StartServerWithAuth(builder => builder.WithResources()); + + var client = await ConnectAsync(); + var resources = await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Single(resources); + Assert.Equal("resource://anonymous", resources[0].Uri); + } + + private async Task StartServerWithAuth(Action configure, string? userName = null, params string[] roles) + { + var builder = Builder.Services.AddMcpServer().WithHttpTransport(); + configure(builder); + Builder.Services.AddAuthorization(); + + var app = Builder.Build(); + + if (userName is not null) + { + app.Use(next => + { + return async context => + { + context.User = CreateUser(userName, roles); + await next(context); + }; + }); + } + + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + + private ClaimsPrincipal CreateUser(string name, params string[] roles) + => new ClaimsPrincipal(new ClaimsIdentity( + [new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name), ..roles.Select(role => new Claim("role", role))], + "TestAuthType", "name", "role")); + + [McpServerToolType] + private class AuthorizationTestTools + { + [McpServerTool, Description("A tool that allows anonymous access.")] + public static string AnonymousTool(string message) + { + return $"Anonymous: {message}"; + } + + [McpServerTool, Description("A tool that requires authorization.")] + [Authorize] + public static string AuthorizedTool(string message) + { + return $"Authorized: {message}"; + } + + [McpServerTool, Description("A tool that requires Admin role.")] + [Authorize(Roles = "Admin")] + public static string AdminTool(string message) + { + return $"Admin: {message}"; + } + } + + [McpServerToolType] + [Authorize] + private class AllowAnonymousTestTools + { + [McpServerTool, Description("A tool that allows anonymous access.")] + [AllowAnonymous] + public static string AnonymousTool(string message) + { + return $"Anonymous: {message}"; + } + + [McpServerTool, Description("A tool that requires authorization.")] + public static string AuthorizedTool(string message) + { + return $"Authorized: {message}"; + } + } + + [McpServerPromptType] + private class AuthorizationTestPrompts + { + [McpServerPrompt, Description("A prompt that allows anonymous access.")] + public static string AnonymousPrompt(string message) + { + return $"Anonymous prompt: {message}"; + } + + [McpServerPrompt, Description("A prompt that requires authorization.")] + [Authorize] + public static string AuthorizedPrompt(string message) + { + return $"Authorized prompt: {message}"; + } + } + + [McpServerResourceType] + private class AuthorizationTestResources + { + [McpServerResource(UriTemplate = "resource://anonymous"), Description("A resource that allows anonymous access.")] + public static string AnonymousResource() + { + return "Anonymous resource content"; + } + + [McpServerResource(UriTemplate = "resource://authorized"), Description("A resource that requires authorization.")] + [Authorize] + public static string AuthorizedResource() + { + return "Authorized resource content"; + } + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs index f3162130..72830407 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs @@ -13,7 +13,7 @@ public class MapMcpSseTests(ITestOutputHelper outputHelper) : MapMcpTests(output [InlineData("/mcp/secondary")] public async Task Allows_Customizing_Route(string pattern) { - Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless); + Builder.Services.AddMcpServer().WithHttpTransport(); await using var app = Builder.Build(); app.MapMcp(pattern); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 4d0d7356..0d867c8f 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -111,6 +111,35 @@ public async Task Messages_FromNewUser_AreRejected() Assert.Equal(HttpStatusCode.Forbidden, httpRequestException.StatusCode); } + [Fact] + public async Task ClaimsPrincipal_CanBeInjectedIntoToolMethod() + { + Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools(); + Builder.Services.AddHttpContextAccessor(); + + await using var app = Builder.Build(); + + app.Use(next => async context => + { + context.User = CreateUser("TestUser"); + await next(context); + }); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var client = await ConnectAsync(); + + var response = await client.CallToolAsync( + "echo_claims_principal", + new Dictionary() { ["message"] = "Hello world!" }, + cancellationToken: TestContext.Current.CancellationToken); + + var content = Assert.Single(response.Content.OfType()); + Assert.Equal("TestUser: Hello world!", content.Text); + } + [Fact] public async Task Sampling_DoesNotCloseStream_Prematurely() { @@ -200,6 +229,17 @@ public string EchoWithUserName(string message) } } + [McpServerToolType] + protected class ClaimsPrincipalTools + { + [McpServerTool, Description("Echoes the input back to the client with the user name from ClaimsPrincipal.")] + public string EchoClaimsPrincipal(ClaimsPrincipal? user, string message) + { + var userName = user?.Identity?.Name ?? "anonymous"; + return $"{userName}: {message}"; + } + } + [McpServerToolType] private class SamplingRegressionTools { diff --git a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs index ec1c8510..d9b699b9 100644 --- a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs +++ b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs @@ -20,7 +20,8 @@ public ClientServerTestBase(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { ServiceCollection sc = new(); - sc.AddSingleton(LoggerFactory); + sc.AddLogging(); + sc.AddSingleton(XunitLoggerProvider); _builder = sc .AddMcpServer() .WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs new file mode 100644 index 00000000..6a7d0044 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs @@ -0,0 +1,314 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; + +namespace ModelContextProtocol.Tests.Configuration; + +public class McpServerBuilderExtensionsFilterTests : ClientServerTestBase +{ + public McpServerBuilderExtensionsFilterTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + } + + private MockLoggerProvider _mockLoggerProvider = new(); + + private static ILogger GetLogger(IServiceProvider? services, string categoryName) + { + var loggerFactory = services?.GetRequiredService() ?? throw new InvalidOperationException("LoggerFactory not available"); + return loggerFactory.CreateLogger(categoryName); + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + mcpServerBuilder + .AddListResourceTemplatesFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListResourceTemplatesFilter"); + logger.LogInformation("ListResourceTemplatesFilter executed"); + return await next(request, cancellationToken); + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListToolsFilter"); + logger.LogInformation("ListToolsFilter executed"); + return await next(request, cancellationToken); + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListToolsOrder1"); + logger.LogInformation("ListToolsOrder1 before"); + var result = await next(request, cancellationToken); + logger.LogInformation("ListToolsOrder1 after"); + return result; + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListToolsOrder2"); + logger.LogInformation("ListToolsOrder2 before"); + var result = await next(request, cancellationToken); + logger.LogInformation("ListToolsOrder2 after"); + return result; + }) + .AddCallToolFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "CallToolFilter"); + var primitiveId = request.MatchedPrimitive?.Id ?? "unknown"; + logger.LogInformation($"CallToolFilter executed for tool: {primitiveId}"); + return await next(request, cancellationToken); + }) + .AddListPromptsFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListPromptsFilter"); + logger.LogInformation("ListPromptsFilter executed"); + return await next(request, cancellationToken); + }) + .AddGetPromptFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "GetPromptFilter"); + var primitiveId = request.MatchedPrimitive?.Id ?? "unknown"; + logger.LogInformation($"GetPromptFilter executed for prompt: {primitiveId}"); + return await next(request, cancellationToken); + }) + .AddListResourcesFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListResourcesFilter"); + logger.LogInformation("ListResourcesFilter executed"); + return await next(request, cancellationToken); + }) + .AddReadResourceFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ReadResourceFilter"); + var primitiveId = request.MatchedPrimitive?.Id ?? "unknown"; + logger.LogInformation($"ReadResourceFilter executed for resource: {primitiveId}"); + return await next(request, cancellationToken); + }) + .AddCompleteFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "CompleteFilter"); + logger.LogInformation("CompleteFilter executed"); + return await next(request, cancellationToken); + }) + .AddSubscribeToResourcesFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "SubscribeToResourcesFilter"); + logger.LogInformation("SubscribeToResourcesFilter executed"); + return await next(request, cancellationToken); + }) + .AddUnsubscribeFromResourcesFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "UnsubscribeFromResourcesFilter"); + logger.LogInformation("UnsubscribeFromResourcesFilter executed"); + return await next(request, cancellationToken); + }) + .AddSetLoggingLevelFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "SetLoggingLevelFilter"); + logger.LogInformation("SetLoggingLevelFilter executed"); + return await next(request, cancellationToken); + }) + .WithTools() + .WithPrompts() + .WithResources() + .WithSetLoggingLevelHandler(async (request, cancellationToken) => new EmptyResult()) + .WithListResourceTemplatesHandler(async (request, cancellationToken) => new ListResourceTemplatesResult + { + ResourceTemplates = [new() { Name = "test", UriTemplate = "test://resource/{id}" }] + }) + .WithCompleteHandler(async (request, cancellationToken) => new CompleteResult + { + Completion = new() { Values = ["test"] } + }); + + services.AddSingleton(_mockLoggerProvider); + } + + [Fact] + public async Task AddListResourceTemplatesFilter_Logs_When_ListResourceTemplates_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.ListResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ListResourceTemplatesFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ListResourceTemplatesFilter", logMessage.Category); + } + + [Fact] + public async Task AddListToolsFilter_Logs_When_ListTools_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ListToolsFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ListToolsFilter", logMessage.Category); + } + + [Fact] + public async Task AddCallToolFilter_Logs_When_CallTool_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.CallToolAsync("test_tool_method", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "CallToolFilter executed for tool: test_tool_method"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("CallToolFilter", logMessage.Category); + } + + [Fact] + public async Task AddListPromptsFilter_Logs_When_ListPrompts_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ListPromptsFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ListPromptsFilter", logMessage.Category); + } + + [Fact] + public async Task AddGetPromptFilter_Logs_When_GetPrompt_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.GetPromptAsync("test_prompt_method", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "GetPromptFilter executed for prompt: test_prompt_method"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("GetPromptFilter", logMessage.Category); + } + + [Fact] + public async Task AddListResourcesFilter_Logs_When_ListResources_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ListResourcesFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ListResourcesFilter", logMessage.Category); + } + + [Fact] + public async Task AddReadResourceFilter_Logs_When_ReadResource_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.ReadResourceAsync("test://resource/123", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ReadResourceFilter executed for resource: test://resource/{id}"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ReadResourceFilter", logMessage.Category); + } + + [Fact] + public async Task AddCompleteFilter_Logs_When_Complete_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + var reference = new PromptReference { Name = "test_prompt_method" }; + await client.CompleteAsync(reference, "argument", "value", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "CompleteFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("CompleteFilter", logMessage.Category); + } + + [Fact] + public async Task AddSubscribeToResourcesFilter_Logs_When_SubscribeToResources_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.SubscribeToResourceAsync("test://resource/123", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "SubscribeToResourcesFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("SubscribeToResourcesFilter", logMessage.Category); + } + + [Fact] + public async Task AddUnsubscribeFromResourcesFilter_Logs_When_UnsubscribeFromResources_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.UnsubscribeFromResourceAsync("test://resource/123", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "UnsubscribeFromResourcesFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("UnsubscribeFromResourcesFilter", logMessage.Category); + } + + [Fact] + public async Task AddSetLoggingLevelFilter_Logs_When_SetLoggingLevel_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.SetLoggingLevel(LoggingLevel.Info, cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "SetLoggingLevelFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("SetLoggingLevelFilter", logMessage.Category); + } + + [Fact] + public async Task AddListToolsFilter_Multiple_Filters_Log_In_Expected_Order() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessages = _mockLoggerProvider.LogMessages + .Where(m => m.Category.StartsWith("ListToolsOrder")) + .Select(m => m.Message); + + Assert.Collection(logMessages, + m => Assert.Equal("ListToolsOrder1 before", m), + m => Assert.Equal("ListToolsOrder2 before", m), + m => Assert.Equal("ListToolsOrder2 after", m), + m => Assert.Equal("ListToolsOrder1 after", m) + ); + } + + [McpServerToolType] + public sealed class TestTool + { + [McpServerTool] + public static string TestToolMethod() + { + return "test result"; + } + } + + [McpServerPromptType] + public sealed class TestPrompt + { + [McpServerPrompt] + public static Task TestPromptMethod() + { + return Task.FromResult(new GetPromptResult + { + Description = "Test prompt", + Messages = [new() { Role = Role.User, Content = new TextContentBlock { Text = "Test" } }] + }); + } + } + + [McpServerResourceType] + public sealed class TestResource + { + [McpServerResource(UriTemplate = "test://resource/{id}")] + public static string TestResourceMethod(string id) + { + return $"Test resource for ID: {id}"; + } + } +} diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 35f833d5..82a2b6b6 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -5,6 +5,7 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; using System.Collections.Concurrent; using System.ComponentModel; using System.IO.Pipelines; @@ -22,6 +23,8 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) { } + private MockLoggerProvider _mockLoggerProvider = new(); + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) { mcpServerBuilder @@ -107,6 +110,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer .WithTools(serializerOptions: BuilderToolsJsonContext.Default.Options); services.AddSingleton(new ObjectWithId()); + services.AddSingleton(_mockLoggerProvider); } [Fact] @@ -155,8 +159,8 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T await using (var client = await McpClientFactory.CreateAsync( new StreamClientTransport( - serverInput: stdinPipe.Writer.AsStream(), - serverOutput: stdoutPipe.Reader.AsStream(), + serverInput: stdinPipe.Writer.AsStream(), + serverOutput: stdoutPipe.Reader.AsStream(), LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)) @@ -230,7 +234,7 @@ public async Task Can_Call_Registered_Tool() var result = await client.CallToolAsync( "echo", - new Dictionary() { ["message"] = "Peter" }, + new Dictionary() { ["message"] = "Peter" }, cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result); @@ -351,14 +355,14 @@ public async Task Can_Call_Registered_Tool_With_Instance_Method() string random1 = parts[0][0]; string random2 = parts[1][0]; Assert.NotEqual(random1, random2); - + string id1 = parts[0][1]; string id2 = parts[1][1]; Assert.Equal(id1, id2); } [Fact] - public async Task Returns_IsError_Content_When_Tool_Fails() + public async Task Returns_IsError_Content_And_Logs_Error_When_Tool_Fails() { await using IMcpClient client = await CreateMcpClientForServer(); @@ -370,6 +374,11 @@ public async Task Returns_IsError_Content_When_Tool_Fails() Assert.NotNull(result.Content); Assert.NotEmpty(result.Content); Assert.Contains("An error occurred", (result.Content[0] as TextContentBlock)?.Text); + + var errorLog = Assert.Single(_mockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Error); + Assert.Equal($"\"throw_exception\" threw an unhandled exception.", errorLog.Message); + Assert.IsType(errorLog.Exception); + Assert.Equal("Test error", errorLog.Exception.Message); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs index 39e9b72f..307e086a 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Primitives; using ModelContextProtocol.Protocol; @@ -15,6 +15,16 @@ namespace ModelContextProtocol.Tests.Server; public class McpServerPromptTests { + private static JsonRpcRequest CreateTestJsonRpcRequest() + { + return new JsonRpcRequest + { + Id = new RequestId("test-id"), + Method = "test/method", + Params = null + }; + } + public McpServerPromptTests() { #if !NET @@ -46,7 +56,7 @@ public async Task SupportsIMcpServer() Assert.DoesNotContain("server", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); var result = await prompt.GetAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Messages); @@ -75,7 +85,7 @@ public async Task SupportsCtorInjection() }, new() { Services = services }); var result = await prompt.GetAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Messages); @@ -125,11 +135,11 @@ public async Task SupportsServiceFromDI() Assert.DoesNotContain("actualMyService", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); await Assert.ThrowsAnyAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken)); var result = await prompt.GetAsync( - new RequestContext(new Mock().Object) { Services = services }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("Hello", Assert.IsType(result.Messages[0].Content).Text); } @@ -150,7 +160,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services }); var result = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("Hello", Assert.IsType(result.Messages[0].Content).Text); } @@ -163,7 +173,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() _ => new DisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("disposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -176,7 +186,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() _ => new AsyncDisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("asyncDisposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -189,7 +199,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable _ => new AsyncDisposableAndDisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("disposals:0, asyncDisposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -205,7 +215,7 @@ public async Task CanReturnGetPromptResult() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Same(expected, actual); @@ -222,7 +232,7 @@ public async Task CanReturnText() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -248,7 +258,7 @@ public async Task CanReturnPromptMessage() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -260,7 +270,7 @@ public async Task CanReturnPromptMessage() [Fact] public async Task CanReturnPromptMessages() { - IList expected = + IList expected = [ new() { @@ -280,7 +290,7 @@ public async Task CanReturnPromptMessages() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -307,7 +317,7 @@ public async Task CanReturnChatMessage() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -339,7 +349,7 @@ public async Task CanReturnChatMessages() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -360,7 +370,7 @@ public async Task ThrowsForNullReturn() }); await Assert.ThrowsAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken)); } @@ -373,7 +383,7 @@ public async Task ThrowsForUnexpectedTypeReturn() }); await Assert.ThrowsAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken)); } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs index 011c4f2b..df0b6537 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; @@ -11,6 +11,16 @@ namespace ModelContextProtocol.Tests.Server; public partial class McpServerResourceTests { + private static JsonRpcRequest CreateTestJsonRpcRequest() + { + return new JsonRpcRequest + { + Id = new RequestId("test-id"), + Method = "test/method", + Params = null + }; + } + public McpServerResourceTests() { #if !NET @@ -138,7 +148,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create(() => "42", new() { Name = Name }); Assert.Equal("resource://mcp/Hello", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = "resource://mcp/Hello" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -146,7 +156,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((IMcpServer server) => "42", new() { Name = Name }); Assert.Equal("resource://mcp/Hello", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = "resource://mcp/Hello" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -154,7 +164,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((string arg1) => arg1, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?arg1}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wOrLd" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wOrLd" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("wOrLd", ((TextResourceContents)result.Contents[0]).Text); @@ -162,7 +172,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((string arg1, string? arg2 = null) => arg1 + arg2, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?arg1,arg2}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wo&arg2=rld" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wo&arg2=rld" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("world", ((TextResourceContents)result.Contents[0]).Text); @@ -170,7 +180,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((object a1, bool a2, char a3, byte a4, sbyte a5) => a1.ToString() + a2 + a3 + a4 + a5, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=hi&a2=true&a3=s&a4=12&a5=34" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=hi&a2=true&a3=s&a4=12&a5=34" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("hiTrues1234", ((TextResourceContents)result.Contents[0]).Text); @@ -178,7 +188,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((ushort a1, short a2, uint a3, int a4, ulong a5) => (a1 + a2 + a3 + a4 + (long)a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("150", ((TextResourceContents)result.Contents[0]).Text); @@ -186,7 +196,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((long a1, float a2, double a3, decimal a4, TimeSpan a5) => a5.ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("5.00:00:00", ((TextResourceContents)result.Contents[0]).Text); @@ -194,7 +204,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((DateTime a1, DateTimeOffset a2, Uri a3, Guid a4, Version a5) => a4.ToString("N") + a5, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a3=http%3A%2F%2Ftest&a4=14e5f43d-0d41-47d6-8207-8249cf669e41&a5=1.2.3.4" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a3=http%3A%2F%2Ftest&a4=14e5f43d-0d41-47d6-8207-8249cf669e41&a5=1.2.3.4" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("14e5f43d0d4147d682078249cf669e411.2.3.4", ((TextResourceContents)result.Contents[0]).Text); @@ -203,7 +213,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((Half a2, Int128 a3, UInt128 a4, IntPtr a5) => (a3 + (Int128)a4 + a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("12", ((TextResourceContents)result.Contents[0]).Text); @@ -211,7 +221,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((UIntPtr a1, DateOnly a2, TimeOnly a3) => a1.ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("123", ((TextResourceContents)result.Contents[0]).Text); @@ -220,7 +230,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((bool? a2, char? a3, byte? a4, sbyte? a5) => a2?.ToString() + a3 + a4 + a5, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a2=true&a3=s&a4=12&a5=34" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a2=true&a3=s&a4=12&a5=34" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("Trues1234", ((TextResourceContents)result.Contents[0]).Text); @@ -228,7 +238,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((ushort? a1, short? a2, uint? a3, int? a4, ulong? a5) => (a1 + a2 + a3 + a4 + (long?)a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("150", ((TextResourceContents)result.Contents[0]).Text); @@ -236,7 +246,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((long? a1, float? a2, double? a3, decimal? a4, TimeSpan? a5) => a5?.ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("5.00:00:00", ((TextResourceContents)result.Contents[0]).Text); @@ -244,7 +254,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((DateTime? a1, DateTimeOffset? a2, Guid? a4) => a4?.ToString("N"), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a4}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a4=14e5f43d-0d41-47d6-8207-8249cf669e41" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a4=14e5f43d-0d41-47d6-8207-8249cf669e41" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("14e5f43d0d4147d682078249cf669e41", ((TextResourceContents)result.Contents[0]).Text); @@ -253,7 +263,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((Half? a2, Int128? a3, UInt128? a4, IntPtr? a5) => (a3 + (Int128?)a4 + a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("12", ((TextResourceContents)result.Contents[0]).Text); @@ -261,7 +271,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((UIntPtr? a1, DateOnly? a2, TimeOnly? a3) => a1?.ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("123", ((TextResourceContents)result.Contents[0]).Text); @@ -277,7 +287,7 @@ public async Task UriTemplate_NonMatchingUri_ReturnsNull(string uri) McpServerResource t = McpServerResource.Create((string arg1) => arg1, new() { Name = "Hello" }); Assert.Equal("resource://mcp/Hello{?arg1}", t.ProtocolResourceTemplate.UriTemplate); Assert.Null(await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = uri } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = uri } }, TestContext.Current.CancellationToken)); } @@ -288,7 +298,7 @@ public async Task UriTemplate_IsHostCaseInsensitive(string actualUri, string que { McpServerResource t = McpServerResource.Create(() => "resource", new() { UriTemplate = actualUri }); Assert.NotNull(await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = queriedUri } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = queriedUri } }, TestContext.Current.CancellationToken)); } @@ -317,7 +327,7 @@ public async Task UriTemplate_MissingParameter_Throws(string uri) McpServerResource t = McpServerResource.Create((string arg1, int arg2) => arg1, new() { Name = "Hello" }); Assert.Equal("resource://mcp/Hello{?arg1,arg2}", t.ProtocolResourceTemplate.UriTemplate); await Assert.ThrowsAsync(async () => await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = uri } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = uri } }, TestContext.Current.CancellationToken)); } @@ -330,25 +340,25 @@ public async Task UriTemplate_MissingOptionalParameter_Succeeds() ReadResourceResult? result; result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg1=first" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg1=first" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("first", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg2=42" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg2=42" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg1=first&arg2=42" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg1=first&arg2=42" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("first42", ((TextResourceContents)result.Contents[0]).Text); @@ -366,7 +376,7 @@ public async Task SupportsIMcpServer() }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -393,7 +403,7 @@ public async Task SupportsCtorInjection() }, new() { Services = services }); var result = await tool.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "https://something" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "https://something" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Contents); @@ -470,11 +480,11 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime Mock mockServer = new(); await Assert.ThrowsAnyAsync(async () => await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken)); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Services = services, Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Services = services, Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -496,7 +506,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services, Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -512,7 +522,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() _ => new DisposableResourceType()); var result = await resource1.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "test://static/resource/instanceMethod" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "test://static/resource/instanceMethod" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("0", ((TextResourceContents)result.Contents[0]).Text); @@ -530,7 +540,7 @@ public async Task CanReturnReadResult() return new ReadResourceResult { Contents = new List { new TextResourceContents { Text = "hello" } } }; }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -547,7 +557,7 @@ public async Task CanReturnResourceContents() return new TextResourceContents { Text = "hello" }; }, new() { Name = "Test", SerializerOptions = JsonContext6.Default.Options }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -568,7 +578,7 @@ public async Task CanReturnCollectionOfResourceContents() ]; }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal(2, result.Contents.Count); @@ -586,7 +596,7 @@ public async Task CanReturnString() return "42"; }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -603,7 +613,7 @@ public async Task CanReturnCollectionOfStrings() return new List { "42", "43" }; }, new() { Name = "Test", SerializerOptions = JsonContext6.Default.Options }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal(2, result.Contents.Count); @@ -621,7 +631,7 @@ public async Task CanReturnDataContent() return new DataContent(new byte[] { 0, 1, 2 }, "application/octet-stream"); }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -643,7 +653,7 @@ public async Task CanReturnCollectionOfAIContent() }; }, new() { Name = "Test", SerializerOptions = JsonContext6.Default.Options }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal(2, result.Contents.Count); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index f961eef3..ca2ab783 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -1,10 +1,8 @@ -using Json.Schema; +using Json.Schema; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; using Moq; using System.Reflection; using System.Runtime.InteropServices; @@ -18,6 +16,16 @@ namespace ModelContextProtocol.Tests.Server; public partial class McpServerToolTests { + private static JsonRpcRequest CreateTestJsonRpcRequest() + { + return new JsonRpcRequest + { + Id = new RequestId("test-id"), + Method = "test/method", + Params = null + }; + } + public McpServerToolTests() { #if !NET @@ -53,7 +61,7 @@ public async Task SupportsIMcpServer() Assert.DoesNotContain("server", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema, McpJsonUtilities.DefaultOptions)); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); } @@ -79,7 +87,7 @@ public async Task SupportsCtorInjection() }, new() { Services = services }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Content); @@ -156,13 +164,14 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime Mock mockServer = new(); - var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), - TestContext.Current.CancellationToken); - Assert.True(result.IsError); + var ex = await Assert.ThrowsAsync(async () => await tool.InvokeAsync( + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), + TestContext.Current.CancellationToken)); - result = await tool.InvokeAsync( - new RequestContext(mockServer.Object) { Services = services }, + mockServer.SetupGet(s => s.Services).Returns(services); + + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); } @@ -183,7 +192,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services }); var result = await tool.InvokeAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); } @@ -198,7 +207,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("""{"disposals":1}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -213,7 +222,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("""{"asyncDisposals":1}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -232,7 +241,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object) { Services = services }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("""{"asyncDisposals":1,"disposals":0}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -253,7 +262,7 @@ public async Task CanReturnCollectionOfAIContent() }, new() { SerializerOptions = JsonContext2.Default.Options }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal(3, result.Content.Count); @@ -287,7 +296,7 @@ public async Task CanReturnSingleAIContent(string data, string type) }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Single(result.Content); @@ -323,7 +332,7 @@ public async Task CanReturnNullAIContent() return (string?)null; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Empty(result.Content); } @@ -338,7 +347,7 @@ public async Task CanReturnString() return "42"; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Single(result.Content); Assert.Equal("42", Assert.IsType(result.Content[0]).Text); @@ -354,7 +363,7 @@ public async Task CanReturnCollectionOfStrings() return new List { "42", "43" }; }, new() { SerializerOptions = JsonContext2.Default.Options }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Single(result.Content); Assert.Equal("""["42","43"]""", Assert.IsType(result.Content[0]).Text); @@ -370,7 +379,7 @@ public async Task CanReturnMcpContent() return new TextContentBlock { Text = "42" }; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Single(result.Content); Assert.Equal("42", Assert.IsType(result.Content[0]).Text); @@ -386,12 +395,12 @@ public async Task CanReturnCollectionOfMcpContent() Assert.Same(mockServer.Object, server); return (IList) [ - new TextContentBlock { Text = "42" }, - new ImageContentBlock { Data = "1234", MimeType = "image/png" } + new TextContentBlock { Text = "42" }, + new ImageContentBlock { Data = "1234", MimeType = "image/png" } ]; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal(2, result.Content.Count); Assert.Equal("42", Assert.IsType(result.Content[0]).Text); @@ -414,7 +423,7 @@ public async Task CanReturnCallToolResult() return response; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Same(response, result); @@ -447,45 +456,6 @@ public async Task SupportsSchemaCreateOptions() ); } - [Fact] - public async Task ToolCallError_LogsErrorMessage() - { - // Arrange - var mockLoggerProvider = new MockLoggerProvider(); - var loggerFactory = new LoggerFactory(new[] { mockLoggerProvider }); - var services = new ServiceCollection(); - services.AddSingleton(loggerFactory); - var serviceProvider = services.BuildServiceProvider(); - - var toolName = "tool-that-throws"; - var exceptionMessage = "Test exception message"; - - McpServerTool tool = McpServerTool.Create(() => - { - throw new InvalidOperationException(exceptionMessage); - }, new() { Name = toolName, Services = serviceProvider }); - - var mockServer = new Mock(); - var request = new RequestContext(mockServer.Object) - { - Params = new CallToolRequestParams { Name = toolName }, - Services = serviceProvider - }; - - // Act - var result = await tool.InvokeAsync(request, TestContext.Current.CancellationToken); - - // Assert - Assert.True(result.IsError); - Assert.Single(result.Content); - Assert.Equal($"An error occurred invoking '{toolName}'.", Assert.IsType(result.Content[0]).Text); - - var errorLog = Assert.Single(mockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Error); - Assert.Equal($"\"{toolName}\" threw an unhandled exception.", errorLog.Message); - Assert.IsType(errorLog.Exception); - Assert.Equal(exceptionMessage, errorLog.Exception.Message); - } - [Theory] [MemberData(nameof(StructuredOutput_ReturnsExpectedSchema_Inputs))] public async Task StructuredOutput_Enabled_ReturnsExpectedSchema(T value) @@ -493,7 +463,7 @@ public async Task StructuredOutput_Enabled_ReturnsExpectedSchema(T value) JsonSerializerOptions options = new() { TypeInfoResolver = new DefaultJsonTypeInfoResolver() }; McpServerTool tool = McpServerTool.Create(() => value, new() { Name = "tool", UseStructuredContent = true, SerializerOptions = options }); var mockServer = new Mock(); - var request = new RequestContext(mockServer.Object) + var request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -511,7 +481,7 @@ public async Task StructuredOutput_Enabled_VoidReturningTools_ReturnsExpectedSch { McpServerTool tool = McpServerTool.Create(() => { }); var mockServer = new Mock(); - var request = new RequestContext(mockServer.Object) + var request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -522,7 +492,7 @@ public async Task StructuredOutput_Enabled_VoidReturningTools_ReturnsExpectedSch Assert.Null(result.StructuredContent); tool = McpServerTool.Create(() => Task.CompletedTask); - request = new RequestContext(mockServer.Object) + request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -533,7 +503,7 @@ public async Task StructuredOutput_Enabled_VoidReturningTools_ReturnsExpectedSch Assert.Null(result.StructuredContent); tool = McpServerTool.Create(() => default(ValueTask)); - request = new RequestContext(mockServer.Object) + request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -551,7 +521,7 @@ public async Task StructuredOutput_Disabled_ReturnsExpectedSchema(T value) JsonSerializerOptions options = new() { TypeInfoResolver = new DefaultJsonTypeInfoResolver() }; McpServerTool tool = McpServerTool.Create(() => value, new() { UseStructuredContent = false, SerializerOptions = options }); var mockServer = new Mock(); - var request = new RequestContext(mockServer.Object) + var request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -592,7 +562,7 @@ public static IEnumerable StructuredOutput_ReturnsExpectedSchema_Input yield return new object[] { new() }; yield return new object[] { new List { "item1", "item2" } }; yield return new object[] { new Dictionary { ["key1"] = 1, ["key2"] = 2 } }; - yield return new object[] { new Person("John", 27) }; + yield return new object[] { new Person("John", 27) }; } private sealed class MyService;