diff --git a/samples/DynamicToolFiltering/.dockerignore b/samples/DynamicToolFiltering/.dockerignore new file mode 100644 index 00000000..3433a5f3 --- /dev/null +++ b/samples/DynamicToolFiltering/.dockerignore @@ -0,0 +1,70 @@ +# Build artifacts +bin/ +obj/ +*.dll +*.exe +*.pdb +out/ + +# Development files +.vs/ +.vscode/ +*.user +*.suo +.idea/ + +# Logs and temporary files +logs/ +*.log +*.tmp +*.temp + +# OS files +.DS_Store +Thumbs.db +*.swp +*.swo +*~ + +# Git +.git/ +.gitignore +.gitattributes + +# Documentation (exclude to reduce image size) +docs/ +*.md +LICENSE + +# Test files and results +tests/ +TestResults/ +coverage/ +*.http + +# Node modules (if any) +node_modules/ +npm-debug.log* + +# Docker files (don't include in context) +Dockerfile* +docker-compose*.yml +.dockerignore + +# Scripts (exclude deployment scripts) +scripts/ +*.sh +*.ps1 + +# Environment files (may contain secrets) +.env* +appsettings.Development.json +appsettings.Local.json + +# Data and cache directories +data/ +cache/ + +# Monitoring configuration +monitoring/ +nginx/ \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Authorization/Filters/BusinessLogicFilter.cs b/samples/DynamicToolFiltering/Authorization/Filters/BusinessLogicFilter.cs new file mode 100644 index 00000000..60a402bb --- /dev/null +++ b/samples/DynamicToolFiltering/Authorization/Filters/BusinessLogicFilter.cs @@ -0,0 +1,311 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server.Authorization; +using DynamicToolFiltering.Configuration; +using DynamicToolFiltering.Services; +using Microsoft.Extensions.Options; +using System.Security.Claims; +using System.Text.RegularExpressions; + +namespace DynamicToolFiltering.Authorization.Filters; + +/// +/// Business logic filter that implements complex business rules including feature flags, +/// quota management, and environment-based restrictions. +/// +public class BusinessLogicFilter : IToolFilter +{ + private readonly BusinessLogicFilteringOptions _options; + private readonly IFeatureFlagService _featureFlagService; + private readonly IQuotaService _quotaService; + private readonly IWebHostEnvironment _environment; + private readonly ILogger _logger; + + public BusinessLogicFilter( + IOptions options, + IFeatureFlagService featureFlagService, + IQuotaService quotaService, + IWebHostEnvironment environment, + ILogger logger) + { + _options = options.Value.BusinessLogic; + _featureFlagService = featureFlagService; + _quotaService = quotaService; + _environment = environment; + _logger = logger; + } + + public int Priority => _options.Priority; + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return true; + } + + var canAccess = await CanAccessToolAsync(tool.Name, context, cancellationToken); + + _logger.LogDebug("Tool inclusion check for {ToolName}: CanAccess: {CanAccess}", tool.Name, canAccess); + + return canAccess; + } + + public async Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return AuthorizationResult.Allow("Business logic filtering disabled"); + } + + // Check environment restrictions first + var environmentCheck = CheckEnvironmentRestrictions(toolName); + if (!environmentCheck.IsAuthorized) + { + return environmentCheck; + } + + // Check feature flags + var featureFlagCheck = await CheckFeatureFlagsAsync(toolName, context, cancellationToken); + if (!featureFlagCheck.IsAuthorized) + { + return featureFlagCheck; + } + + // Check quota limits + var quotaCheck = await CheckQuotaLimitsAsync(toolName, context, cancellationToken); + if (!quotaCheck.IsAuthorized) + { + return quotaCheck; + } + + _logger.LogDebug("Tool execution authorized by business logic filter: {ToolName}", toolName); + return AuthorizationResult.Allow($"Tool '{toolName}' passes all business logic checks"); + } + + private async Task CanAccessToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken) + { + // Check environment restrictions + if (!CheckEnvironmentRestrictions(toolName).IsAuthorized) + { + return false; + } + + // Check feature flags + if (_options.FeatureFlags.Enabled) + { + var featureFlag = GetFeatureFlagForTool(toolName); + if (featureFlag != null) + { + var userId = GetUserId(context); + var isEnabled = await _featureFlagService.IsEnabledAsync(featureFlag, userId, cancellationToken); + if (!isEnabled) + { + return false; + } + } + } + + // Check quota availability (basic check for visibility) + if (_options.QuotaManagement.Enabled) + { + var userId = GetUserId(context); + var userRole = GetUserRole(context); + var hasQuota = await _quotaService.HasAvailableQuotaAsync(userId, userRole, toolName, cancellationToken); + if (!hasQuota) + { + return false; + } + } + + return true; + } + + private AuthorizationResult CheckEnvironmentRestrictions(string toolName) + { + if (!_options.EnvironmentRestrictions.Enabled) + { + return AuthorizationResult.Allow("Environment restrictions disabled"); + } + + var environmentName = _environment.EnvironmentName; + + // Check production restrictions + if (string.Equals(environmentName, "Production", StringComparison.OrdinalIgnoreCase)) + { + if (IsToolMatched(toolName, _options.EnvironmentRestrictions.ProductionRestrictedTools)) + { + var reason = $"Tool '{toolName}' is restricted in production environment"; + + _logger.LogWarning("Tool execution denied in production: {ToolName}", toolName); + + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Environment", + ("realm", "mcp-api"), + ("environment", environmentName), + ("restriction", "production_restricted")); + + return AuthorizationResult.DenyWithChallenge(reason, challenge); + } + } + + // Check development-only tools + if (!string.Equals(environmentName, "Development", StringComparison.OrdinalIgnoreCase)) + { + if (IsToolMatched(toolName, _options.EnvironmentRestrictions.DevelopmentOnlyTools)) + { + var reason = $"Tool '{toolName}' is only available in development environment"; + + _logger.LogWarning("Tool execution denied - development only: {ToolName}", toolName); + + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Environment", + ("realm", "mcp-api"), + ("environment", environmentName), + ("restriction", "development_only")); + + return AuthorizationResult.DenyWithChallenge(reason, challenge); + } + } + + return AuthorizationResult.Allow("Environment restrictions passed"); + } + + private async Task CheckFeatureFlagsAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken) + { + if (!_options.FeatureFlags.Enabled) + { + return AuthorizationResult.Allow("Feature flags disabled"); + } + + var featureFlag = GetFeatureFlagForTool(toolName); + if (featureFlag == null) + { + return AuthorizationResult.Allow("No feature flag required"); + } + + var userId = GetUserId(context); + var isEnabled = await _featureFlagService.IsEnabledAsync(featureFlag, userId, cancellationToken); + + if (!isEnabled) + { + var reason = $"Tool '{toolName}' is disabled by feature flag '{featureFlag}'"; + + _logger.LogWarning("Tool execution denied by feature flag: {ToolName}, Flag: {FeatureFlag}", toolName, featureFlag); + + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "FeatureFlag", + ("realm", "mcp-api"), + ("feature_flag", featureFlag), + ("tool_name", toolName)); + + return AuthorizationResult.DenyWithChallenge(reason, challenge); + } + + return AuthorizationResult.Allow($"Feature flag '{featureFlag}' enabled"); + } + + private async Task CheckQuotaLimitsAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken) + { + if (!_options.QuotaManagement.Enabled) + { + return AuthorizationResult.Allow("Quota management disabled"); + } + + var userId = GetUserId(context); + var userRole = GetUserRole(context); + + // Check if user has available quota + var hasQuota = await _quotaService.HasAvailableQuotaAsync(userId, userRole, toolName, cancellationToken); + if (!hasQuota) + { + var currentUsage = await _quotaService.GetCurrentUsageAsync(userId, cancellationToken); + var quotaLimit = await _quotaService.GetQuotaLimitAsync(userId, userRole, cancellationToken); + + var reason = $"Quota exceeded for tool '{toolName}'. Usage: {currentUsage}/{quotaLimit}"; + + _logger.LogWarning("Tool execution denied - quota exceeded: {ToolName}, User: {UserId}, Usage: {Usage}/{Limit}", + toolName, userId, currentUsage, quotaLimit); + + var resetDate = await _quotaService.GetQuotaResetDateAsync(userId, cancellationToken); + + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Quota", + ("realm", "mcp-api"), + ("current_usage", currentUsage.ToString()), + ("quota_limit", quotaLimit.ToString()), + ("reset_date", resetDate.ToString("O")), + ("tool_name", toolName)); + + return AuthorizationResult.DenyWithChallenge(reason, challenge); + } + + // Consume quota for this operation + var quotaCost = GetQuotaCost(toolName); + await _quotaService.ConsumeQuotaAsync(userId, toolName, quotaCost, cancellationToken); + + var remainingQuota = await _quotaService.GetRemainingQuotaAsync(userId, userRole, cancellationToken); + + _logger.LogDebug("Quota consumed for tool: {ToolName}, User: {UserId}, Cost: {Cost}, Remaining: {Remaining}", + toolName, userId, quotaCost, remainingQuota); + + return AuthorizationResult.Allow($"Quota available. Cost: {quotaCost}, Remaining: {remainingQuota}"); + } + + private string GetUserId(ToolAuthorizationContext context) + { + return context.User?.FindFirst(ClaimTypes.NameIdentifier)?.Value + ?? context.User?.FindFirst("sub")?.Value + ?? context.User?.FindFirst("user_id")?.Value + ?? "anonymous"; + } + + private string GetUserRole(ToolAuthorizationContext context) + { + return context.User?.FindFirst(ClaimTypes.Role)?.Value + ?? context.User?.FindFirst("role")?.Value + ?? (context.User?.Identity?.IsAuthenticated == true ? "user" : "guest"); + } + + private string? GetFeatureFlagForTool(string toolName) + { + foreach (var mapping in _options.FeatureFlags.ToolFeatureMapping) + { + if (IsPatternMatch(mapping.Key, toolName)) + { + return mapping.Value; + } + } + + return null; + } + + private int GetQuotaCost(string toolName) + { + foreach (var mapping in _options.QuotaManagement.ToolQuotaCosts) + { + if (IsPatternMatch(mapping.Key, toolName)) + { + return mapping.Value; + } + } + + return 1; // Default cost + } + + private bool IsToolMatched(string toolName, string[] patterns) + { + return patterns.Any(pattern => IsPatternMatch(pattern, toolName)); + } + + private static bool IsPatternMatch(string pattern, string toolName) + { + if (pattern == "*") + { + return true; + } + + // Convert glob pattern to regex + var regexPattern = "^" + Regex.Escape(pattern).Replace("\\*", ".*") + "$"; + return Regex.IsMatch(toolName, regexPattern, RegexOptions.IgnoreCase); + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Authorization/Filters/RateLimitingToolFilter.cs b/samples/DynamicToolFiltering/Authorization/Filters/RateLimitingToolFilter.cs new file mode 100644 index 00000000..90fd4861 --- /dev/null +++ b/samples/DynamicToolFiltering/Authorization/Filters/RateLimitingToolFilter.cs @@ -0,0 +1,177 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server.Authorization; +using DynamicToolFiltering.Configuration; +using DynamicToolFiltering.Services; +using Microsoft.Extensions.Options; +using System.Security.Claims; +using System.Text.RegularExpressions; + +namespace DynamicToolFiltering.Authorization.Filters; + +/// +/// Rate limiting tool filter that implements quota and rate limiting functionality. +/// Supports both role-based and tool-specific rate limits with sliding or fixed windows. +/// +public class RateLimitingToolFilter : IToolFilter +{ + private readonly RateLimitingOptions _options; + private readonly IRateLimitingService _rateLimitingService; + private readonly ILogger _logger; + + public RateLimitingToolFilter( + IOptions options, + IRateLimitingService rateLimitingService, + ILogger logger) + { + _options = options.Value.RateLimiting; + _rateLimitingService = rateLimitingService; + _logger = logger; + } + + public int Priority => _options.Priority; + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Rate limiting doesn't affect tool visibility, only execution + return Task.FromResult(true); + } + + public async Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return AuthorizationResult.Allow("Rate limiting disabled"); + } + + var userId = GetUserId(context); + var userRole = GetUserRole(context); + + // Get applicable rate limit for this user/tool combination + var rateLimit = GetApplicableRateLimit(toolName, userRole); + + if (rateLimit == -1) + { + // Unlimited access + _logger.LogDebug("Tool execution authorized (unlimited): {ToolName} for user {UserId}", toolName, userId); + return AuthorizationResult.Allow($"User has unlimited access to tool '{toolName}'"); + } + + // Check current usage + var windowStart = GetWindowStart(); + var currentUsage = await _rateLimitingService.GetUsageCountAsync(userId, toolName, windowStart, cancellationToken); + + if (currentUsage >= rateLimit) + { + var reason = $"Rate limit exceeded for tool '{toolName}'. Limit: {rateLimit} requests per {_options.WindowMinutes} minutes. Current usage: {currentUsage}"; + + _logger.LogWarning("Rate limit exceeded: {ToolName} for user {UserId}. Limit: {Limit}, Current: {Current}", + toolName, userId, rateLimit, currentUsage); + + // Calculate reset time + var resetTime = windowStart.AddMinutes(_options.WindowMinutes); + + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "RateLimit", + ("realm", "mcp-api"), + ("limit", rateLimit.ToString()), + ("remaining", Math.Max(0, rateLimit - currentUsage).ToString()), + ("reset_time", resetTime.ToString("O")), + ("window_minutes", _options.WindowMinutes.ToString())); + + return AuthorizationResult.DenyWithChallenge(reason, challenge); + } + + // Record the usage + await _rateLimitingService.RecordUsageAsync(userId, toolName, DateTime.UtcNow, cancellationToken); + + var remaining = rateLimit - currentUsage - 1; // -1 for the current request + + _logger.LogDebug("Tool execution authorized: {ToolName} for user {UserId}. Remaining: {Remaining}/{Limit}", + toolName, userId, remaining, rateLimit); + + return AuthorizationResult.Allow($"Tool '{toolName}' execution authorized. Remaining: {remaining}/{rateLimit}"); + } + + private string GetUserId(ToolAuthorizationContext context) + { + // Try to get user ID from claims + var userId = context.User?.FindFirst(ClaimTypes.NameIdentifier)?.Value + ?? context.User?.FindFirst("sub")?.Value + ?? context.User?.FindFirst("user_id")?.Value; + + if (!string.IsNullOrEmpty(userId)) + { + return userId; + } + + // For anonymous users, use a combination of IP and user agent as identifier + var clientInfo = context.Session?.ClientInfo?.Name ?? "unknown"; + return $"anonymous_{clientInfo.GetHashCode():X8}"; + } + + private string GetUserRole(ToolAuthorizationContext context) + { + var role = context.User?.FindFirst(ClaimTypes.Role)?.Value + ?? context.User?.FindFirst("role")?.Value; + + if (!string.IsNullOrEmpty(role)) + { + return role; + } + + // Default role based on authentication status + return context.User?.Identity?.IsAuthenticated == true ? "user" : "guest"; + } + + private int GetApplicableRateLimit(string toolName, string userRole) + { + // Check for tool-specific limits first (these override role limits) + foreach (var toolLimit in _options.ToolLimits) + { + if (IsPatternMatch(toolLimit.Key, toolName)) + { + return toolLimit.Value; + } + } + + // Fall back to role-based limits + if (_options.RoleLimits.TryGetValue(userRole, out var roleLimit)) + { + return roleLimit; + } + + // Default to guest limits if role not found + return _options.RoleLimits.TryGetValue("guest", out var guestLimit) ? guestLimit : 10; + } + + private DateTime GetWindowStart() + { + var now = DateTime.UtcNow; + + if (_options.UseSlidingWindow) + { + // Sliding window: go back WindowMinutes from now + return now.AddMinutes(-_options.WindowMinutes); + } + else + { + // Fixed window: align to window boundaries + var windowMinutes = _options.WindowMinutes; + var minutesSinceEpoch = (long)(now - DateTime.UnixEpoch).TotalMinutes; + var windowStart = minutesSinceEpoch - (minutesSinceEpoch % windowMinutes); + return DateTime.UnixEpoch.AddMinutes(windowStart); + } + } + + private static bool IsPatternMatch(string pattern, string toolName) + { + if (pattern == "*") + { + return true; + } + + // Convert glob pattern to regex + var regexPattern = "^" + Regex.Escape(pattern).Replace("\\*", ".*") + "$"; + return Regex.IsMatch(toolName, regexPattern, RegexOptions.IgnoreCase); + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Authorization/Filters/RoleBasedToolFilter.cs b/samples/DynamicToolFiltering/Authorization/Filters/RoleBasedToolFilter.cs new file mode 100644 index 00000000..019afcf2 --- /dev/null +++ b/samples/DynamicToolFiltering/Authorization/Filters/RoleBasedToolFilter.cs @@ -0,0 +1,254 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server.Authorization; +using DynamicToolFiltering.Configuration; +using Microsoft.Extensions.Options; +using System.Security.Claims; +using System.Text.RegularExpressions; + +namespace DynamicToolFiltering.Authorization.Filters; + +/// +/// Role-based tool filter that restricts access based on user roles. +/// Supports hierarchical roles and pattern-based tool matching. +/// +/// ARCHITECTURAL DECISION RECORD (ADR-001): +/// ======================================== +/// Decision: Implement hierarchical role-based access control with pattern matching +/// +/// Context: +/// - Need to control tool access based on user roles (guest < user < premium < admin < super_admin) +/// - Different tools require different permission levels +/// - Must support both exact tool name matching and pattern-based matching (e.g., "admin_*") +/// - Should be configurable and extensible for new roles/tools +/// +/// Decision Drivers: +/// 1. Security: Principle of least privilege - users should only access tools appropriate for their role +/// 2. Scalability: Pattern matching reduces configuration overhead for large tool sets +/// 3. Flexibility: Hierarchical roles allow role inheritance (admin can use user tools) +/// 4. Performance: Role checking should be fast (Priority 100 - after rate limiting but before scope checking) +/// +/// Implementation Details: +/// - Uses Claims-based authentication to extract user roles +/// - Supports multiple roles per user (for flexibility) +/// - Pattern matching with wildcards (*, prefix matching) +/// - Configurable role hierarchy and tool mappings +/// - Detailed logging for audit and debugging +/// +/// Consequences: +/// + Simple to understand and configure +/// + Efficient for common use cases +/// + Follows standard RBAC patterns +/// - Requires careful role hierarchy design +/// - Pattern matching could become complex with many tools +/// +/// Alternatives Considered: +/// 1. Attribute-based access control (ABAC) - Too complex for initial implementation +/// 2. Simple boolean permissions - Not flexible enough for hierarchical access +/// 3. External authorization service - Adds complexity and latency +/// +public class RoleBasedToolFilter : IToolFilter +{ + private readonly RoleBasedFilteringOptions _options; + private readonly ILogger _logger; + + public RoleBasedToolFilter(IOptions options, ILogger logger) + { + _options = options.Value.RoleBased; + _logger = logger; + } + + /// + /// Filter execution priority. Lower numbers execute first. + /// Priority 100 places this after rate limiting (50) but before scope checking (150). + /// + /// DESIGN DECISION: Role-based filtering occurs early in the pipeline because: + /// 1. It's fast to execute (simple claim lookup) + /// 2. It can quickly filter out unauthorized tools + /// 3. It reduces load on downstream filters + /// + public int Priority => _options.Priority; + + /// + /// Determines if a tool should be visible to the user based on their roles. + /// This method implements the "fail-fast" principle - if a user doesn't have + /// the required role, the tool won't appear in their tool list. + /// + /// DESIGN DECISION: Tool visibility vs execution separation + /// - Visibility check is more permissive to allow discovery + /// - Execution check is more restrictive for security + /// - This provides better UX while maintaining security + /// + /// The tool to check for visibility + /// The authorization context containing user information + /// Cancellation token for async operations + /// True if the tool should be visible to the user + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // PERFORMANCE OPTIMIZATION: Early exit if filtering is disabled + // This avoids unnecessary processing when role-based filtering is turned off + if (!_options.Enabled) + { + return Task.FromResult(true); + } + + // SECURITY PRINCIPLE: Extract and validate user roles from claims + // Uses the standard Claims-based authentication model from ASP.NET Core + var userRoles = GetUserRoles(context); + var requiredRoles = GetRequiredRoles(tool.Name); + + // AUTHORIZATION LOGIC: Check if user has any of the required roles + // Uses hierarchical role checking - higher roles can access lower-level tools + var hasAccess = HasRequiredRole(userRoles, requiredRoles); + + // AUDIT LOGGING: Detailed logging for security monitoring and debugging + // Logs both successful access and denials for security analysis + _logger.LogDebug("Tool inclusion check for {ToolName}: User roles [{UserRoles}], Required roles [{RequiredRoles}], HasAccess: {HasAccess}", + tool.Name, string.Join(", ", userRoles), string.Join(", ", requiredRoles), hasAccess); + + return Task.FromResult(hasAccess); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return Task.FromResult(AuthorizationResult.Allow("Role-based filtering disabled")); + } + + var userRoles = GetUserRoles(context); + var requiredRoles = GetRequiredRoles(toolName); + + if (HasRequiredRole(userRoles, requiredRoles)) + { + _logger.LogDebug("Tool execution authorized for {ToolName}: User has required role", toolName); + return Task.FromResult(AuthorizationResult.Allow($"User has required role for tool '{toolName}'")); + } + + var reason = $"Tool '{toolName}' requires role(s): {string.Join(" or ", requiredRoles)}. User has role(s): {string.Join(", ", userRoles)}"; + + _logger.LogWarning("Tool execution denied for {ToolName}: {Reason}", toolName, reason); + + // Create a role-based challenge + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Role", + ("realm", "mcp-api"), + ("required_roles", string.Join(",", requiredRoles)), + ("user_roles", string.Join(",", userRoles))); + + return Task.FromResult(AuthorizationResult.DenyWithChallenge(reason, challenge)); + } + + private List GetUserRoles(ToolAuthorizationContext context) + { + var roles = new List(); + + // Try to get roles from claims principal + if (context.User?.Identity?.IsAuthenticated == true) + { + roles.AddRange(context.User.Claims + .Where(c => c.Type == _options.RoleClaimType || c.Type == ClaimTypes.Role) + .Select(c => c.Value)); + } + + // If no roles found and user is not authenticated, assign guest role + if (roles.Count == 0 && context.User?.Identity?.IsAuthenticated != true) + { + roles.Add("guest"); + } + + // If no roles found but user is authenticated, assign default user role + if (roles.Count == 0 && context.User?.Identity?.IsAuthenticated == true) + { + roles.Add("user"); + } + + return roles; + } + + private List GetRequiredRoles(string toolName) + { + foreach (var mapping in _options.ToolRoleMapping) + { + if (IsPatternMatch(mapping.Key, toolName)) + { + return mapping.Value.ToList(); + } + } + + // Default to requiring authentication (user role or higher) + return new List { "user" }; + } + + private bool HasRequiredRole(List userRoles, List requiredRoles) + { + if (requiredRoles.Count == 0) + { + return true; // No specific role required + } + + if (_options.UseHierarchicalRoles) + { + return HasHierarchicalRole(userRoles, requiredRoles); + } + else + { + return userRoles.Intersect(requiredRoles, StringComparer.OrdinalIgnoreCase).Any(); + } + } + + private bool HasHierarchicalRole(List userRoles, List requiredRoles) + { + // Get the highest privilege level for user roles + var userMaxLevel = GetMaxRoleLevel(userRoles); + + // Get the minimum required privilege level + var requiredMinLevel = GetMinRoleLevel(requiredRoles); + + // User must have equal or higher privilege level + return userMaxLevel <= requiredMinLevel; // Lower index = higher privilege + } + + private int GetMaxRoleLevel(List roles) + { + var minLevel = int.MaxValue; + + foreach (var role in roles) + { + var level = Array.IndexOf(_options.RoleHierarchy, role); + if (level >= 0 && level < minLevel) + { + minLevel = level; + } + } + + return minLevel == int.MaxValue ? _options.RoleHierarchy.Length : minLevel; + } + + private int GetMinRoleLevel(List requiredRoles) + { + var maxLevel = -1; + + foreach (var role in requiredRoles) + { + var level = Array.IndexOf(_options.RoleHierarchy, role); + if (level > maxLevel) + { + maxLevel = level; + } + } + + return maxLevel == -1 ? _options.RoleHierarchy.Length - 1 : maxLevel; + } + + private static bool IsPatternMatch(string pattern, string toolName) + { + if (pattern == "*") + { + return true; + } + + // Convert glob pattern to regex + var regexPattern = "^" + Regex.Escape(pattern).Replace("\\*", ".*") + "$"; + return Regex.IsMatch(toolName, regexPattern, RegexOptions.IgnoreCase); + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Authorization/Filters/ScopeBasedToolFilter.cs b/samples/DynamicToolFiltering/Authorization/Filters/ScopeBasedToolFilter.cs new file mode 100644 index 00000000..24846773 --- /dev/null +++ b/samples/DynamicToolFiltering/Authorization/Filters/ScopeBasedToolFilter.cs @@ -0,0 +1,185 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server.Authorization; +using DynamicToolFiltering.Configuration; +using Microsoft.Extensions.Options; +using System.Security.Claims; +using System.Text.RegularExpressions; + +namespace DynamicToolFiltering.Authorization.Filters; + +/// +/// Scope-based tool filter that implements OAuth2-style scope checking. +/// Restricts tool access based on granted scopes in JWT tokens or claims. +/// +public class ScopeBasedToolFilter : IToolFilter +{ + private readonly ScopeBasedFilteringOptions _options; + private readonly ILogger _logger; + + public ScopeBasedToolFilter(IOptions options, ILogger logger) + { + _options = options.Value.ScopeBased; + _logger = logger; + } + + public int Priority => _options.Priority; + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return Task.FromResult(true); + } + + var userScopes = GetUserScopes(context); + var requiredScopes = GetRequiredScopes(tool.Name); + + var hasAccess = HasRequiredScope(userScopes, requiredScopes); + + _logger.LogDebug("Tool inclusion check for {ToolName}: User scopes [{UserScopes}], Required scopes [{RequiredScopes}], HasAccess: {HasAccess}", + tool.Name, string.Join(", ", userScopes), string.Join(", ", requiredScopes), hasAccess); + + return Task.FromResult(hasAccess); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return Task.FromResult(AuthorizationResult.Allow("Scope-based filtering disabled")); + } + + var userScopes = GetUserScopes(context); + var requiredScopes = GetRequiredScopes(toolName); + + if (HasRequiredScope(userScopes, requiredScopes)) + { + _logger.LogDebug("Tool execution authorized for {ToolName}: User has required scope", toolName); + return Task.FromResult(AuthorizationResult.Allow($"User has required scope for tool '{toolName}'")); + } + + var reason = $"Tool '{toolName}' requires scope(s): {string.Join(" or ", requiredScopes)}"; + + _logger.LogWarning("Tool execution denied for {ToolName}: Insufficient scope. User scopes: [{UserScopes}], Required: [{RequiredScopes}]", + toolName, string.Join(", ", userScopes), string.Join(", ", requiredScopes)); + + // Determine the most appropriate scope to request + var suggestedScope = requiredScopes.FirstOrDefault() ?? "basic:tools"; + + // Create OAuth2-style Bearer challenge with insufficient_scope error + return Task.FromResult(AuthorizationResult.DenyInsufficientScope(suggestedScope, "mcp-api")); + } + + private List GetUserScopes(ToolAuthorizationContext context) + { + var scopes = new List(); + + // Try to get scopes from claims principal + if (context.User?.Identity?.IsAuthenticated == true) + { + // Check for scope claim (OAuth2 standard) + var scopeClaims = context.User.Claims + .Where(c => c.Type == _options.ScopeClaimType) + .ToList(); + + foreach (var scopeClaim in scopeClaims) + { + // OAuth2 scopes can be space-separated in a single claim + var scopeValues = scopeClaim.Value.Split(' ', StringSplitOptions.RemoveEmptyEntries); + scopes.AddRange(scopeValues); + } + + // Also check for individual scope claims (some implementations use this pattern) + scopes.AddRange(context.User.Claims + .Where(c => c.Type.StartsWith("scope:", StringComparison.OrdinalIgnoreCase)) + .Select(c => c.Type["scope:".Length..])); + } + + // If no scopes found and user is not authenticated, assign basic public scope + if (scopes.Count == 0 && context.User?.Identity?.IsAuthenticated != true) + { + scopes.Add("basic:tools"); + } + + // If no scopes found but user is authenticated, assign basic authenticated scope + if (scopes.Count == 0 && context.User?.Identity?.IsAuthenticated == true) + { + scopes.Add("user:tools"); + } + + return scopes.Distinct(StringComparer.OrdinalIgnoreCase).ToList(); + } + + private List GetRequiredScopes(string toolName) + { + foreach (var mapping in _options.ToolScopeMapping) + { + if (IsPatternMatch(mapping.Key, toolName)) + { + return mapping.Value.ToList(); + } + } + + // Default to requiring basic tools scope + return new List { "basic:tools" }; + } + + private bool HasRequiredScope(List userScopes, List requiredScopes) + { + if (requiredScopes.Count == 0) + { + return true; // No specific scope required + } + + // User needs at least one of the required scopes + return requiredScopes.Any(requiredScope => + userScopes.Any(userScope => + IsScopeMatch(userScope, requiredScope))); + } + + private static bool IsScopeMatch(string userScope, string requiredScope) + { + // Exact match + if (string.Equals(userScope, requiredScope, StringComparison.OrdinalIgnoreCase)) + { + return true; + } + + // Hierarchical scope matching (e.g., "admin:tools" implies "user:tools") + // This implements a simple hierarchical model where broader scopes include narrower ones + var scopeHierarchy = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "admin:tools", new[] { "admin:tools", "premium:tools", "user:tools", "read:tools", "basic:tools" } }, + { "premium:tools", new[] { "premium:tools", "user:tools", "read:tools", "basic:tools" } }, + { "user:tools", new[] { "user:tools", "read:tools", "basic:tools" } }, + { "read:tools", new[] { "read:tools", "basic:tools" } }, + { "basic:tools", new[] { "basic:tools" } } + }; + + if (scopeHierarchy.TryGetValue(userScope, out var impliedScopes)) + { + return impliedScopes.Contains(requiredScope, StringComparer.OrdinalIgnoreCase); + } + + // Wildcard matching for custom scopes (e.g., "tools:*" matches "tools:read") + if (userScope.EndsWith(":*")) + { + var scopePrefix = userScope[..^1]; // Remove the "*" + return requiredScope.StartsWith(scopePrefix, StringComparison.OrdinalIgnoreCase); + } + + return false; + } + + private static bool IsPatternMatch(string pattern, string toolName) + { + if (pattern == "*") + { + return true; + } + + // Convert glob pattern to regex + var regexPattern = "^" + Regex.Escape(pattern).Replace("\\*", ".*") + "$"; + return Regex.IsMatch(toolName, regexPattern, RegexOptions.IgnoreCase); + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Authorization/Filters/TenantIsolationFilter.cs b/samples/DynamicToolFiltering/Authorization/Filters/TenantIsolationFilter.cs new file mode 100644 index 00000000..70774c8e --- /dev/null +++ b/samples/DynamicToolFiltering/Authorization/Filters/TenantIsolationFilter.cs @@ -0,0 +1,201 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server.Authorization; +using DynamicToolFiltering.Configuration; +using Microsoft.Extensions.Options; +using System.Security.Claims; +using System.Text.RegularExpressions; + +namespace DynamicToolFiltering.Authorization.Filters; + +/// +/// Tenant isolation filter that provides multi-tenant tool access control. +/// Restricts tool access based on tenant membership and tenant-specific configurations. +/// +public class TenantIsolationFilter : IToolFilter +{ + private readonly TenantIsolationOptions _options; + private readonly ILogger _logger; + + public TenantIsolationFilter(IOptions options, ILogger logger) + { + _options = options.Value.TenantIsolation; + _logger = logger; + } + + public int Priority => _options.Priority; + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return Task.FromResult(true); + } + + var tenantId = GetTenantId(context); + var canAccess = CanAccessTool(tool.Name, tenantId); + + _logger.LogDebug("Tool inclusion check for {ToolName}: Tenant {TenantId}, CanAccess: {CanAccess}", + tool.Name, tenantId ?? "none", canAccess); + + return Task.FromResult(canAccess); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return Task.FromResult(AuthorizationResult.Allow("Tenant isolation disabled")); + } + + var tenantId = GetTenantId(context); + + if (string.IsNullOrEmpty(tenantId)) + { + var reason = "Tenant ID is required for tool access"; + + _logger.LogWarning("Tool execution denied: {ToolName} - No tenant ID provided", toolName); + + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Tenant", + ("realm", "mcp-api"), + ("tenant_header", _options.TenantHeaderName), + ("tenant_claim", _options.TenantClaimType)); + + return Task.FromResult(AuthorizationResult.DenyWithChallenge(reason, challenge)); + } + + if (!_options.TenantConfigurations.TryGetValue(tenantId, out var tenantConfig)) + { + var reason = $"Unknown tenant: {tenantId}"; + + _logger.LogWarning("Tool execution denied: {ToolName} - Unknown tenant {TenantId}", toolName, tenantId); + + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Tenant", + ("realm", "mcp-api"), + ("error", "unknown_tenant"), + ("tenant_id", tenantId)); + + return Task.FromResult(AuthorizationResult.DenyWithChallenge(reason, challenge)); + } + + if (!tenantConfig.IsActive) + { + var reason = $"Tenant {tenantId} is currently inactive"; + + _logger.LogWarning("Tool execution denied: {ToolName} - Inactive tenant {TenantId}", toolName, tenantId); + + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Tenant", + ("realm", "mcp-api"), + ("error", "tenant_inactive"), + ("tenant_id", tenantId)); + + return Task.FromResult(AuthorizationResult.DenyWithChallenge(reason, challenge)); + } + + // Check if tool is explicitly denied for this tenant + if (IsToolDenied(toolName, tenantConfig.DeniedTools)) + { + var reason = $"Tool '{toolName}' is not available for tenant {tenantId}"; + + _logger.LogWarning("Tool execution denied: {ToolName} - Explicitly denied for tenant {TenantId}", toolName, tenantId); + + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Tenant", + ("realm", "mcp-api"), + ("error", "tool_denied"), + ("tenant_id", tenantId), + ("tool_name", toolName)); + + return Task.FromResult(AuthorizationResult.DenyWithChallenge(reason, challenge)); + } + + // Check if tool is allowed for this tenant + if (!IsToolAllowed(toolName, tenantConfig.AllowedTools)) + { + var reason = $"Tool '{toolName}' is not in the allowed tools list for tenant {tenantId}"; + + _logger.LogWarning("Tool execution denied: {ToolName} - Not in allowed list for tenant {TenantId}", toolName, tenantId); + + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Tenant", + ("realm", "mcp-api"), + ("error", "tool_not_allowed"), + ("tenant_id", tenantId), + ("tool_name", toolName)); + + return Task.FromResult(AuthorizationResult.DenyWithChallenge(reason, challenge)); + } + + _logger.LogDebug("Tool execution authorized for tenant: {ToolName}, Tenant: {TenantId}", toolName, tenantId); + + return Task.FromResult(AuthorizationResult.Allow($"Tool '{toolName}' is available for tenant {tenantId}")); + } + + private string? GetTenantId(ToolAuthorizationContext context) + { + // Try to get tenant ID from claims first + var tenantId = context.User?.FindFirst(_options.TenantClaimType)?.Value; + + if (!string.IsNullOrEmpty(tenantId)) + { + return tenantId; + } + + // Try to get tenant ID from HTTP headers (if available in context) + // Note: This would require extending ToolAuthorizationContext to include HTTP context + // For now, we'll rely on claims-based approach + + return null; + } + + private bool CanAccessTool(string toolName, string? tenantId) + { + if (string.IsNullOrEmpty(tenantId)) + { + return false; // No tenant, no access + } + + if (!_options.TenantConfigurations.TryGetValue(tenantId, out var tenantConfig)) + { + return false; // Unknown tenant + } + + if (!tenantConfig.IsActive) + { + return false; // Inactive tenant + } + + // Check denied tools first + if (IsToolDenied(toolName, tenantConfig.DeniedTools)) + { + return false; + } + + // Check allowed tools + return IsToolAllowed(toolName, tenantConfig.AllowedTools); + } + + private bool IsToolAllowed(string toolName, string[] allowedPatterns) + { + return allowedPatterns.Any(pattern => IsPatternMatch(pattern, toolName)); + } + + private bool IsToolDenied(string toolName, string[] deniedPatterns) + { + return deniedPatterns.Any(pattern => IsPatternMatch(pattern, toolName)); + } + + private static bool IsPatternMatch(string pattern, string toolName) + { + if (pattern == "*") + { + return true; + } + + // Convert glob pattern to regex + var regexPattern = "^" + Regex.Escape(pattern).Replace("\\*", ".*") + "$"; + return Regex.IsMatch(toolName, regexPattern, RegexOptions.IgnoreCase); + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Authorization/Filters/TimeBasedToolFilter.cs b/samples/DynamicToolFiltering/Authorization/Filters/TimeBasedToolFilter.cs new file mode 100644 index 00000000..aaf4b0cd --- /dev/null +++ b/samples/DynamicToolFiltering/Authorization/Filters/TimeBasedToolFilter.cs @@ -0,0 +1,196 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server.Authorization; +using DynamicToolFiltering.Configuration; +using Microsoft.Extensions.Options; +using System.Globalization; +using System.Text.RegularExpressions; + +namespace DynamicToolFiltering.Authorization.Filters; + +/// +/// Time-based tool filter that restricts access based on business hours and maintenance windows. +/// +public class TimeBasedToolFilter : IToolFilter +{ + private readonly TimeBasedFilteringOptions _options; + private readonly ILogger _logger; + private readonly TimeZoneInfo _timeZone; + + public TimeBasedToolFilter(IOptions options, ILogger logger) + { + _options = options.Value.TimeBased; + _logger = logger; + + try + { + _timeZone = TimeZoneInfo.FindSystemTimeZoneById(_options.TimeZone); + } + catch (TimeZoneNotFoundException) + { + _logger.LogWarning("Time zone '{TimeZone}' not found, falling back to UTC", _options.TimeZone); + _timeZone = TimeZoneInfo.Utc; + } + } + + public int Priority => _options.Priority; + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return Task.FromResult(true); + } + + var canAccess = CanAccessTool(tool.Name); + + _logger.LogDebug("Tool inclusion check for {ToolName}: CanAccess: {CanAccess}", tool.Name, canAccess); + + return Task.FromResult(canAccess); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return Task.FromResult(AuthorizationResult.Allow("Time-based filtering disabled")); + } + + var currentTime = TimeZoneInfo.ConvertTimeFromUtc(DateTime.UtcNow, _timeZone); + + // Check maintenance windows first (highest priority) + foreach (var maintenanceWindow in _options.MaintenanceWindows) + { + if (maintenanceWindow.IsActive && IsInMaintenanceWindow(currentTime, maintenanceWindow)) + { + if (IsToolBlocked(toolName, maintenanceWindow.BlockedTools)) + { + var reason = $"Tool '{toolName}' is blocked during maintenance window: {maintenanceWindow.Description}"; + + _logger.LogWarning("Tool execution denied during maintenance: {ToolName}, Window: {Description}", + toolName, maintenanceWindow.Description); + + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Maintenance", + ("realm", "mcp-api"), + ("maintenance_start", maintenanceWindow.StartTime.ToString("O")), + ("maintenance_end", maintenanceWindow.EndTime.ToString("O")), + ("description", maintenanceWindow.Description)); + + return Task.FromResult(AuthorizationResult.DenyWithChallenge(reason, challenge)); + } + } + } + + // Check business hours restrictions + if (_options.BusinessHours.Enabled && IsToolRestrictedToBusinessHours(toolName)) + { + if (!IsWithinBusinessHours(currentTime)) + { + var reason = $"Tool '{toolName}' is only available during business hours: {_options.BusinessHours.StartTime}-{_options.BusinessHours.EndTime} on {string.Join(", ", _options.BusinessHours.BusinessDays)}"; + + _logger.LogWarning("Tool execution denied outside business hours: {ToolName}", toolName); + + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "BusinessHours", + ("realm", "mcp-api"), + ("business_start", _options.BusinessHours.StartTime), + ("business_end", _options.BusinessHours.EndTime), + ("business_days", string.Join(",", _options.BusinessHours.BusinessDays)), + ("current_time", currentTime.ToString("O")), + ("timezone", _timeZone.Id)); + + return Task.FromResult(AuthorizationResult.DenyWithChallenge(reason, challenge)); + } + } + + _logger.LogDebug("Tool execution authorized by time-based filter: {ToolName}", toolName); + return Task.FromResult(AuthorizationResult.Allow($"Tool '{toolName}' is available at current time")); + } + + private bool CanAccessTool(string toolName) + { + var currentTime = TimeZoneInfo.ConvertTimeFromUtc(DateTime.UtcNow, _timeZone); + + // Check maintenance windows + foreach (var maintenanceWindow in _options.MaintenanceWindows) + { + if (maintenanceWindow.IsActive && + IsInMaintenanceWindow(currentTime, maintenanceWindow) && + IsToolBlocked(toolName, maintenanceWindow.BlockedTools)) + { + return false; + } + } + + // Check business hours + if (_options.BusinessHours.Enabled && + IsToolRestrictedToBusinessHours(toolName) && + !IsWithinBusinessHours(currentTime)) + { + return false; + } + + return true; + } + + private bool IsInMaintenanceWindow(DateTime currentTime, MaintenanceWindowOptions maintenanceWindow) + { + var windowStart = TimeZoneInfo.ConvertTimeFromUtc(maintenanceWindow.StartTime, _timeZone); + var windowEnd = TimeZoneInfo.ConvertTimeFromUtc(maintenanceWindow.EndTime, _timeZone); + + return currentTime >= windowStart && currentTime <= windowEnd; + } + + private bool IsToolBlocked(string toolName, string[] blockedPatterns) + { + return blockedPatterns.Any(pattern => IsPatternMatch(pattern, toolName)); + } + + private bool IsToolRestrictedToBusinessHours(string toolName) + { + return _options.BusinessHours.RestrictedTools.Any(pattern => IsPatternMatch(pattern, toolName)); + } + + private bool IsWithinBusinessHours(DateTime currentTime) + { + // Check if current day is a business day + var currentDayName = currentTime.DayOfWeek.ToString(); + if (!_options.BusinessHours.BusinessDays.Contains(currentDayName, StringComparer.OrdinalIgnoreCase)) + { + return false; + } + + // Parse business hours + if (!TimeOnly.TryParse(_options.BusinessHours.StartTime, CultureInfo.InvariantCulture, out var startTime) || + !TimeOnly.TryParse(_options.BusinessHours.EndTime, CultureInfo.InvariantCulture, out var endTime)) + { + _logger.LogError("Invalid business hours format. Start: {StartTime}, End: {EndTime}", + _options.BusinessHours.StartTime, _options.BusinessHours.EndTime); + return false; + } + + var currentTimeOnly = TimeOnly.FromDateTime(currentTime); + + // Handle cases where end time is before start time (spans midnight) + if (endTime < startTime) + { + return currentTimeOnly >= startTime || currentTimeOnly <= endTime; + } + else + { + return currentTimeOnly >= startTime && currentTimeOnly <= endTime; + } + } + + private static bool IsPatternMatch(string pattern, string toolName) + { + if (pattern == "*") + { + return true; + } + + // Convert glob pattern to regex + var regexPattern = "^" + Regex.Escape(pattern).Replace("\\*", ".*") + "$"; + return Regex.IsMatch(toolName, regexPattern, RegexOptions.IgnoreCase); + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Configuration/FilteringOptions.cs b/samples/DynamicToolFiltering/Configuration/FilteringOptions.cs new file mode 100644 index 00000000..8ee7104f --- /dev/null +++ b/samples/DynamicToolFiltering/Configuration/FilteringOptions.cs @@ -0,0 +1,443 @@ +namespace DynamicToolFiltering.Configuration; + +/// +/// Configuration options for dynamic tool filtering system. +/// +public class FilteringOptions +{ + public const string SectionName = "Filtering"; + + /// + /// Gets or sets whether filtering is enabled globally. + /// + public bool Enabled { get; set; } = true; + + /// + /// Gets or sets the default behavior when no filters match (allow or deny). + /// + public string DefaultBehavior { get; set; } = "deny"; + + /// + /// Gets or sets role-based filtering configuration. + /// + public RoleBasedFilteringOptions RoleBased { get; set; } = new(); + + /// + /// Gets or sets time-based filtering configuration. + /// + public TimeBasedFilteringOptions TimeBased { get; set; } = new(); + + /// + /// Gets or sets scope-based filtering configuration. + /// + public ScopeBasedFilteringOptions ScopeBased { get; set; } = new(); + + /// + /// Gets or sets rate limiting configuration. + /// + public RateLimitingOptions RateLimiting { get; set; } = new(); + + /// + /// Gets or sets tenant isolation configuration. + /// + public TenantIsolationOptions TenantIsolation { get; set; } = new(); + + /// + /// Gets or sets business logic filtering configuration. + /// + public BusinessLogicFilteringOptions BusinessLogic { get; set; } = new(); +} + +/// +/// Configuration for role-based filtering. +/// +public class RoleBasedFilteringOptions +{ + /// + /// Gets or sets whether role-based filtering is enabled. + /// + public bool Enabled { get; set; } = true; + + /// + /// Gets or sets the priority of the role-based filter. + /// + public int Priority { get; set; } = 100; + + /// + /// Gets or sets the claim type that contains user roles. + /// + public string RoleClaimType { get; set; } = "role"; + + /// + /// Gets or sets the mapping of tool patterns to required roles. + /// + public Dictionary ToolRoleMapping { get; set; } = new() + { + { "admin_*", new[] { "admin", "super_admin" } }, + { "premium_*", new[] { "premium", "admin", "super_admin" } }, + { "*_user_*", new[] { "user", "premium", "admin", "super_admin" } }, + { "*", new[] { "guest", "user", "premium", "admin", "super_admin" } } + }; + + /// + /// Gets or sets whether to use hierarchical roles (admin inherits user permissions). + /// + public bool UseHierarchicalRoles { get; set; } = true; + + /// + /// Gets or sets the role hierarchy from highest to lowest privilege. + /// + public string[] RoleHierarchy { get; set; } = { "super_admin", "admin", "premium", "user", "guest" }; +} + +/// +/// Configuration for time-based filtering. +/// +public class TimeBasedFilteringOptions +{ + /// + /// Gets or sets whether time-based filtering is enabled. + /// + public bool Enabled { get; set; } = false; + + /// + /// Gets or sets the priority of the time-based filter. + /// + public int Priority { get; set; } = 200; + + /// + /// Gets or sets the timezone for time-based filtering. + /// + public string TimeZone { get; set; } = "UTC"; + + /// + /// Gets or sets business hours when certain tools are available. + /// + public BusinessHoursOptions BusinessHours { get; set; } = new(); + + /// + /// Gets or sets maintenance windows when tools are restricted. + /// + public MaintenanceWindowOptions[] MaintenanceWindows { get; set; } = Array.Empty(); +} + +/// +/// Configuration for business hours. +/// +public class BusinessHoursOptions +{ + /// + /// Gets or sets whether business hours restrictions are enabled. + /// + public bool Enabled { get; set; } = false; + + /// + /// Gets or sets the start time for business hours (24-hour format). + /// + public string StartTime { get; set; } = "09:00"; + + /// + /// Gets or sets the end time for business hours (24-hour format). + /// + public string EndTime { get; set; } = "17:00"; + + /// + /// Gets or sets the days of week for business hours. + /// + public string[] BusinessDays { get; set; } = { "Monday", "Tuesday", "Wednesday", "Thursday", "Friday" }; + + /// + /// Gets or sets tool patterns that are restricted to business hours. + /// + public string[] RestrictedTools { get; set; } = { "admin_*" }; +} + +/// +/// Configuration for maintenance windows. +/// +public class MaintenanceWindowOptions +{ + /// + /// Gets or sets the start time of the maintenance window. + /// + public DateTime StartTime { get; set; } + + /// + /// Gets or sets the end time of the maintenance window. + /// + public DateTime EndTime { get; set; } + + /// + /// Gets or sets tool patterns that are blocked during maintenance. + /// + public string[] BlockedTools { get; set; } = { "*" }; + + /// + /// Gets or sets whether this maintenance window is active. + /// + public bool IsActive { get; set; } = true; + + /// + /// Gets or sets the description of the maintenance window. + /// + public string Description { get; set; } = ""; +} + +/// +/// Configuration for scope-based filtering. +/// +public class ScopeBasedFilteringOptions +{ + /// + /// Gets or sets whether scope-based filtering is enabled. + /// + public bool Enabled { get; set; } = true; + + /// + /// Gets or sets the priority of the scope-based filter. + /// + public int Priority { get; set; } = 150; + + /// + /// Gets or sets the claim type that contains scopes. + /// + public string ScopeClaimType { get; set; } = "scope"; + + /// + /// Gets or sets the mapping of tool patterns to required scopes. + /// + public Dictionary ToolScopeMapping { get; set; } = new() + { + { "admin_*", new[] { "admin:tools" } }, + { "premium_*", new[] { "premium:tools" } }, + { "*_user_*", new[] { "user:tools" } }, + { "get_*", new[] { "read:tools" } }, + { "*", new[] { "basic:tools" } } + }; +} + +/// +/// Configuration for rate limiting. +/// +public class RateLimitingOptions +{ + /// + /// Gets or sets whether rate limiting is enabled. + /// + public bool Enabled { get; set; } = true; + + /// + /// Gets or sets the priority of the rate limiting filter. + /// + public int Priority { get; set; } = 50; + + /// + /// Gets or sets the time window for rate limiting in minutes. + /// + public int WindowMinutes { get; set; } = 60; + + /// + /// Gets or sets rate limits per user role. + /// + public Dictionary RoleLimits { get; set; } = new() + { + { "guest", 10 }, + { "user", 100 }, + { "premium", 500 }, + { "admin", 1000 }, + { "super_admin", -1 } // -1 means unlimited + }; + + /// + /// Gets or sets per-tool rate limits that override role limits. + /// + public Dictionary ToolLimits { get; set; } = new() + { + { "premium_performance_benchmark", 5 }, + { "admin_*", 50 } + }; + + /// + /// Gets or sets whether to use sliding window (true) or fixed window (false). + /// + public bool UseSlidingWindow { get; set; } = true; +} + +/// +/// Configuration for tenant isolation. +/// +public class TenantIsolationOptions +{ + /// + /// Gets or sets whether tenant isolation is enabled. + /// + public bool Enabled { get; set; } = false; + + /// + /// Gets or sets the priority of the tenant isolation filter. + /// + public int Priority { get; set; } = 75; + + /// + /// Gets or sets the claim type that contains tenant ID. + /// + public string TenantClaimType { get; set; } = "tenant_id"; + + /// + /// Gets or sets the header name for tenant ID (alternative to claims). + /// + public string TenantHeaderName { get; set; } = "X-Tenant-ID"; + + /// + /// Gets or sets tenant-specific tool access configuration. + /// + public Dictionary TenantConfigurations { get; set; } = new(); +} + +/// +/// Configuration for a specific tenant. +/// +public class TenantConfiguration +{ + /// + /// Gets or sets the tenant name. + /// + public string Name { get; set; } = ""; + + /// + /// Gets or sets whether this tenant is active. + /// + public bool IsActive { get; set; } = true; + + /// + /// Gets or sets tool patterns allowed for this tenant. + /// + public string[] AllowedTools { get; set; } = { "*" }; + + /// + /// Gets or sets tool patterns explicitly denied for this tenant. + /// + public string[] DeniedTools { get; set; } = Array.Empty(); + + /// + /// Gets or sets custom rate limits for this tenant. + /// + public Dictionary CustomRateLimits { get; set; } = new(); +} + +/// +/// Configuration for business logic filtering. +/// +public class BusinessLogicFilteringOptions +{ + /// + /// Gets or sets whether business logic filtering is enabled. + /// + public bool Enabled { get; set; } = true; + + /// + /// Gets or sets the priority of the business logic filter. + /// + public int Priority { get; set; } = 300; + + /// + /// Gets or sets feature flag configuration. + /// + public FeatureFlagOptions FeatureFlags { get; set; } = new(); + + /// + /// Gets or sets quota management configuration. + /// + public QuotaManagementOptions QuotaManagement { get; set; } = new(); + + /// + /// Gets or sets environment-based restrictions. + /// + public EnvironmentRestrictionOptions EnvironmentRestrictions { get; set; } = new(); +} + +/// +/// Configuration for feature flags. +/// +public class FeatureFlagOptions +{ + /// + /// Gets or sets whether feature flag filtering is enabled. + /// + public bool Enabled { get; set; } = false; + + /// + /// Gets or sets feature flag mappings for tools. + /// + public Dictionary ToolFeatureMapping { get; set; } = new() + { + { "premium_*", "premium_features" }, + { "admin_performance_*", "admin_performance_tools" } + }; + + /// + /// Gets or sets the default state for unknown feature flags. + /// + public bool DefaultFeatureFlagState { get; set; } = false; +} + +/// +/// Configuration for quota management. +/// +public class QuotaManagementOptions +{ + /// + /// Gets or sets whether quota management is enabled. + /// + public bool Enabled { get; set; } = false; + + /// + /// Gets or sets the quota period in days. + /// + public int QuotaPeriodDays { get; set; } = 30; + + /// + /// Gets or sets quota limits per user role. + /// + public Dictionary RoleQuotas { get; set; } = new() + { + { "user", 1000 }, + { "premium", 10000 }, + { "admin", -1 } // -1 means unlimited + }; + + /// + /// Gets or sets quota costs per tool pattern. + /// + public Dictionary ToolQuotaCosts { get; set; } = new() + { + { "premium_performance_benchmark", 10 }, + { "premium_*", 2 }, + { "*", 1 } + }; +} + +/// +/// Configuration for environment-based restrictions. +/// +public class EnvironmentRestrictionOptions +{ + /// + /// Gets or sets whether environment restrictions are enabled. + /// + public bool Enabled { get; set; } = true; + + /// + /// Gets or sets tool patterns restricted in production. + /// + public string[] ProductionRestrictedTools { get; set; } = + { + "admin_force_gc", + "admin_list_processes" + }; + + /// + /// Gets or sets tool patterns only available in development. + /// + public string[] DevelopmentOnlyTools { get; set; } = Array.Empty(); +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Dockerfile b/samples/DynamicToolFiltering/Dockerfile new file mode 100644 index 00000000..d86fc7c9 --- /dev/null +++ b/samples/DynamicToolFiltering/Dockerfile @@ -0,0 +1,66 @@ +# Dynamic Tool Filtering MCP Server - Docker Configuration +# Multi-stage build for optimized production image + +# Build stage +FROM mcr.microsoft.com/dotnet/sdk:9.0 AS build +WORKDIR /src + +# Copy project file and restore dependencies +# This layer is cached unless project file changes +COPY DynamicToolFiltering.csproj . +RUN dotnet restore + +# Copy source code and build +COPY . . +RUN dotnet build -c Release -o /app/build --no-restore + +# Publish stage +FROM build AS publish +RUN dotnet publish -c Release -o /app/publish --no-restore --no-build + +# Runtime stage - minimal ASP.NET Core runtime image +FROM mcr.microsoft.com/dotnet/aspnet:9.0 AS runtime + +# Install dependencies for health checks and debugging +RUN apt-get update && apt-get install -y \ + curl \ + jq \ + && rm -rf /var/lib/apt/lists/* + +# Create non-root user for security +RUN groupadd -r appgroup && useradd -r -g appgroup appuser + +# Set working directory +WORKDIR /app + +# Copy published application +COPY --from=publish /app/publish . + +# Create necessary directories and set permissions +RUN mkdir -p logs data && \ + chown -R appuser:appgroup /app + +# Switch to non-root user +USER appuser + +# Expose port +EXPOSE 8080 + +# Add metadata labels +LABEL org.opencontainers.image.title="Dynamic Tool Filtering MCP Server" \ + org.opencontainers.image.description="Advanced MCP server demonstrating tool filtering and authorization" \ + org.opencontainers.image.source="https://github.com/microsoft/mcp-csharp-sdk" \ + org.opencontainers.image.documentation="https://github.com/microsoft/mcp-csharp-sdk/tree/main/samples/DynamicToolFiltering" + +# Health check - verify the application is responding +HEALTHCHECK --interval=30s --timeout=10s --start-period=15s --retries=3 \ + CMD curl -f http://localhost:8080/health || exit 1 + +# Set environment variables for container +ENV ASPNETCORE_URLS=http://+:8080 \ + ASPNETCORE_ENVIRONMENT=Production \ + DOTNET_RUNNING_IN_CONTAINER=true \ + DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=true + +# Start application +ENTRYPOINT ["dotnet", "DynamicToolFiltering.dll"] \ No newline at end of file diff --git a/samples/DynamicToolFiltering/DynamicToolFiltering.csproj b/samples/DynamicToolFiltering/DynamicToolFiltering.csproj new file mode 100644 index 00000000..5eb656e2 --- /dev/null +++ b/samples/DynamicToolFiltering/DynamicToolFiltering.csproj @@ -0,0 +1,28 @@ + + + + net9.0 + enable + enable + true + DynamicToolFiltering-sample + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/samples/DynamicToolFiltering/INTEGRATION_EXAMPLES.md b/samples/DynamicToolFiltering/INTEGRATION_EXAMPLES.md new file mode 100644 index 00000000..04540985 --- /dev/null +++ b/samples/DynamicToolFiltering/INTEGRATION_EXAMPLES.md @@ -0,0 +1,1011 @@ +# Integration Examples for Dynamic Tool Filtering + +This document provides practical examples for integrating the Dynamic Tool Filtering system with external services and real-world scenarios. + +## Table of Contents + +1. [JWT Integration with Identity Providers](#jwt-integration-with-identity-providers) +2. [Redis Integration for Rate Limiting](#redis-integration-for-rate-limiting) +3. [Database Integration for Quotas](#database-integration-for-quotas) +4. [External Feature Flag Services](#external-feature-flag-services) +5. [Multi-Tenant SaaS Integration](#multi-tenant-saas-integration) +6. [Monitoring and Observability](#monitoring-and-observability) +7. [Custom Filter Development](#custom-filter-development) + +## JWT Integration with Identity Providers + +### Auth0 Integration + +```csharp +// Program.cs - Configure Auth0 JWT authentication +builder.Services.AddAuthentication(JwtBearerDefaults.AuthenticationScheme) + .AddJwtBearer(options => + { + options.Authority = "https://your-tenant.auth0.com/"; + options.Audience = "your-api-identifier"; + options.TokenValidationParameters = new TokenValidationParameters + { + ValidateIssuer = true, + ValidateAudience = true, + ValidateLifetime = true, + ValidateIssuerSigningKey = true, + ClockSkew = TimeSpan.FromMinutes(5) + }; + }); + +// Custom claims transformation for Auth0 +builder.Services.AddSingleton(); + +public class Auth0ClaimsTransformation : IClaimsTransformation +{ + public Task TransformAsync(ClaimsPrincipal principal) + { + var claimsIdentity = (ClaimsIdentity)principal.Identity!; + + // Map Auth0 custom claims to standard claims + var permissions = principal.FindFirst("permissions")?.Value; + if (!string.IsNullOrEmpty(permissions)) + { + var permissionList = JsonSerializer.Deserialize(permissions); + foreach (var permission in permissionList) + { + claimsIdentity.AddClaim(new Claim("scope", permission)); + } + } + + // Map Auth0 roles + var roles = principal.FindFirst("https://myapp.com/roles")?.Value; + if (!string.IsNullOrEmpty(roles)) + { + var roleList = JsonSerializer.Deserialize(roles); + foreach (var role in roleList) + { + claimsIdentity.AddClaim(new Claim(ClaimTypes.Role, role)); + } + } + + return Task.FromResult(principal); + } +} +``` + +### Azure AD B2C Integration + +```csharp +// Program.cs - Configure Azure AD B2C +builder.Services.AddAuthentication(JwtBearerDefaults.AuthenticationScheme) + .AddMicrosoftIdentityWebApi(builder.Configuration.GetSection("AzureAdB2C")); + +// appsettings.json +{ + "AzureAdB2C": { + "Instance": "https://yourtenant.b2clogin.com", + "ClientId": "your-client-id", + "Domain": "yourtenant.onmicrosoft.com", + "SignUpSignInPolicyId": "B2C_1_signupsignin1" + } +} +``` + +## Redis Integration for Rate Limiting + +### Production-Ready Rate Limiting Service + +```csharp +// Services/RedisRateLimitingService.cs +public class RedisRateLimitingService : IRateLimitingService +{ + private readonly IDatabase _database; + private readonly ILogger _logger; + private const string USAGE_KEY_PREFIX = "rate_limit:usage:"; + private const string STATISTICS_KEY_PREFIX = "rate_limit:stats:"; + + public RedisRateLimitingService(IConnectionMultiplexer redis, ILogger logger) + { + _database = redis.GetDatabase(); + _logger = logger; + } + + public async Task GetUsageCountAsync(string userId, string toolName, DateTime windowStart, CancellationToken cancellationToken = default) + { + var key = GetUsageKey(userId, toolName); + var windowEnd = windowStart.AddHours(1); // 1-hour window + + var count = await _database.SortedSetCountAsync(key, + windowStart.Ticks, windowEnd.Ticks); + + return (int)count; + } + + public async Task RecordUsageAsync(string userId, string toolName, DateTime timestamp, CancellationToken cancellationToken = default) + { + var key = GetUsageKey(userId, toolName); + var score = timestamp.Ticks; + + // Add usage record + await _database.SortedSetAddAsync(key, Guid.NewGuid().ToString(), score); + + // Set expiration for cleanup + await _database.KeyExpireAsync(key, TimeSpan.FromDays(1)); + + // Update statistics + await UpdateStatisticsAsync(userId, toolName, timestamp); + + _logger.LogDebug("Recorded usage for {UserId}, {ToolName} at {Timestamp}", + userId, toolName, timestamp); + } + + public async Task CleanupOldRecordsAsync(CancellationToken cancellationToken = default) + { + var cutoffTime = DateTime.UtcNow.AddHours(-24); + var pattern = $"{USAGE_KEY_PREFIX}*"; + + await foreach (var key in _database.Multiplexer.GetServer().KeysAsync(pattern: pattern)) + { + await _database.SortedSetRemoveRangeByScoreAsync(key, + double.NegativeInfinity, cutoffTime.Ticks); + } + } + + public async Task> GetUsageStatisticsAsync(string userId, CancellationToken cancellationToken = default) + { + var statsKey = $"{STATISTICS_KEY_PREFIX}{userId}"; + var hash = await _database.HashGetAllAsync(statsKey); + + return hash.ToDictionary( + h => h.Name.ToString(), + h => (int)h.Value + ); + } + + private async Task UpdateStatisticsAsync(string userId, string toolName, DateTime timestamp) + { + var statsKey = $"{STATISTICS_KEY_PREFIX}{userId}"; + var hourKey = $"{toolName}:{timestamp:yyyy-MM-dd:HH}"; + + await _database.HashIncrementAsync(statsKey, hourKey); + await _database.KeyExpireAsync(statsKey, TimeSpan.FromDays(30)); + } + + private string GetUsageKey(string userId, string toolName) => + $"{USAGE_KEY_PREFIX}{userId}:{toolName}"; +} + +// Program.cs - Register Redis services +builder.Services.AddStackExchangeRedisCache(options => +{ + options.Configuration = builder.Configuration.GetConnectionString("Redis"); +}); + +builder.Services.AddSingleton(provider => + ConnectionMultiplexer.Connect(builder.Configuration.GetConnectionString("Redis"))); + +builder.Services.AddSingleton(); +``` + +## Database Integration for Quotas + +### Entity Framework Quota Service + +```csharp +// Models/QuotaDbContext.cs +public class QuotaDbContext : DbContext +{ + public QuotaDbContext(DbContextOptions options) : base(options) { } + + public DbSet UserQuotas { get; set; } + public DbSet QuotaUsages { get; set; } + public DbSet QuotaResets { get; set; } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity(entity => + { + entity.HasKey(e => e.UserId); + entity.Property(e => e.UserId).HasMaxLength(100); + entity.Property(e => e.Role).HasMaxLength(50); + entity.HasIndex(e => e.Role); + entity.HasIndex(e => e.NextResetDate); + }); + + modelBuilder.Entity(entity => + { + entity.HasKey(e => e.Id); + entity.Property(e => e.UserId).HasMaxLength(100); + entity.Property(e => e.ToolName).HasMaxLength(100); + entity.HasIndex(e => new { e.UserId, e.ToolName }); + entity.HasIndex(e => e.UsageDate); + }); + + modelBuilder.Entity(entity => + { + entity.HasKey(e => e.Id); + entity.Property(e => e.UserId).HasMaxLength(100); + entity.HasIndex(e => e.UserId); + entity.HasIndex(e => e.ResetDate); + }); + } +} + +// Models/QuotaEntities.cs +public class UserQuota +{ + public string UserId { get; set; } = ""; + public string Role { get; set; } = ""; + public int CurrentUsage { get; set; } + public int QuotaLimit { get; set; } + public DateTime NextResetDate { get; set; } + public DateTime LastUsage { get; set; } + public DateTime CreatedAt { get; set; } + public DateTime UpdatedAt { get; set; } +} + +public class QuotaUsage +{ + public int Id { get; set; } + public string UserId { get; set; } = ""; + public string ToolName { get; set; } = ""; + public int Cost { get; set; } + public DateTime UsageDate { get; set; } + public string? RequestId { get; set; } +} + +public class QuotaReset +{ + public int Id { get; set; } + public string UserId { get; set; } = ""; + public DateTime ResetDate { get; set; } + public int PreviousUsage { get; set; } + public string ResetReason { get; set; } = ""; +} + +// Services/DatabaseQuotaService.cs +public class DatabaseQuotaService : IQuotaService +{ + private readonly QuotaDbContext _context; + private readonly QuotaManagementOptions _options; + private readonly ILogger _logger; + + public DatabaseQuotaService( + QuotaDbContext context, + IOptions options, + ILogger logger) + { + _context = context; + _options = options.Value.BusinessLogic.QuotaManagement; + _logger = logger; + } + + public async Task HasAvailableQuotaAsync(string userId, string userRole, string toolName, CancellationToken cancellationToken = default) + { + var userQuota = await GetOrCreateUserQuotaAsync(userId, userRole); + var quotaCost = GetQuotaCost(toolName); + + return userQuota.QuotaLimit == -1 || // Unlimited + userQuota.CurrentUsage + quotaCost <= userQuota.QuotaLimit; + } + + public async Task ConsumeQuotaAsync(string userId, string toolName, int cost, CancellationToken cancellationToken = default) + { + using var transaction = await _context.Database.BeginTransactionAsync(cancellationToken); + + try + { + var userQuota = await _context.UserQuotas + .FirstOrDefaultAsync(q => q.UserId == userId, cancellationToken); + + if (userQuota != null) + { + userQuota.CurrentUsage += cost; + userQuota.LastUsage = DateTime.UtcNow; + userQuota.UpdatedAt = DateTime.UtcNow; + } + + var usage = new QuotaUsage + { + UserId = userId, + ToolName = toolName, + Cost = cost, + UsageDate = DateTime.UtcNow, + RequestId = Guid.NewGuid().ToString() + }; + + _context.QuotaUsages.Add(usage); + await _context.SaveChangesAsync(cancellationToken); + await transaction.CommitAsync(cancellationToken); + + _logger.LogDebug("Consumed {Cost} quota for user {UserId}, tool {ToolName}", + cost, userId, toolName); + } + catch + { + await transaction.RollbackAsync(cancellationToken); + throw; + } + } + + public async Task GetCurrentUsageAsync(string userId, CancellationToken cancellationToken = default) + { + var userQuota = await _context.UserQuotas + .FirstOrDefaultAsync(q => q.UserId == userId, cancellationToken); + + return userQuota?.CurrentUsage ?? 0; + } + + private async Task GetOrCreateUserQuotaAsync(string userId, string userRole) + { + var userQuota = await _context.UserQuotas + .FirstOrDefaultAsync(q => q.UserId == userId); + + if (userQuota == null) + { + var quotaLimit = GetQuotaLimitForRole(userRole); + userQuota = new UserQuota + { + UserId = userId, + Role = userRole, + CurrentUsage = 0, + QuotaLimit = quotaLimit, + NextResetDate = CalculateNextResetDate(DateTime.UtcNow), + CreatedAt = DateTime.UtcNow, + UpdatedAt = DateTime.UtcNow + }; + + _context.UserQuotas.Add(userQuota); + await _context.SaveChangesAsync(); + } + + return userQuota; + } + + private int GetQuotaLimitForRole(string userRole) + { + return _options.RoleQuotas.TryGetValue(userRole, out var limit) ? limit : 1000; + } + + private int GetQuotaCost(string toolName) + { + foreach (var mapping in _options.ToolQuotaCosts) + { + if (IsPatternMatch(mapping.Key, toolName)) + { + return mapping.Value; + } + } + return 1; + } + + private DateTime CalculateNextResetDate(DateTime fromDate) + { + return fromDate.AddDays(_options.QuotaPeriodDays); + } + + private static bool IsPatternMatch(string pattern, string toolName) + { + if (pattern == "*") return true; + if (pattern.EndsWith("*")) + { + var prefix = pattern[..^1]; + return toolName.StartsWith(prefix, StringComparison.OrdinalIgnoreCase); + } + return string.Equals(pattern, toolName, StringComparison.OrdinalIgnoreCase); + } +} + +// Program.cs - Register database services +builder.Services.AddDbContext(options => + options.UseSqlServer(builder.Configuration.GetConnectionString("DefaultConnection"))); + +builder.Services.AddScoped(); +``` + +## External Feature Flag Services + +### LaunchDarkly Integration + +```csharp +// Services/LaunchDarklyFeatureFlagService.cs +public class LaunchDarklyFeatureFlagService : IFeatureFlagService +{ + private readonly LdClient _client; + private readonly ILogger _logger; + + public LaunchDarklyFeatureFlagService(IConfiguration configuration, ILogger logger) + { + var sdkKey = configuration["LaunchDarkly:SdkKey"]; + var config = Configuration.Default(sdkKey); + _client = new LdClient(config); + _logger = logger; + } + + public Task IsEnabledAsync(string flagName, string userId, CancellationToken cancellationToken = default) + { + var user = User.WithKey(userId); + var isEnabled = _client.BoolVariation(flagName, user, false); + + _logger.LogDebug("Feature flag {FlagName} for user {UserId}: {Enabled}", flagName, userId, isEnabled); + + return Task.FromResult(isEnabled); + } + + public async Task> GetAllFlagsAsync(string userId, CancellationToken cancellationToken = default) + { + var user = User.WithKey(userId); + var allFlags = _client.AllFlagsState(user); + + var result = new Dictionary(); + foreach (var flag in allFlags.ToValuesMap()) + { + if (flag.Value is bool boolValue) + { + result[flag.Key] = boolValue; + } + } + + return result; + } + + public Task SetFlagAsync(string flagName, bool enabled, string? userId = null, CancellationToken cancellationToken = default) + { + // LaunchDarkly doesn't support programmatic flag setting from SDK + // This would typically be done through their REST API or dashboard + _logger.LogWarning("Cannot set flag {FlagName} programmatically with LaunchDarkly SDK", flagName); + return Task.CompletedTask; + } + + public Task GetRolloutPercentageAsync(string flagName, CancellationToken cancellationToken = default) + { + // This would require calling LaunchDarkly's REST API to get flag configuration + _logger.LogWarning("Cannot get rollout percentage for flag {FlagName} with current implementation", flagName); + return Task.FromResult(100); + } +} + +// Program.cs +builder.Services.AddSingleton(); +``` + +### Azure App Configuration Integration + +```csharp +// Services/AzureFeatureFlagService.cs +public class AzureFeatureFlagService : IFeatureFlagService +{ + private readonly IFeatureManager _featureManager; + private readonly ILogger _logger; + + public AzureFeatureFlagService(IFeatureManager featureManager, ILogger logger) + { + _featureManager = featureManager; + _logger = logger; + } + + public async Task IsEnabledAsync(string flagName, string userId, CancellationToken cancellationToken = default) + { + var context = new TargetingContext + { + UserId = userId + }; + + var isEnabled = await _featureManager.IsEnabledAsync(flagName, context); + + _logger.LogDebug("Feature flag {FlagName} for user {UserId}: {Enabled}", flagName, userId, isEnabled); + + return isEnabled; + } + + public async Task> GetAllFlagsAsync(string userId, CancellationToken cancellationToken = default) + { + var context = new TargetingContext { UserId = userId }; + var result = new Dictionary(); + + var flagNames = new[] { "premium_features", "admin_performance_tools", "experimental_tools", "beta_features" }; + + foreach (var flagName in flagNames) + { + result[flagName] = await _featureManager.IsEnabledAsync(flagName, context); + } + + return result; + } + + // Other methods would interact with Azure App Configuration REST API +} + +// Program.cs +builder.Configuration.AddAzureAppConfiguration(options => +{ + options.Connect(builder.Configuration.GetConnectionString("AzureAppConfiguration")) + .UseFeatureFlags(); +}); + +builder.Services.AddFeatureManagement(); +builder.Services.AddSingleton(); +``` + +## Multi-Tenant SaaS Integration + +### Advanced Tenant Isolation Filter + +```csharp +// Services/TenantManagementService.cs +public interface ITenantManagementService +{ + Task GetTenantAsync(string tenantId); + Task> GetTenantToolsAsync(string tenantId); + Task> GetTenantRateLimitsAsync(string tenantId); + Task IsTenantActiveAsync(string tenantId); +} + +public class TenantManagementService : ITenantManagementService +{ + private readonly HttpClient _httpClient; + private readonly IMemoryCache _cache; + private readonly ILogger _logger; + + public TenantManagementService(HttpClient httpClient, IMemoryCache cache, ILogger logger) + { + _httpClient = httpClient; + _cache = cache; + _logger = logger; + } + + public async Task GetTenantAsync(string tenantId) + { + var cacheKey = $"tenant:{tenantId}"; + + if (_cache.TryGetValue(cacheKey, out TenantInfo? cachedTenant)) + { + return cachedTenant; + } + + try + { + var response = await _httpClient.GetAsync($"/api/tenants/{tenantId}"); + if (response.IsSuccessStatusCode) + { + var json = await response.Content.ReadAsStringAsync(); + var tenant = JsonSerializer.Deserialize(json); + + _cache.Set(cacheKey, tenant, TimeSpan.FromMinutes(15)); + return tenant; + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to fetch tenant {TenantId}", tenantId); + } + + return null; + } + + public async Task> GetTenantToolsAsync(string tenantId) + { + var tenant = await GetTenantAsync(tenantId); + return tenant?.AllowedTools?.ToList() ?? new List(); + } + + public async Task> GetTenantRateLimitsAsync(string tenantId) + { + var tenant = await GetTenantAsync(tenantId); + return tenant?.CustomRateLimits ?? new Dictionary(); + } + + public async Task IsTenantActiveAsync(string tenantId) + { + var tenant = await GetTenantAsync(tenantId); + return tenant?.IsActive ?? false; + } +} + +// Enhanced Tenant Isolation Filter +public class EnhancedTenantIsolationFilter : IToolFilter +{ + private readonly TenantIsolationOptions _options; + private readonly ITenantManagementService _tenantService; + private readonly ILogger _logger; + + public EnhancedTenantIsolationFilter( + IOptions options, + ITenantManagementService tenantService, + ILogger logger) + { + _options = options.Value.TenantIsolation; + _tenantService = tenantService; + _logger = logger; + } + + public int Priority => _options.Priority; + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) return true; + + var tenantId = GetTenantId(context); + if (string.IsNullOrEmpty(tenantId)) return true; + + var allowedTools = await _tenantService.GetTenantToolsAsync(tenantId); + return IsToolAllowed(tool.Name, allowedTools); + } + + public async Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + return AuthorizationResult.Allow("Tenant isolation disabled"); + + var tenantId = GetTenantId(context); + if (string.IsNullOrEmpty(tenantId)) + return AuthorizationResult.Allow("No tenant context"); + + // Check if tenant is active + if (!await _tenantService.IsTenantActiveAsync(tenantId)) + { + var reason = $"Tenant '{tenantId}' is not active"; + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Tenant", + ("realm", "mcp-api"), + ("tenant_id", tenantId), + ("status", "inactive")); + + return AuthorizationResult.DenyWithChallenge(reason, challenge); + } + + // Check tool access + var allowedTools = await _tenantService.GetTenantToolsAsync(tenantId); + if (!IsToolAllowed(toolName, allowedTools)) + { + var reason = $"Tool '{toolName}' is not allowed for tenant '{tenantId}'"; + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Tenant", + ("realm", "mcp-api"), + ("tenant_id", tenantId), + ("tool_name", toolName), + ("restriction", "tool_not_allowed")); + + return AuthorizationResult.DenyWithChallenge(reason, challenge); + } + + return AuthorizationResult.Allow($"Tool '{toolName}' allowed for tenant '{tenantId}'"); + } + + private string? GetTenantId(ToolAuthorizationContext context) + { + // Try to get tenant ID from claims first + var tenantClaim = context.User?.FindFirst(_options.TenantClaimType)?.Value; + if (!string.IsNullOrEmpty(tenantClaim)) + return tenantClaim; + + // Fall back to header if available in the context + // This would need to be passed through from the HTTP context + return context.AdditionalData?.TryGetValue("TenantId", out var tenantId) == true + ? tenantId?.ToString() + : null; + } + + private bool IsToolAllowed(string toolName, List allowedTools) + { + return allowedTools.Any(pattern => IsPatternMatch(pattern, toolName)); + } + + private static bool IsPatternMatch(string pattern, string toolName) + { + if (pattern == "*") return true; + if (pattern.EndsWith("*")) + { + var prefix = pattern[..^1]; + return toolName.StartsWith(prefix, StringComparison.OrdinalIgnoreCase); + } + return string.Equals(pattern, toolName, StringComparison.OrdinalIgnoreCase); + } +} + +// Models/TenantInfo.cs +public class TenantInfo +{ + public string Id { get; set; } = ""; + public string Name { get; set; } = ""; + public bool IsActive { get; set; } + public string SubscriptionTier { get; set; } = ""; + public string[] AllowedTools { get; set; } = Array.Empty(); + public string[] DeniedTools { get; set; } = Array.Empty(); + public Dictionary CustomRateLimits { get; set; } = new(); + public DateTime CreatedAt { get; set; } + public DateTime UpdatedAt { get; set; } +} +``` + +## Monitoring and Observability + +### Application Insights Integration + +```csharp +// Services/TelemetryFilterWrapper.cs +public class TelemetryFilterWrapper : IToolFilter +{ + private readonly IToolFilter _innerFilter; + private readonly TelemetryClient _telemetryClient; + private readonly ILogger _logger; + + public TelemetryFilterWrapper(IToolFilter innerFilter, TelemetryClient telemetryClient, ILogger logger) + { + _innerFilter = innerFilter; + _telemetryClient = telemetryClient; + _logger = logger; + } + + public int Priority => _innerFilter.Priority; + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var stopwatch = Stopwatch.StartNew(); + var filterName = _innerFilter.GetType().Name; + + try + { + var result = await _innerFilter.ShouldIncludeToolAsync(tool, context, cancellationToken); + + stopwatch.Stop(); + + _telemetryClient.TrackDependency("Filter", filterName, $"ShouldInclude:{tool.Name}", + DateTime.UtcNow.Subtract(stopwatch.Elapsed), stopwatch.Elapsed, result.ToString()); + + _telemetryClient.TrackMetric($"Filter.{filterName}.ShouldInclude.Duration", stopwatch.ElapsedMilliseconds); + _telemetryClient.TrackMetric($"Filter.{filterName}.ShouldInclude.{(result ? "Allow" : "Deny")}", 1); + + return result; + } + catch (Exception ex) + { + stopwatch.Stop(); + + _telemetryClient.TrackException(ex, new Dictionary + { + ["FilterName"] = filterName, + ["ToolName"] = tool.Name, + ["Operation"] = "ShouldIncludeToolAsync" + }); + + throw; + } + } + + public async Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var stopwatch = Stopwatch.StartNew(); + var filterName = _innerFilter.GetType().Name; + + try + { + var result = await _innerFilter.CanExecuteToolAsync(toolName, context, cancellationToken); + + stopwatch.Stop(); + + _telemetryClient.TrackDependency("Filter", filterName, $"CanExecute:{toolName}", + DateTime.UtcNow.Subtract(stopwatch.Elapsed), stopwatch.Elapsed, result.IsAuthorized.ToString()); + + _telemetryClient.TrackMetric($"Filter.{filterName}.CanExecute.Duration", stopwatch.ElapsedMilliseconds); + _telemetryClient.TrackMetric($"Filter.{filterName}.CanExecute.{(result.IsAuthorized ? "Allow" : "Deny")}", 1); + + if (!result.IsAuthorized) + { + _telemetryClient.TrackEvent("FilterDenied", new Dictionary + { + ["FilterName"] = filterName, + ["ToolName"] = toolName, + ["Reason"] = result.Reason, + ["UserId"] = context.User?.FindFirst(ClaimTypes.NameIdentifier)?.Value ?? "Anonymous" + }); + } + + return result; + } + catch (Exception ex) + { + stopwatch.Stop(); + + _telemetryClient.TrackException(ex, new Dictionary + { + ["FilterName"] = filterName, + ["ToolName"] = toolName, + ["Operation"] = "CanExecuteToolAsync" + }); + + throw; + } + } +} + +// Program.cs - Wrap filters with telemetry +builder.Services.AddApplicationInsightsTelemetry(); + +builder.Services.Decorate(); +``` + +### Prometheus Metrics Integration + +```csharp +// Services/MetricsCollectionService.cs +public class MetricsCollectionService : IHostedService +{ + private readonly IServiceProvider _serviceProvider; + private readonly Counter _filterExecutionCounter; + private readonly Histogram _filterExecutionDuration; + private readonly Gauge _activeFiltersGauge; + + public MetricsCollectionService(IServiceProvider serviceProvider) + { + _serviceProvider = serviceProvider; + + _filterExecutionCounter = Metrics.CreateCounter( + "filter_executions_total", + "Total number of filter executions", + new[] { "filter_name", "operation", "result" }); + + _filterExecutionDuration = Metrics.CreateHistogram( + "filter_execution_duration_seconds", + "Duration of filter executions", + new[] { "filter_name", "operation" }); + + _activeFiltersGauge = Metrics.CreateGauge( + "active_filters_count", + "Number of active filters"); + } + + public Task StartAsync(CancellationToken cancellationToken) + { + // Initialize metrics collection + using var scope = _serviceProvider.CreateScope(); + var filters = scope.ServiceProvider.GetServices(); + _activeFiltersGauge.Set(filters.Count()); + + return Task.CompletedTask; + } + + public Task StopAsync(CancellationToken cancellationToken) => Task.CompletedTask; + + public void RecordFilterExecution(string filterName, string operation, string result, double durationSeconds) + { + _filterExecutionCounter.WithLabels(filterName, operation, result).Inc(); + _filterExecutionDuration.WithLabels(filterName, operation).Observe(durationSeconds); + } +} + +// Program.cs - Add Prometheus +builder.Services.AddSingleton(); +builder.Services.AddHostedService(); + +// In the request pipeline +app.UseMetricServer(); // Expose /metrics endpoint +``` + +## Custom Filter Development + +### Custom Business Logic Filter Example + +```csharp +// Filters/GeographicRestrictionFilter.cs +public class GeographicRestrictionFilter : IToolFilter +{ + private readonly IConfiguration _configuration; + private readonly ILogger _logger; + private readonly Dictionary _toolRegionMapping; + + public GeographicRestrictionFilter(IConfiguration configuration, ILogger logger) + { + _configuration = configuration; + _logger = logger; + + // Load region mappings from configuration + _toolRegionMapping = configuration.GetSection("GeographicRestrictions:ToolRegionMapping") + .Get>() ?? new Dictionary(); + } + + public int Priority => 125; // Between role-based and scope-based + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Geographic restrictions don't affect tool visibility + return Task.FromResult(true); + } + + public async Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var userRegion = GetUserRegion(context); + var allowedRegions = GetAllowedRegions(toolName); + + if (allowedRegions.Length == 0) + { + // No geographic restrictions for this tool + return AuthorizationResult.Allow("No geographic restrictions"); + } + + if (string.IsNullOrEmpty(userRegion)) + { + var reason = $"Tool '{toolName}' has geographic restrictions but user region could not be determined"; + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Geographic", + ("realm", "mcp-api"), + ("tool_name", toolName), + ("allowed_regions", string.Join(",", allowedRegions))); + + return AuthorizationResult.DenyWithChallenge(reason, challenge); + } + + if (!allowedRegions.Contains(userRegion, StringComparer.OrdinalIgnoreCase)) + { + var reason = $"Tool '{toolName}' is not available in region '{userRegion}'. Allowed regions: {string.Join(", ", allowedRegions)}"; + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Geographic", + ("realm", "mcp-api"), + ("tool_name", toolName), + ("user_region", userRegion), + ("allowed_regions", string.Join(",", allowedRegions))); + + _logger.LogWarning("Geographic restriction denied: {ToolName} for region {UserRegion}", toolName, userRegion); + + return AuthorizationResult.DenyWithChallenge(reason, challenge); + } + + return AuthorizationResult.Allow($"Tool '{toolName}' allowed in region '{userRegion}'"); + } + + private string? GetUserRegion(ToolAuthorizationContext context) + { + // Try to get region from claims + var regionClaim = context.User?.FindFirst("region")?.Value + ?? context.User?.FindFirst("geo_region")?.Value; + + if (!string.IsNullOrEmpty(regionClaim)) + return regionClaim; + + // Could also determine region from IP address using a geolocation service + // This would require additional context data to be passed through + + return context.AdditionalData?.TryGetValue("UserRegion", out var region) == true + ? region?.ToString() + : null; + } + + private string[] GetAllowedRegions(string toolName) + { + foreach (var mapping in _toolRegionMapping) + { + if (IsPatternMatch(mapping.Key, toolName)) + { + return mapping.Value; + } + } + + return Array.Empty(); + } + + private static bool IsPatternMatch(string pattern, string toolName) + { + if (pattern == "*") return true; + if (pattern.EndsWith("*")) + { + var prefix = pattern[..^1]; + return toolName.StartsWith(prefix, StringComparison.OrdinalIgnoreCase); + } + return string.Equals(pattern, toolName, StringComparison.OrdinalIgnoreCase); + } +} + +// Configuration example +// appsettings.json +{ + "GeographicRestrictions": { + "Enabled": true, + "ToolRegionMapping": { + "admin_*": ["US", "CA", "EU"], + "premium_financial_*": ["US", "UK", "EU"], + "compliance_*": ["US"] + } + } +} + +// Register the filter +builder.Services.AddSingleton(); +``` + +This comprehensive integration guide shows how the Dynamic Tool Filtering system can be extended and integrated with real-world services and infrastructure. Each example demonstrates production-ready patterns and best practices for building scalable, secure MCP applications. \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Models/FilterResult.cs b/samples/DynamicToolFiltering/Models/FilterResult.cs new file mode 100644 index 00000000..4bd773fa --- /dev/null +++ b/samples/DynamicToolFiltering/Models/FilterResult.cs @@ -0,0 +1,135 @@ +namespace DynamicToolFiltering.Models; + +/// +/// Represents the result of a filter operation. +/// +public class FilterResult +{ + /// + /// Gets or sets whether the filter passed. + /// + public bool Passed { get; set; } + + /// + /// Gets or sets the filter name that produced this result. + /// + public string FilterName { get; set; } = string.Empty; + + /// + /// Gets or sets the priority of the filter that produced this result. + /// + public int Priority { get; set; } + + /// + /// Gets or sets the reason for the filter result. + /// + public string Reason { get; set; } = string.Empty; + + /// + /// Gets or sets additional data from the filter. + /// + public Dictionary Data { get; set; } = new(); + + /// + /// Gets or sets the timestamp when the filter was evaluated. + /// + public DateTime EvaluatedAt { get; set; } = DateTime.UtcNow; + + /// + /// Gets or sets the execution time of the filter in milliseconds. + /// + public double ExecutionTimeMs { get; set; } + + /// + /// Creates a successful filter result. + /// + /// The name of the filter. + /// The reason for success. + /// The filter priority. + /// A successful filter result. + public static FilterResult Success(string filterName, string reason, int priority) + { + return new FilterResult + { + Passed = true, + FilterName = filterName, + Reason = reason, + Priority = priority + }; + } + + /// + /// Creates a failed filter result. + /// + /// The name of the filter. + /// The reason for failure. + /// The filter priority. + /// A failed filter result. + public static FilterResult Failure(string filterName, string reason, int priority) + { + return new FilterResult + { + Passed = false, + FilterName = filterName, + Reason = reason, + Priority = priority + }; + } +} + +/// +/// Represents a collection of filter results from multiple filters. +/// +public class FilterResultCollection +{ + /// + /// Gets or sets the list of individual filter results. + /// + public List Results { get; set; } = new(); + + /// + /// Gets or sets the overall result (all filters must pass). + /// + public bool OverallResult => Results.All(r => r.Passed); + + /// + /// Gets or sets the first failed filter result, if any. + /// + public FilterResult? FirstFailure => Results.FirstOrDefault(r => !r.Passed); + + /// + /// Gets or sets the total execution time for all filters. + /// + public double TotalExecutionTimeMs => Results.Sum(r => r.ExecutionTimeMs); + + /// + /// Gets or sets the number of filters that were evaluated. + /// + public int FilterCount => Results.Count; + + /// + /// Adds a filter result to the collection. + /// + /// The filter result to add. + public void AddResult(FilterResult result) + { + Results.Add(result); + } + + /// + /// Gets a summary of the filter results. + /// + /// A summary string of the filter results. + public string GetSummary() + { + if (OverallResult) + { + return $"All {FilterCount} filters passed in {TotalExecutionTimeMs:F2}ms"; + } + else + { + var firstFailure = FirstFailure!; + return $"Filter '{firstFailure.FilterName}' failed: {firstFailure.Reason}"; + } + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Models/ToolExecutionContext.cs b/samples/DynamicToolFiltering/Models/ToolExecutionContext.cs new file mode 100644 index 00000000..5d690464 --- /dev/null +++ b/samples/DynamicToolFiltering/Models/ToolExecutionContext.cs @@ -0,0 +1,90 @@ +using System.Security.Claims; + +namespace DynamicToolFiltering.Models; + +/// +/// Represents the execution context for a tool call with relevant filtering information. +/// +public class ToolExecutionContext +{ + /// + /// Gets or sets the name of the tool being executed. + /// + public string ToolName { get; set; } = string.Empty; + + /// + /// Gets or sets the user executing the tool. + /// + public UserInfo? User { get; set; } + + /// + /// Gets or sets the claims principal for the current user. + /// + public ClaimsPrincipal? ClaimsPrincipal { get; set; } + + /// + /// Gets or sets the session ID for the current session. + /// + public string? SessionId { get; set; } + + /// + /// Gets or sets the client information. + /// + public string? ClientInfo { get; set; } + + /// + /// Gets or sets the IP address of the client. + /// + public string? ClientIpAddress { get; set; } + + /// + /// Gets or sets the timestamp when the execution was requested. + /// + public DateTime RequestedAt { get; set; } = DateTime.UtcNow; + + /// + /// Gets or sets the tenant context if applicable. + /// + public TenantContext? TenantContext { get; set; } + + /// + /// Gets or sets the execution environment. + /// + public string Environment { get; set; } = "Development"; + + /// + /// Gets or sets additional context data for filters. + /// + public Dictionary AdditionalData { get; set; } = new(); +} + +/// +/// Represents tenant context information. +/// +public class TenantContext +{ + /// + /// Gets or sets the tenant identifier. + /// + public string TenantId { get; set; } = string.Empty; + + /// + /// Gets or sets the tenant name. + /// + public string Name { get; set; } = string.Empty; + + /// + /// Gets or sets whether the tenant is active. + /// + public bool IsActive { get; set; } = true; + + /// + /// Gets or sets the tenant's subscription tier. + /// + public string SubscriptionTier { get; set; } = "Basic"; + + /// + /// Gets or sets tenant-specific settings. + /// + public Dictionary Settings { get; set; } = new(); +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Models/UsageStatistics.cs b/samples/DynamicToolFiltering/Models/UsageStatistics.cs new file mode 100644 index 00000000..35efe025 --- /dev/null +++ b/samples/DynamicToolFiltering/Models/UsageStatistics.cs @@ -0,0 +1,168 @@ +namespace DynamicToolFiltering.Models; + +/// +/// Represents usage statistics for a user or tool. +/// +public class UsageStatistics +{ + /// + /// Gets or sets the user ID. + /// + public string UserId { get; set; } = string.Empty; + + /// + /// Gets or sets the total number of tool executions. + /// + public int TotalExecutions { get; set; } + + /// + /// Gets or sets the number of successful executions. + /// + public int SuccessfulExecutions { get; set; } + + /// + /// Gets or sets the number of failed executions. + /// + public int FailedExecutions { get; set; } + + /// + /// Gets or sets the number of executions blocked by filters. + /// + public int BlockedExecutions { get; set; } + + /// + /// Gets or sets per-tool usage counts. + /// + public Dictionary ToolUsageCounts { get; set; } = new(); + + /// + /// Gets or sets per-filter block counts. + /// + public Dictionary FilterBlockCounts { get; set; } = new(); + + /// + /// Gets or sets the first execution timestamp. + /// + public DateTime? FirstExecutionAt { get; set; } + + /// + /// Gets or sets the last execution timestamp. + /// + public DateTime? LastExecutionAt { get; set; } + + /// + /// Gets or sets the current quota usage. + /// + public int QuotaUsed { get; set; } + + /// + /// Gets or sets the quota limit. + /// + public int QuotaLimit { get; set; } + + /// + /// Gets or sets when the quota period resets. + /// + public DateTime? QuotaResetAt { get; set; } + + /// + /// Gets the success rate as a percentage. + /// + public double SuccessRate => TotalExecutions > 0 ? (double)SuccessfulExecutions / TotalExecutions * 100 : 0; + + /// + /// Gets the block rate as a percentage. + /// + public double BlockRate => TotalExecutions > 0 ? (double)BlockedExecutions / TotalExecutions * 100 : 0; + + /// + /// Gets the remaining quota. + /// + public int RemainingQuota => Math.Max(0, QuotaLimit - QuotaUsed); + + /// + /// Gets whether the quota is unlimited. + /// + public bool IsUnlimitedQuota => QuotaLimit == -1; +} + +/// +/// Represents aggregated usage statistics across multiple users or time periods. +/// +public class AggregatedUsageStatistics +{ + /// + /// Gets or sets the time period for these statistics. + /// + public TimeSpan Period { get; set; } + + /// + /// Gets or sets the start time of the statistics period. + /// + public DateTime PeriodStart { get; set; } + + /// + /// Gets or sets the end time of the statistics period. + /// + public DateTime PeriodEnd { get; set; } + + /// + /// Gets or sets the total number of unique users. + /// + public int UniqueUsers { get; set; } + + /// + /// Gets or sets the total number of tool executions. + /// + public int TotalExecutions { get; set; } + + /// + /// Gets or sets the total number of successful executions. + /// + public int SuccessfulExecutions { get; set; } + + /// + /// Gets or sets the total number of failed executions. + /// + public int FailedExecutions { get; set; } + + /// + /// Gets or sets the total number of blocked executions. + /// + public int BlockedExecutions { get; set; } + + /// + /// Gets or sets the most popular tools by execution count. + /// + public Dictionary PopularTools { get; set; } = new(); + + /// + /// Gets or sets the most active users by execution count. + /// + public Dictionary ActiveUsers { get; set; } = new(); + + /// + /// Gets or sets filter blocking statistics. + /// + public Dictionary FilterBlockStats { get; set; } = new(); + + /// + /// Gets or sets peak usage hours. + /// + public Dictionary HourlyUsage { get; set; } = new(); + + /// + /// Gets the overall success rate as a percentage. + /// + public double OverallSuccessRate => TotalExecutions > 0 ? (double)SuccessfulExecutions / TotalExecutions * 100 : 0; + + /// + /// Gets the overall block rate as a percentage. + /// + public double OverallBlockRate => TotalExecutions > 0 ? (double)BlockedExecutions / TotalExecutions * 100 : 0; + + /// + /// Gets the average executions per user. + /// + public double AverageExecutionsPerUser => UniqueUsers > 0 ? (double)TotalExecutions / UniqueUsers : 0; +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Models/UserInfo.cs b/samples/DynamicToolFiltering/Models/UserInfo.cs new file mode 100644 index 00000000..3d667437 --- /dev/null +++ b/samples/DynamicToolFiltering/Models/UserInfo.cs @@ -0,0 +1,57 @@ +namespace DynamicToolFiltering.Models; + +/// +/// Represents user information for the filtering system. +/// +public class UserInfo +{ + /// + /// Gets or sets the unique user identifier. + /// + public string UserId { get; set; } = string.Empty; + + /// + /// Gets or sets the user's display name. + /// + public string Name { get; set; } = string.Empty; + + /// + /// Gets or sets the user's email address. + /// + public string Email { get; set; } = string.Empty; + + /// + /// Gets or sets the user's primary role. + /// + public string Role { get; set; } = "guest"; + + /// + /// Gets or sets the list of scopes assigned to the user. + /// + public List Scopes { get; set; } = new(); + + /// + /// Gets or sets the tenant ID associated with the user. + /// + public string? TenantId { get; set; } + + /// + /// Gets or sets whether the user is currently active. + /// + public bool IsActive { get; set; } = true; + + /// + /// Gets or sets when the user was created. + /// + public DateTime CreatedAt { get; set; } = DateTime.UtcNow; + + /// + /// Gets or sets when the user last authenticated. + /// + public DateTime? LastLoginAt { get; set; } + + /// + /// Gets or sets custom user properties. + /// + public Dictionary Properties { get; set; } = new(); +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Program.cs b/samples/DynamicToolFiltering/Program.cs new file mode 100644 index 00000000..25a4087d --- /dev/null +++ b/samples/DynamicToolFiltering/Program.cs @@ -0,0 +1,322 @@ +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.Extensions.Options; +using Microsoft.IdentityModel.Tokens; +using OpenTelemetry; +using OpenTelemetry.Metrics; +using OpenTelemetry.Trace; +using Serilog; +using System.Security.Claims; +using System.Text; +using System.Text.Encodings.Web; + +using DynamicToolFiltering.Authorization.Filters; +using DynamicToolFiltering.Configuration; +using DynamicToolFiltering.Services; +using DynamicToolFiltering.Tools; +using ModelContextProtocol.Server.Authorization; + +// Configure Serilog for comprehensive logging +Log.Logger = new LoggerConfiguration() + .WriteTo.Console() + .WriteTo.File("logs/dynamic-tool-filtering-.txt", rollingInterval: RollingInterval.Day) + .CreateLogger(); + +var builder = WebApplication.CreateBuilder(args); + +// Replace default logging with Serilog +builder.Host.UseSerilog(); + +// Configure filtering options from configuration +builder.Services.Configure(builder.Configuration.GetSection(FilteringOptions.SectionName)); + +// Register core services +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); + +// Configure multiple authentication schemes +ConfigureAuthentication(builder); + +// Configure authorization and filtering +ConfigureFiltering(builder); + +// Configure MCP server with tools +builder.Services.AddMcpServer() + .WithHttpTransport() + .WithTools() + .WithTools() + .WithTools() + .WithTools(); + +// Add telemetry for monitoring +builder.Services.AddOpenTelemetry() + .WithTracing(b => b + .AddSource("*") + .AddAspNetCoreInstrumentation() + .AddHttpClientInstrumentation()) + .WithMetrics(b => b + .AddMeter("*") + .AddAspNetCoreInstrumentation() + .AddHttpClientInstrumentation()) + .WithLogging() + .UseOtlpExporter(); + +// Add CORS for web clients +builder.Services.AddCors(options => +{ + options.AddDefaultPolicy(policy => + { + policy.AllowAnyOrigin() + .AllowAnyMethod() + .AllowAnyHeader() + .WithExposedHeaders("WWW-Authenticate"); + }); +}); + +var app = builder.Build(); + +// Configure request pipeline +if (app.Environment.IsDevelopment()) +{ + app.UseDeveloperExceptionPage(); +} + +app.UseSerilogRequestLogging(); +app.UseCors(); + +// Authentication must come before authorization +app.UseAuthentication(); +app.UseAuthorization(); + +// Map MCP endpoints +app.MapMcp(); + +// Add health check endpoint +app.MapGet("/health", () => new { + Status = "healthy", + Timestamp = DateTime.UtcNow, + Environment = app.Environment.EnvironmentName, + Version = "1.0.0" +}); + +// Add filter management endpoints (for demo purposes) +app.MapGet("/admin/filters/status", async (IServiceProvider services) => +{ + var toolAuthService = services.GetRequiredService(); + + return new + { + Message = "Filter management endpoints would be implemented here", + Timestamp = DateTime.UtcNow, + FiltersRegistered = "Multiple filters active (see configuration)" + }; +}).RequireAuthorization("AdminPolicy"); + +// Add feature flag management endpoints +app.MapGet("/admin/feature-flags", async (IFeatureFlagService featureFlagService) => +{ + var flags = await featureFlagService.GetAllFlagsAsync("admin"); + return new { FeatureFlags = flags, Timestamp = DateTime.UtcNow }; +}).RequireAuthorization("AdminPolicy"); + +app.MapPost("/admin/feature-flags/{flagName}", async ( + string flagName, + bool enabled, + IFeatureFlagService featureFlagService) => +{ + await featureFlagService.SetFlagAsync(flagName, enabled); + return new { FlagName = flagName, Enabled = enabled, UpdatedAt = DateTime.UtcNow }; +}).RequireAuthorization("AdminPolicy"); + +Log.Information("Starting Dynamic Tool Filtering MCP Server on {Environment}", app.Environment.EnvironmentName); + +app.Run(); + +static void ConfigureAuthentication(WebApplicationBuilder builder) +{ + var authBuilder = builder.Services.AddAuthentication(options => + { + options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; + options.DefaultChallengeScheme = JwtBearerDefaults.AuthenticationScheme; + }); + + // JWT Bearer authentication + authBuilder.AddJwtBearer(JwtBearerDefaults.AuthenticationScheme, options => + { + var jwtSettings = builder.Configuration.GetSection("Jwt"); + var secretKey = jwtSettings["SecretKey"] ?? "your-256-bit-secret-key-here-make-it-secure"; + var issuer = jwtSettings["Issuer"] ?? "dynamic-tool-filtering"; + var audience = jwtSettings["Audience"] ?? "mcp-api"; + + options.TokenValidationParameters = new TokenValidationParameters + { + ValidateIssuer = true, + ValidateAudience = true, + ValidateLifetime = true, + ValidateIssuerSigningKey = true, + ValidIssuer = issuer, + ValidAudience = audience, + IssuerSigningKey = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(secretKey)), + ClockSkew = TimeSpan.FromMinutes(5) + }; + + options.Events = new JwtBearerEvents + { + OnAuthenticationFailed = context => + { + Log.Warning("JWT authentication failed: {Error}", context.Exception.Message); + return Task.CompletedTask; + }, + OnTokenValidated = context => + { + var userId = context.Principal?.FindFirst(ClaimTypes.NameIdentifier)?.Value; + Log.Debug("JWT token validated for user: {UserId}", userId); + return Task.CompletedTask; + } + }; + }); + + // API Key authentication (custom scheme) + authBuilder.AddScheme( + "ApiKey", options => + { + options.HeaderName = "X-API-Key"; + options.QueryStringKey = "apikey"; + }); + + // Configure authorization policies + builder.Services.AddAuthorization(options => + { + options.AddPolicy("AdminPolicy", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(ClaimTypes.Role, "admin", "super_admin"); + }); + + options.AddPolicy("PremiumPolicy", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(ClaimTypes.Role, "premium", "admin", "super_admin"); + }); + + options.AddPolicy("UserPolicy", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(ClaimTypes.Role, "user", "premium", "admin", "super_admin"); + }); + }); +} + +static void ConfigureFiltering(WebApplicationBuilder builder) +{ + // Register the tool authorization service + builder.Services.AddSingleton(); + + // Register all filter implementations with proper ordering (priority-based) + builder.Services.AddSingleton(); // Priority 50 - highest + builder.Services.AddSingleton(); // Priority 75 + builder.Services.AddSingleton(); // Priority 100 + builder.Services.AddSingleton(); // Priority 150 + builder.Services.AddSingleton(); // Priority 200 + builder.Services.AddSingleton(); // Priority 300 - lowest + + // Configure the tool authorization service with all filters + builder.Services.AddSingleton(serviceProvider => + { + var authService = new ToolAuthorizationService(); + var filters = serviceProvider.GetServices(); + + // Register filters in priority order + foreach (var filter in filters.OrderBy(f => f.Priority)) + { + authService.RegisterFilter(filter); + Log.Information("Registered tool filter: {FilterType} with priority {Priority}", + filter.GetType().Name, filter.Priority); + } + + return authService; + }); +} + +// Custom API Key authentication handler +public class ApiKeyAuthenticationHandler : AuthenticationHandler +{ + private readonly ILogger _logger; + + public ApiKeyAuthenticationHandler( + IOptionsMonitor options, + ILoggerFactory loggerFactory, + UrlEncoder encoder) + : base(options, loggerFactory, encoder) + { + _logger = loggerFactory.CreateLogger(); + } + + protected override Task HandleAuthenticateAsync() + { + // Try to get API key from header + var apiKey = Request.Headers[Options.HeaderName].FirstOrDefault(); + + // If not in header, try query string + if (string.IsNullOrEmpty(apiKey)) + { + apiKey = Request.Query[Options.QueryStringKey].FirstOrDefault(); + } + + if (string.IsNullOrEmpty(apiKey)) + { + return Task.FromResult(AuthenticateResult.NoResult()); + } + + // Validate API key (in production, use secure storage and proper validation) + var validApiKeys = new Dictionary + { + { "demo-guest-key", ("guest-user", "guest", new[] { "basic:tools" }) }, + { "demo-user-key", ("demo-user", "user", new[] { "user:tools", "read:tools", "basic:tools" }) }, + { "demo-premium-key", ("premium-user", "premium", new[] { "premium:tools", "user:tools", "read:tools", "basic:tools" }) }, + { "demo-admin-key", ("admin-user", "admin", new[] { "admin:tools", "premium:tools", "user:tools", "read:tools", "basic:tools" }) } + }; + + if (!validApiKeys.TryGetValue(apiKey, out var keyInfo)) + { + _logger.LogWarning("Invalid API key attempted: {ApiKey}", apiKey[..Math.Min(8, apiKey.Length)] + "..."); + return Task.FromResult(AuthenticateResult.Fail("Invalid API key")); + } + + // Create claims for the authenticated user + var claims = new List + { + new(ClaimTypes.NameIdentifier, keyInfo.UserId), + new(ClaimTypes.Name, keyInfo.UserId), + new(ClaimTypes.Role, keyInfo.Role), + new(ClaimTypes.AuthenticationMethod, "ApiKey") + }; + + // Add scope claims + foreach (var scope in keyInfo.Scopes) + { + claims.Add(new Claim("scope", scope)); + } + + var identity = new ClaimsIdentity(claims, Scheme.Name); + var principal = new ClaimsPrincipal(identity); + var ticket = new AuthenticationTicket(principal, Scheme.Name); + + _logger.LogDebug("API key authentication successful for user: {UserId}, Role: {Role}", keyInfo.UserId, keyInfo.Role); + + return Task.FromResult(AuthenticateResult.Success(ticket)); + } + + protected override Task HandleChallengeAsync(AuthenticationProperties properties) + { + Response.Headers.Add("WWW-Authenticate", $"ApiKey realm=\"mcp-api\", parameter=\"{Options.HeaderName}\""); + return base.HandleChallengeAsync(properties); + } +} + +public class ApiKeyAuthenticationSchemeOptions : AuthenticationSchemeOptions +{ + public string HeaderName { get; set; } = "X-API-Key"; + public string QueryStringKey { get; set; } = "apikey"; +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Properties/launchSettings.json b/samples/DynamicToolFiltering/Properties/launchSettings.json new file mode 100644 index 00000000..2cf9ec42 --- /dev/null +++ b/samples/DynamicToolFiltering/Properties/launchSettings.json @@ -0,0 +1,151 @@ +{ + "profiles": { + "DevelopmentMode": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "http://localhost:8080", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development", + "ASPNETCORE_URLS": "http://localhost:8080", + "Filtering__Enabled": "true", + "Filtering__RoleBased__Enabled": "true", + "Filtering__TimeBased__Enabled": "false", + "Filtering__ScopeBased__Enabled": "true", + "Filtering__RateLimiting__Enabled": "true", + "Filtering__TenantIsolation__Enabled": "false", + "Filtering__BusinessLogic__Enabled": "true", + "Filtering__BusinessLogic__FeatureFlags__Enabled": "true", + "Filtering__BusinessLogic__QuotaManagement__Enabled": "false", + "Filtering__BusinessLogic__EnvironmentRestrictions__Enabled": "true" + } + }, + "StagingMode": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "http://localhost:8080", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Staging", + "ASPNETCORE_URLS": "http://localhost:8080", + "Filtering__Enabled": "true", + "Filtering__RoleBased__Enabled": "true", + "Filtering__TimeBased__Enabled": "true", + "Filtering__ScopeBased__Enabled": "true", + "Filtering__RateLimiting__Enabled": "true", + "Filtering__TenantIsolation__Enabled": "true", + "Filtering__BusinessLogic__Enabled": "true", + "Filtering__BusinessLogic__FeatureFlags__Enabled": "true", + "Filtering__BusinessLogic__QuotaManagement__Enabled": "true", + "Filtering__BusinessLogic__EnvironmentRestrictions__Enabled": "true" + } + }, + "ProductionMode": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "http://localhost:8080", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Production", + "ASPNETCORE_URLS": "http://localhost:8080", + "Filtering__Enabled": "true", + "Filtering__RoleBased__Enabled": "true", + "Filtering__TimeBased__Enabled": "true", + "Filtering__ScopeBased__Enabled": "true", + "Filtering__RateLimiting__Enabled": "true", + "Filtering__TenantIsolation__Enabled": "true", + "Filtering__BusinessLogic__Enabled": "true", + "Filtering__BusinessLogic__FeatureFlags__Enabled": "true", + "Filtering__BusinessLogic__QuotaManagement__Enabled": "true", + "Filtering__BusinessLogic__EnvironmentRestrictions__Enabled": "true" + } + }, + "MinimalFilteringMode": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "http://localhost:8080", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development", + "ASPNETCORE_URLS": "http://localhost:8080", + "Filtering__Enabled": "true", + "Filtering__RoleBased__Enabled": "true", + "Filtering__TimeBased__Enabled": "false", + "Filtering__ScopeBased__Enabled": "false", + "Filtering__RateLimiting__Enabled": "false", + "Filtering__TenantIsolation__Enabled": "false", + "Filtering__BusinessLogic__Enabled": "false" + } + }, + "NoFilteringMode": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "http://localhost:8080", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development", + "ASPNETCORE_URLS": "http://localhost:8080", + "Filtering__Enabled": "false" + } + }, + "TenantDemoMode": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "http://localhost:8080", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development", + "ASPNETCORE_URLS": "http://localhost:8080", + "Filtering__Enabled": "true", + "Filtering__RoleBased__Enabled": "false", + "Filtering__TimeBased__Enabled": "false", + "Filtering__ScopeBased__Enabled": "false", + "Filtering__RateLimiting__Enabled": "false", + "Filtering__TenantIsolation__Enabled": "true", + "Filtering__BusinessLogic__Enabled": "false" + } + }, + "RateLimitingDemoMode": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "http://localhost:8080", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development", + "ASPNETCORE_URLS": "http://localhost:8080", + "Filtering__Enabled": "true", + "Filtering__RoleBased__Enabled": "true", + "Filtering__TimeBased__Enabled": "false", + "Filtering__ScopeBased__Enabled": "false", + "Filtering__RateLimiting__Enabled": "true", + "Filtering__RateLimiting__WindowMinutes": "1", + "Filtering__RateLimiting__RoleLimits__guest": "3", + "Filtering__RateLimiting__RoleLimits__user": "10", + "Filtering__RateLimiting__RoleLimits__premium": "25", + "Filtering__RateLimiting__RoleLimits__admin": "100", + "Filtering__TenantIsolation__Enabled": "false", + "Filtering__BusinessLogic__Enabled": "false" + } + }, + "BusinessHoursDemoMode": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "http://localhost:8080", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development", + "ASPNETCORE_URLS": "http://localhost:8080", + "Filtering__Enabled": "true", + "Filtering__RoleBased__Enabled": "false", + "Filtering__TimeBased__Enabled": "true", + "Filtering__TimeBased__BusinessHours__Enabled": "true", + "Filtering__TimeBased__BusinessHours__StartTime": "09:00", + "Filtering__TimeBased__BusinessHours__EndTime": "17:00", + "Filtering__ScopeBased__Enabled": "false", + "Filtering__RateLimiting__Enabled": "false", + "Filtering__TenantIsolation__Enabled": "false", + "Filtering__BusinessLogic__Enabled": "false" + } + } + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/README.md b/samples/DynamicToolFiltering/README.md new file mode 100644 index 00000000..e70438b7 --- /dev/null +++ b/samples/DynamicToolFiltering/README.md @@ -0,0 +1,670 @@ +# Dynamic Tool Filtering MCP Server Sample + +This comprehensive sample demonstrates advanced tool filtering and authorization capabilities in the MCP (Model Context Protocol) C# SDK. It showcases how to implement sophisticated access control systems with multiple filter types, authentication schemes, and business logic constraints. + +## Overview + +The Dynamic Tool Filtering sample illustrates real-world scenarios where different users need different levels of access to tools based on roles, time constraints, quotas, feature flags, and business rules. It's designed to be educational while demonstrating production-ready patterns. + +## Features + +### 🔐 Multiple Filter Types + +- **Role-Based Filtering**: Hierarchical role system (guest → user → premium → admin → super_admin) +- **Time-Based Filtering**: Business hours restrictions and maintenance windows +- **Scope-Based Filtering**: OAuth2-style scope checking for fine-grained permissions +- **Rate Limiting**: Per-user and per-tool rate limits with sliding/fixed windows +- **Tenant Isolation**: Multi-tenant tool access with tenant-specific configurations +- **Business Logic Filtering**: Feature flags, quota management, and environment restrictions + +### 🛠️ Tool Categories + +The sample includes four categories of tools representing different security levels: + +1. **Public Tools** (`PublicTools.cs`): Available to all users without authentication +2. **User Tools** (`UserTools.cs`): Require basic authentication and user role +3. **Admin Tools** (`AdminTools.cs`): Require administrative privileges +4. **Premium Tools** (`PremiumTools.cs`): Advanced functionality requiring premium access + +### 🔑 Authentication Methods + +- **JWT Bearer Tokens**: Standard OAuth2/OIDC authentication with claims +- **API Key Authentication**: Simple header or query-based authentication +- **Role-based Claims**: Hierarchical role system with inheritance +- **Scope Claims**: Granular permission scopes for different operations + +## Architecture + +### Filter Priority System + +Filters execute in priority order (lowest number = highest priority): + +1. **Rate Limiting** (Priority 50): Enforces usage quotas first +2. **Tenant Isolation** (Priority 75): Multi-tenant access control +3. **Role-Based** (Priority 100): User role verification +4. **Scope-Based** (Priority 150): OAuth2 scope checking +5. **Time-Based** (Priority 200): Business hours and maintenance +6. **Business Logic** (Priority 300): Feature flags and environment rules + +### Filter Flow Architecture + +``` +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ MCP Client │────│ Authentication │────│ Authorization │ +│ │ │ & Identity │ │ Filters │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────────────┐ +│ FILTER CHAIN EXECUTION │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ 1. Rate Limiting Filter (Priority 50) │ │ +│ │ ├─ Check per-user rate limits │ │ +│ │ ├─ Validate time windows │ │ +│ │ └─ Record usage statistics │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ 2. Tenant Isolation Filter (Priority 75) │ │ +│ │ ├─ Validate tenant status │ │ +│ │ ├─ Check tenant tool allowlist │ │ +│ │ └─ Apply tenant-specific rate limits │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ 3. Role-Based Filter (Priority 100) │ │ +│ │ ├─ Extract user roles from claims │ │ +│ │ ├─ Check hierarchical permissions │ │ +│ │ └─ Validate tool access patterns │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ 4. Scope-Based Filter (Priority 150) │ │ +│ │ ├─ Parse OAuth2 scopes │ │ +│ │ ├─ Match required tool scopes │ │ +│ │ └─ Validate scope hierarchy │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ 5. Time-Based Filter (Priority 200) │ │ +│ │ ├─ Check business hours │ │ +│ │ ├─ Validate maintenance windows │ │ +│ │ └─ Apply timezone calculations │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ 6. Business Logic Filter (Priority 300) │ │ +│ │ ├─ Check feature flags │ │ +│ │ ├─ Validate quotas │ │ +│ │ └─ Apply environment restrictions │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└───────────────────────────────────────────────────────────────────┘ + │ + ▼ + ┌─────────────────┐ + │ Tool Execution │ + │ or Rejection │ + └─────────────────┘ +``` + +### Component Interaction Diagram + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ MCP SERVER │ +│ │ +│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │ +│ │ Public Tools │ │ User Tools │ │ Premium Tools │ │ +│ │ │ │ │ │ │ │ +│ │ • echo │ │ • get_profile │ │ • secure_random │ │ +│ │ • system_info │ │ • hash_calc │ │ • text_analysis │ │ +│ │ • utc_time │ │ • uuid_gen │ │ • password_gen │ │ +│ └─────────────────┘ └─────────────────┘ └─────────────────┘ │ +│ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ TOOL AUTHORIZATION SERVICE │ │ +│ │ │ │ +│ │ • Filter registration & management │ │ +│ │ • Priority-based execution │ │ +│ │ • Result aggregation & challenge generation │ │ +│ │ • Context enrichment │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +│ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ EXTERNAL SERVICES │ │ +│ │ │ │ +│ │ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │ │ +│ │ │ Rate Limiting │ │ Feature Flags │ │ Quota Service │ │ │ +│ │ │ Service │ │ Service │ │ │ │ │ +│ │ │ │ │ │ │ │ │ │ +│ │ │ • Usage cache │ │ • Flag state │ │ • Usage track │ │ │ +│ │ │ • Time windows│ │ • A/B testing │ │ • Limits mgmt │ │ │ +│ │ └───────────────┘ └───────────────┘ └───────────────┘ │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +### Project Structure + +``` +DynamicToolFiltering/ +├── Authorization/ +│ └── Filters/ # Filter implementations +│ ├── BusinessLogicFilter.cs +│ ├── RateLimitingToolFilter.cs +│ ├── RoleBasedToolFilter.cs +│ ├── ScopeBasedToolFilter.cs +│ ├── TenantIsolationFilter.cs +│ └── TimeBasedToolFilter.cs +├── Configuration/ # Configuration models +│ └── FilteringOptions.cs +├── Models/ # Data models +│ ├── FilterResult.cs +│ ├── ToolExecutionContext.cs +│ ├── UsageStatistics.cs +│ └── UserInfo.cs +├── Services/ # Supporting services +│ ├── IFeatureFlagService.cs +│ ├── IQuotaService.cs +│ ├── IRateLimitingService.cs +│ ├── InMemoryFeatureFlagService.cs +│ ├── InMemoryQuotaService.cs +│ └── InMemoryRateLimitingService.cs +├── Tools/ # Tool implementations +│ ├── AdminTools.cs +│ ├── PremiumTools.cs +│ ├── PublicTools.cs +│ └── UserTools.cs +├── Properties/ # Launch profiles +│ └── launchSettings.json +├── docs/ # Enhanced documentation +│ ├── ARCHITECTURE.md +│ ├── DEPLOYMENT.md +│ ├── PERFORMANCE.md +│ └── TROUBLESHOOTING.md +├── scripts/ # Automation scripts +│ ├── test-all.sh +│ ├── test-all.ps1 +│ └── setup-dev.sh +├── .vscode/ # VS Code configuration +│ ├── launch.json +│ ├── settings.json +│ └── tasks.json +├── appsettings.*.json # Configuration files +├── Dockerfile # Docker configuration +├── docker-compose.yml # Multi-service setup +├── Program.cs +├── README.md +├── TESTING_GUIDE.md +└── INTEGRATION_EXAMPLES.md +``` + +## Quick Start + +### 1. Prerequisites + +- .NET 9.0 SDK or later +- Git (for cloning the repository) +- curl or Postman (for testing) +- Optional: Docker Desktop (for containerized deployment) +- Optional: Visual Studio Code with C# extension + +### 2. One-Line Setup + +```bash +# Clone, build, and run in development mode +git clone https://github.com/microsoft/mcp-csharp-sdk.git && \ +cd mcp-csharp-sdk/samples/DynamicToolFiltering && \ +dotnet run --launch-profile DevelopmentMode +``` + +### 3. Alternative Setup Methods + +#### Option A: Manual Setup + +```bash +# Navigate to the sample directory +cd samples/DynamicToolFiltering + +# Restore dependencies +dotnet restore + +# Build the project +dotnet build + +# Run with development profile +dotnet run --launch-profile DevelopmentMode +``` + +#### Option B: Docker Setup + +```bash +# Build Docker image +docker build -t dynamic-tool-filtering . + +# Run container +docker run -p 8080:8080 dynamic-tool-filtering +``` + +#### Option C: Development Environment Setup + +```bash +# Run the setup script (creates .vscode config, installs tools) +./scripts/setup-dev.sh + +# Open in VS Code with debugging ready +code . +``` + +### 4. Verify Installation + +```bash +# Check server health +curl http://localhost:8080/health + +# Expected response: +# { +# "Status": "healthy", +# "Timestamp": "2024-01-01T12:00:00.000Z", +# "Environment": "Development", +# "Version": "1.0.0" +# } +``` + +### 5. Quick API Test Suite + +The sample includes predefined API keys for testing different user roles: + +| Role | API Key | Available Tools | Rate Limit | +|------|---------|-----------------|------------| +| Guest | `demo-guest-key` | Public tools only | 20/hour | +| User | `demo-user-key` | Public + User tools | 100/hour | +| Premium | `demo-premium-key` | Public + User + Premium tools | 500/hour | +| Admin | `demo-admin-key` | All tools | 1000/hour | + +#### Test Tool Visibility by Role + +```bash +# Guest user - should see only basic tools +curl -H "X-API-Key: demo-guest-key" \ + http://localhost:8080/mcp/v1/tools + +# User role - should see user-level tools +curl -H "X-API-Key: demo-user-key" \ + http://localhost:8080/mcp/v1/tools + +# Premium user - should see premium tools +curl -H "X-API-Key: demo-premium-key" \ + http://localhost:8080/mcp/v1/tools + +# Admin user - should see all tools +curl -H "X-API-Key: demo-admin-key" \ + http://localhost:8080/mcp/v1/tools +``` + +#### Test Tool Execution + +```bash +# Test 1: Public tool (no authentication required) +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "Content-Type: application/json" \ + -d '{ + "name": "echo", + "arguments": { + "message": "Hello Dynamic Filtering!" + } + }' + +# Test 2: User tool (requires authentication) +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-user-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "get_user_profile", + "arguments": {} + }' + +# Test 3: Premium tool (requires premium role) +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-premium-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "premium_generate_secure_random", + "arguments": { + "byteCount": 32, + "format": "hex" + } + }' + +# Test 4: Admin tool (requires admin role) +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-admin-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "admin_get_system_diagnostics", + "arguments": {} + }' + +# Test 5: Authorization failure (user trying admin tool) +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-user-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "admin_get_system_diagnostics", + "arguments": {} + }' +# Expected: HTTP 401 with authorization error +``` + +### 6. Run Automated Test Suite + +```bash +# Run comprehensive test suite (bash) +./scripts/test-all.sh + +# Or PowerShell version +.\scripts\test-all.ps1 + +# Run specific test categories +./scripts/test-all.sh --category authentication +./scripts/test-all.sh --category rate-limiting +./scripts/test-all.sh --category performance +``` + +## Launch Profiles + +The sample includes multiple launch profiles for different scenarios: + +### Development Profiles + +- **DevelopmentMode**: Basic filtering with relaxed rate limits +- **NoFilteringMode**: All filtering disabled for testing +- **MinimalFilteringMode**: Only role-based filtering enabled + +### Feature-Specific Profiles + +- **TenantDemoMode**: Demonstrates multi-tenant access control +- **RateLimitingDemoMode**: Shows rate limiting with strict limits (1-minute windows) +- **BusinessHoursDemoMode**: Time-based restrictions (9 AM - 5 PM weekdays) + +### Environment Profiles + +- **StagingMode**: All filters enabled with moderate settings +- **ProductionMode**: Strict security with all filters active + +## Configuration + +### Environment Variables + +Override any configuration using environment variables: + +```bash +# Enable/disable specific filters +export Filtering__RoleBased__Enabled=true +export Filtering__RateLimiting__Enabled=true +export Filtering__TimeBased__Enabled=false + +# Customize rate limits +export Filtering__RateLimiting__WindowMinutes=60 +export Filtering__RateLimiting__RoleLimits__user=100 + +# Business hours (UTC times) +export Filtering__TimeBased__BusinessHours__StartTime="09:00" +export Filtering__TimeBased__BusinessHours__EndTime="17:00" +``` + +### JWT Configuration + +For JWT authentication, configure the following: + +```json +{ + "Jwt": { + "SecretKey": "your-256-bit-secret-key", + "Issuer": "your-issuer", + "Audience": "your-audience", + "ExpirationMinutes": 60 + } +} +``` + +## Filter Implementations + +### Role-Based Filter + +Implements hierarchical role checking with pattern matching: + +- Supports glob patterns for tool names (`admin_*`, `premium_*`) +- Hierarchical inheritance (admin inherits user permissions) +- Configurable role mappings + +### Rate Limiting Filter + +Implements quota management: + +- Per-user and per-tool rate limits +- Sliding or fixed time windows +- Role-based default limits with tool-specific overrides +- Automatic cleanup of old usage records + +### Scope-Based Filter + +OAuth2-style scope checking: + +- Space-separated scopes in JWT claims +- Hierarchical scope inheritance +- Wildcard scope matching +- Proper OAuth2 error responses + +### Time-Based Filter + +Business hours and maintenance windows: + +- Configurable business hours per timezone +- Maintenance window blocking +- Tool-specific time restrictions + +### Tenant Isolation Filter + +Multi-tenant access control: + +- Tenant-specific tool allowlists/denylists +- Custom rate limits per tenant +- Tenant activation status checking + +### Business Logic Filter + +Advanced business rules: + +- Feature flag integration +- Quota management with periodic resets +- Environment-specific restrictions (dev/staging/prod) + +## Error Handling + +The sample demonstrates proper HTTP error responses with WWW-Authenticate headers: + +### 401 Unauthorized Responses + +```http +HTTP/1.1 401 Unauthorized +WWW-Authenticate: Bearer realm="mcp-api", scope="admin:tools", error="insufficient_scope" +Content-Type: application/json + +{ + "error": { + "code": -32002, + "message": "Access denied for tool 'admin_tool': Insufficient scope", + "data": { + "ToolName": "admin_tool", + "Reason": "Insufficient scope", + "HttpStatusCode": 401, + "RequiresAuthentication": true + } + } +} +``` + +### Custom Challenge Headers + +Different filter types generate appropriate challenge headers: + +- **Bearer**: OAuth2 Bearer token challenges with scope information +- **Basic**: Basic authentication challenges +- **ApiKey**: Custom API key challenges +- **Role**: Role-based access challenges +- **Tenant**: Tenant-specific challenges + +## Production Considerations + +### Security + +- Use secure JWT secret keys (256-bit minimum) +- Implement proper token storage and rotation +- Use HTTPS in production +- Consider rate limiting at the infrastructure level +- Implement proper audit logging + +### Performance + +- Use distributed caches (Redis) for rate limiting in production +- Implement proper database storage for quotas and usage tracking +- Consider caching authorization decisions +- Monitor filter performance and adjust priorities + +### Scalability + +- Use external feature flag services (LaunchDarkly, Azure App Configuration) +- Implement distributed quota management +- Consider eventual consistency for rate limiting +- Use proper database indexing for usage queries + +## Advanced Usage + +### Custom Filter Implementation + +Create custom filters by implementing `IToolFilter`: + +```csharp +public class CustomBusinessFilter : IToolFilter +{ + public int Priority => 250; // Between time-based and business logic + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken) + { + // Custom visibility logic + return true; + } + + public async Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken) + { + // Custom authorization logic + return AuthorizationResult.Allow("Custom filter passed"); + } +} +``` + +### Dynamic Filter Registration + +Filters can be registered and unregistered at runtime: + +```csharp +var authService = serviceProvider.GetRequiredService(); +var customFilter = new CustomBusinessFilter(); +authService.RegisterFilter(customFilter); +``` + +### Integration with External Services + +The sample demonstrates integration patterns for: + +- Feature flag services +- Rate limiting backends +- Quota management systems +- Tenant management APIs + +## Testing + +The sample includes comprehensive testing scenarios: + +1. **Authentication Testing**: Test different API keys and JWT tokens +2. **Authorization Testing**: Verify role and scope-based access +3. **Rate Limiting Testing**: Exceed limits and verify blocking +4. **Time-Based Testing**: Test business hours restrictions +5. **Feature Flag Testing**: Toggle features and verify access +6. **Error Handling Testing**: Verify proper error responses + +## Troubleshooting + +### Common Issues + +1. **Tools not visible**: Check filter configurations and user roles +2. **Rate limit errors**: Verify rate limiting settings and time windows +3. **Authentication failures**: Check API keys and JWT configuration +4. **Time-based restrictions**: Verify timezone settings and business hours + +### Debugging + +Enable debug logging to see filter execution: + +```json +{ + "Logging": { + "LogLevel": { + "DynamicToolFiltering": "Debug" + } + } +} +``` + +### Health Checks + +Use the health endpoint to verify service status: + +```bash +curl http://localhost:8080/health +``` + +## Learning Objectives + +This sample demonstrates: + +1. **Filter Architecture**: How to design and implement filter chains +2. **Authorization Patterns**: Multiple authentication and authorization strategies +3. **Configuration Management**: Environment-specific configurations +4. **Error Handling**: Proper HTTP error responses and challenges +5. **Performance Considerations**: Efficient filter design and caching +6. **Security Best Practices**: Secure authentication and authorization +7. **Testing Strategies**: Comprehensive testing of authorization systems + +## Next Steps + +- Implement persistent storage for rate limiting and quotas +- Add integration with external identity providers +- Implement audit logging for all authorization decisions +- Add metrics and monitoring for filter performance +- Create admin APIs for dynamic filter management +- Implement filter testing framework + +## Contributing + +This sample is designed to be educational and demonstrative. Feel free to: + +- Extend with additional filter types +- Add integration with real external services +- Improve error handling and edge cases +- Add more comprehensive testing +- Enhance documentation and examples + +## License + +This sample is part of the MCP C# SDK and follows the same license terms. \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Services/IFeatureFlagService.cs b/samples/DynamicToolFiltering/Services/IFeatureFlagService.cs new file mode 100644 index 00000000..b045d29c --- /dev/null +++ b/samples/DynamicToolFiltering/Services/IFeatureFlagService.cs @@ -0,0 +1,41 @@ +namespace DynamicToolFiltering.Services; + +/// +/// Service for managing feature flags. +/// +public interface IFeatureFlagService +{ + /// + /// Checks if a feature flag is enabled for a specific user. + /// + /// The feature flag name. + /// The user ID. + /// Cancellation token. + /// True if the feature is enabled, false otherwise. + Task IsEnabledAsync(string flagName, string userId, CancellationToken cancellationToken = default); + + /// + /// Gets all feature flags and their states for a user. + /// + /// The user ID. + /// Cancellation token. + /// Dictionary of feature flag names and their enabled states. + Task> GetAllFlagsAsync(string userId, CancellationToken cancellationToken = default); + + /// + /// Sets the state of a feature flag (for testing/admin purposes). + /// + /// The feature flag name. + /// Whether the flag should be enabled. + /// Optional user ID for user-specific flags. + /// Cancellation token. + Task SetFlagAsync(string flagName, bool enabled, string? userId = null, CancellationToken cancellationToken = default); + + /// + /// Gets the rollout percentage for a feature flag. + /// + /// The feature flag name. + /// Cancellation token. + /// The rollout percentage (0-100). + Task GetRolloutPercentageAsync(string flagName, CancellationToken cancellationToken = default); +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Services/IQuotaService.cs b/samples/DynamicToolFiltering/Services/IQuotaService.cs new file mode 100644 index 00000000..14ba06a8 --- /dev/null +++ b/samples/DynamicToolFiltering/Services/IQuotaService.cs @@ -0,0 +1,75 @@ +namespace DynamicToolFiltering.Services; + +/// +/// Service for managing user quotas and usage tracking. +/// +public interface IQuotaService +{ + /// + /// Checks if a user has available quota for a specific tool. + /// + /// The user ID. + /// The user role. + /// The tool name. + /// Cancellation token. + /// True if quota is available, false otherwise. + Task HasAvailableQuotaAsync(string userId, string userRole, string toolName, CancellationToken cancellationToken = default); + + /// + /// Consumes quota for a tool usage. + /// + /// The user ID. + /// The tool name. + /// The quota cost to consume. + /// Cancellation token. + Task ConsumeQuotaAsync(string userId, string toolName, int cost, CancellationToken cancellationToken = default); + + /// + /// Gets the current quota usage for a user. + /// + /// The user ID. + /// Cancellation token. + /// The current usage amount. + Task GetCurrentUsageAsync(string userId, CancellationToken cancellationToken = default); + + /// + /// Gets the quota limit for a user based on their role. + /// + /// The user ID. + /// The user role. + /// Cancellation token. + /// The quota limit (-1 for unlimited). + Task GetQuotaLimitAsync(string userId, string userRole, CancellationToken cancellationToken = default); + + /// + /// Gets the remaining quota for a user. + /// + /// The user ID. + /// The user role. + /// Cancellation token. + /// The remaining quota amount. + Task GetRemainingQuotaAsync(string userId, string userRole, CancellationToken cancellationToken = default); + + /// + /// Gets the date when the user's quota will reset. + /// + /// The user ID. + /// Cancellation token. + /// The quota reset date. + Task GetQuotaResetDateAsync(string userId, CancellationToken cancellationToken = default); + + /// + /// Resets quota for a user (for admin purposes or period rollover). + /// + /// The user ID. + /// Cancellation token. + Task ResetQuotaAsync(string userId, CancellationToken cancellationToken = default); + + /// + /// Gets detailed quota usage breakdown by tool for a user. + /// + /// The user ID. + /// Cancellation token. + /// Dictionary of tool names and their usage amounts. + Task> GetUsageBreakdownAsync(string userId, CancellationToken cancellationToken = default); +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Services/IRateLimitingService.cs b/samples/DynamicToolFiltering/Services/IRateLimitingService.cs new file mode 100644 index 00000000..b2a83fa3 --- /dev/null +++ b/samples/DynamicToolFiltering/Services/IRateLimitingService.cs @@ -0,0 +1,40 @@ +namespace DynamicToolFiltering.Services; + +/// +/// Service for managing rate limiting and usage tracking. +/// +public interface IRateLimitingService +{ + /// + /// Gets the current usage count for a user/tool combination within the specified time window. + /// + /// The user ID. + /// The tool name. + /// The start of the time window. + /// Cancellation token. + /// The current usage count. + Task GetUsageCountAsync(string userId, string toolName, DateTime windowStart, CancellationToken cancellationToken = default); + + /// + /// Records a tool usage for rate limiting tracking. + /// + /// The user ID. + /// The tool name. + /// The timestamp of the usage. + /// Cancellation token. + Task RecordUsageAsync(string userId, string toolName, DateTime timestamp, CancellationToken cancellationToken = default); + + /// + /// Cleans up old usage records that are outside the retention window. + /// + /// Cancellation token. + Task CleanupOldRecordsAsync(CancellationToken cancellationToken = default); + + /// + /// Gets usage statistics for a user. + /// + /// The user ID. + /// Cancellation token. + /// Usage statistics. + Task> GetUsageStatisticsAsync(string userId, CancellationToken cancellationToken = default); +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Services/InMemoryFeatureFlagService.cs b/samples/DynamicToolFiltering/Services/InMemoryFeatureFlagService.cs new file mode 100644 index 00000000..a9594e1c --- /dev/null +++ b/samples/DynamicToolFiltering/Services/InMemoryFeatureFlagService.cs @@ -0,0 +1,164 @@ +using System.Collections.Concurrent; +using System.Security.Cryptography; +using System.Text; + +namespace DynamicToolFiltering.Services; + +/// +/// In-memory implementation of feature flag service with percentage rollout support. +/// Note: This is for demonstration purposes. In production, use a dedicated feature flag service. +/// +public class InMemoryFeatureFlagService : IFeatureFlagService +{ + private readonly ConcurrentDictionary _flags = new(); + private readonly ILogger _logger; + + public InMemoryFeatureFlagService(ILogger logger) + { + _logger = logger; + InitializeDefaultFlags(); + } + + public Task IsEnabledAsync(string flagName, string userId, CancellationToken cancellationToken = default) + { + if (!_flags.TryGetValue(flagName, out var flag)) + { + _logger.LogDebug("Feature flag {FlagName} not found, returning false", flagName); + return Task.FromResult(false); + } + + // Check if globally enabled/disabled + if (!flag.Enabled) + { + return Task.FromResult(false); + } + + // Check user-specific overrides + if (flag.UserOverrides.TryGetValue(userId, out var userOverride)) + { + _logger.LogDebug("Feature flag {FlagName} has user override for {UserId}: {Enabled}", flagName, userId, userOverride); + return Task.FromResult(userOverride); + } + + // Check percentage rollout + if (flag.RolloutPercentage < 100) + { + var userHash = GetUserHash(userId, flagName); + var enabled = userHash < flag.RolloutPercentage; + _logger.LogDebug("Feature flag {FlagName} percentage rollout for {UserId}: {Percentage}% -> {Enabled}", + flagName, userId, flag.RolloutPercentage, enabled); + return Task.FromResult(enabled); + } + + return Task.FromResult(true); + } + + public Task> GetAllFlagsAsync(string userId, CancellationToken cancellationToken = default) + { + var result = new Dictionary(); + + foreach (var kvp in _flags) + { + var flagName = kvp.Key; + var isEnabled = IsEnabledAsync(flagName, userId, cancellationToken).Result; + result[flagName] = isEnabled; + } + + return Task.FromResult(result); + } + + public Task SetFlagAsync(string flagName, bool enabled, string? userId = null, CancellationToken cancellationToken = default) + { + _flags.AddOrUpdate(flagName, + new FeatureFlag { Name = flagName, Enabled = enabled }, + (_, existingFlag) => + { + if (userId != null) + { + existingFlag.UserOverrides[userId] = enabled; + _logger.LogInformation("Set user override for feature flag {FlagName}, User: {UserId}, Enabled: {Enabled}", + flagName, userId, enabled); + } + else + { + existingFlag.Enabled = enabled; + _logger.LogInformation("Set global state for feature flag {FlagName}, Enabled: {Enabled}", flagName, enabled); + } + return existingFlag; + }); + + return Task.CompletedTask; + } + + public Task GetRolloutPercentageAsync(string flagName, CancellationToken cancellationToken = default) + { + if (_flags.TryGetValue(flagName, out var flag)) + { + return Task.FromResult(flag.RolloutPercentage); + } + + return Task.FromResult(0); + } + + /// + /// Sets the rollout percentage for a feature flag. + /// + public Task SetRolloutPercentageAsync(string flagName, int percentage, CancellationToken cancellationToken = default) + { + if (percentage < 0 || percentage > 100) + { + throw new ArgumentOutOfRangeException(nameof(percentage), "Percentage must be between 0 and 100"); + } + + _flags.AddOrUpdate(flagName, + new FeatureFlag { Name = flagName, Enabled = true, RolloutPercentage = percentage }, + (_, existingFlag) => + { + existingFlag.RolloutPercentage = percentage; + return existingFlag; + }); + + _logger.LogInformation("Set rollout percentage for feature flag {FlagName}: {Percentage}%", flagName, percentage); + + return Task.CompletedTask; + } + + private void InitializeDefaultFlags() + { + // Initialize some example feature flags + var defaultFlags = new[] + { + new FeatureFlag { Name = "premium_features", Enabled = true, RolloutPercentage = 50 }, + new FeatureFlag { Name = "admin_performance_tools", Enabled = true, RolloutPercentage = 25 }, + new FeatureFlag { Name = "experimental_tools", Enabled = false, RolloutPercentage = 5 }, + new FeatureFlag { Name = "beta_features", Enabled = true, RolloutPercentage = 75 } + }; + + foreach (var flag in defaultFlags) + { + _flags.TryAdd(flag.Name, flag); + } + + _logger.LogInformation("Initialized {Count} default feature flags", defaultFlags.Length); + } + + private static int GetUserHash(string userId, string flagName) + { + // Create a consistent hash for the user/flag combination + // This ensures the same user always gets the same result for a flag + var input = $"{userId}:{flagName}"; + var hash = SHA256.HashData(Encoding.UTF8.GetBytes(input)); + + // Convert first 4 bytes to int and get percentage (0-99) + var hashInt = BitConverter.ToInt32(hash, 0); + return Math.Abs(hashInt) % 100; + } + + private class FeatureFlag + { + public string Name { get; set; } = ""; + public bool Enabled { get; set; } = true; + public int RolloutPercentage { get; set; } = 100; + public ConcurrentDictionary UserOverrides { get; set; } = new(); + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Services/InMemoryQuotaService.cs b/samples/DynamicToolFiltering/Services/InMemoryQuotaService.cs new file mode 100644 index 00000000..852d41dd --- /dev/null +++ b/samples/DynamicToolFiltering/Services/InMemoryQuotaService.cs @@ -0,0 +1,258 @@ +using System.Collections.Concurrent; +using DynamicToolFiltering.Configuration; +using Microsoft.Extensions.Options; + +namespace DynamicToolFiltering.Services; + +/// +/// In-memory implementation of quota service with period-based quotas. +/// Note: This is for demonstration purposes. In production, use a persistent store. +/// +public class InMemoryQuotaService : IQuotaService +{ + private readonly QuotaManagementOptions _options; + private readonly ConcurrentDictionary _userQuotas = new(); + private readonly ILogger _logger; + private readonly Timer _resetTimer; + + public InMemoryQuotaService(IOptions options, ILogger logger) + { + _options = options.Value.BusinessLogic.QuotaManagement; + _logger = logger; + + // Check for quota resets daily + _resetTimer = new Timer(async _ => await ProcessQuotaResetsAsync(), null, TimeSpan.FromHours(1), TimeSpan.FromHours(1)); + } + + public Task HasAvailableQuotaAsync(string userId, string userRole, string toolName, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return Task.FromResult(true); + } + + var quotaLimit = GetQuotaLimitForRole(userRole); + if (quotaLimit == -1) + { + return Task.FromResult(true); // Unlimited + } + + var userQuota = GetOrCreateUserQuota(userId); + var quotaCost = GetQuotaCost(toolName); + + lock (userQuota) + { + var hasQuota = userQuota.CurrentUsage + quotaCost <= quotaLimit; + return Task.FromResult(hasQuota); + } + } + + public Task ConsumeQuotaAsync(string userId, string toolName, int cost, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return Task.CompletedTask; + } + + var userQuota = GetOrCreateUserQuota(userId); + + lock (userQuota) + { + userQuota.CurrentUsage += cost; + userQuota.LastUsage = DateTime.UtcNow; + + // Track tool-specific usage + userQuota.ToolUsage.TryGetValue(toolName, out var currentToolUsage); + userQuota.ToolUsage[toolName] = currentToolUsage + cost; + + _logger.LogDebug("Consumed {Cost} quota for user {UserId}, tool {ToolName}. Total usage: {Usage}", + cost, userId, toolName, userQuota.CurrentUsage); + } + + return Task.CompletedTask; + } + + public Task GetCurrentUsageAsync(string userId, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return Task.FromResult(0); + } + + var userQuota = GetOrCreateUserQuota(userId); + + lock (userQuota) + { + return Task.FromResult(userQuota.CurrentUsage); + } + } + + public Task GetQuotaLimitAsync(string userId, string userRole, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return Task.FromResult(-1); // Unlimited when disabled + } + + return Task.FromResult(GetQuotaLimitForRole(userRole)); + } + + public Task GetRemainingQuotaAsync(string userId, string userRole, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + return Task.FromResult(-1); // Unlimited when disabled + } + + var quotaLimit = GetQuotaLimitForRole(userRole); + if (quotaLimit == -1) + { + return Task.FromResult(-1); // Unlimited + } + + var userQuota = GetOrCreateUserQuota(userId); + + lock (userQuota) + { + var remaining = Math.Max(0, quotaLimit - userQuota.CurrentUsage); + return Task.FromResult(remaining); + } + } + + public Task GetQuotaResetDateAsync(string userId, CancellationToken cancellationToken = default) + { + var userQuota = GetOrCreateUserQuota(userId); + + lock (userQuota) + { + return Task.FromResult(userQuota.NextResetDate); + } + } + + public Task ResetQuotaAsync(string userId, CancellationToken cancellationToken = default) + { + var userQuota = GetOrCreateUserQuota(userId); + + lock (userQuota) + { + userQuota.CurrentUsage = 0; + userQuota.ToolUsage.Clear(); + userQuota.NextResetDate = CalculateNextResetDate(DateTime.UtcNow); + + _logger.LogInformation("Reset quota for user {UserId}. Next reset: {NextReset}", userId, userQuota.NextResetDate); + } + + return Task.CompletedTask; + } + + public Task> GetUsageBreakdownAsync(string userId, CancellationToken cancellationToken = default) + { + var userQuota = GetOrCreateUserQuota(userId); + + lock (userQuota) + { + return Task.FromResult(new Dictionary(userQuota.ToolUsage)); + } + } + + private UserQuotaInfo GetOrCreateUserQuota(string userId) + { + return _userQuotas.GetOrAdd(userId, _ => new UserQuotaInfo + { + UserId = userId, + CurrentUsage = 0, + NextResetDate = CalculateNextResetDate(DateTime.UtcNow), + LastUsage = DateTime.UtcNow, + ToolUsage = new ConcurrentDictionary() + }); + } + + private int GetQuotaLimitForRole(string userRole) + { + if (_options.RoleQuotas.TryGetValue(userRole, out var limit)) + { + return limit; + } + + // Default to user quota if role not found + return _options.RoleQuotas.TryGetValue("user", out var userLimit) ? userLimit : 1000; + } + + private int GetQuotaCost(string toolName) + { + foreach (var mapping in _options.ToolQuotaCosts) + { + if (IsPatternMatch(mapping.Key, toolName)) + { + return mapping.Value; + } + } + + return 1; // Default cost + } + + private DateTime CalculateNextResetDate(DateTime fromDate) + { + return fromDate.AddDays(_options.QuotaPeriodDays); + } + + private async Task ProcessQuotaResetsAsync() + { + var now = DateTime.UtcNow; + var usersToReset = new List(); + + foreach (var kvp in _userQuotas) + { + var userQuota = kvp.Value; + + lock (userQuota) + { + if (now >= userQuota.NextResetDate) + { + usersToReset.Add(kvp.Key); + } + } + } + + foreach (var userId in usersToReset) + { + await ResetQuotaAsync(userId); + } + + if (usersToReset.Count > 0) + { + _logger.LogInformation("Reset quotas for {Count} users", usersToReset.Count); + } + } + + private static bool IsPatternMatch(string pattern, string toolName) + { + if (pattern == "*") + { + return true; + } + + // Simple glob pattern matching + if (pattern.EndsWith("*")) + { + var prefix = pattern[..^1]; + return toolName.StartsWith(prefix, StringComparison.OrdinalIgnoreCase); + } + + return string.Equals(pattern, toolName, StringComparison.OrdinalIgnoreCase); + } + + public void Dispose() + { + _resetTimer?.Dispose(); + } + + private class UserQuotaInfo + { + public string UserId { get; set; } = ""; + public int CurrentUsage { get; set; } + public DateTime NextResetDate { get; set; } + public DateTime LastUsage { get; set; } + public ConcurrentDictionary ToolUsage { get; set; } = new(); + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Services/InMemoryRateLimitingService.cs b/samples/DynamicToolFiltering/Services/InMemoryRateLimitingService.cs new file mode 100644 index 00000000..ebd3c2d9 --- /dev/null +++ b/samples/DynamicToolFiltering/Services/InMemoryRateLimitingService.cs @@ -0,0 +1,128 @@ +using System.Collections.Concurrent; + +namespace DynamicToolFiltering.Services; + +/// +/// In-memory implementation of rate limiting service. +/// Note: This is for demonstration purposes. In production, use a distributed cache like Redis. +/// +public class InMemoryRateLimitingService : IRateLimitingService +{ + private readonly ConcurrentDictionary> _usageRecords = new(); + private readonly ILogger _logger; + private readonly Timer _cleanupTimer; + + public InMemoryRateLimitingService(ILogger logger) + { + _logger = logger; + + // Run cleanup every 10 minutes + _cleanupTimer = new Timer(async _ => await CleanupOldRecordsAsync(), null, TimeSpan.FromMinutes(10), TimeSpan.FromMinutes(10)); + } + + public Task GetUsageCountAsync(string userId, string toolName, DateTime windowStart, CancellationToken cancellationToken = default) + { + var key = GetKey(userId, toolName); + + if (!_usageRecords.TryGetValue(key, out var records)) + { + return Task.FromResult(0); + } + + lock (records) + { + var count = records.Count(r => r.Timestamp >= windowStart); + return Task.FromResult(count); + } + } + + public Task RecordUsageAsync(string userId, string toolName, DateTime timestamp, CancellationToken cancellationToken = default) + { + var key = GetKey(userId, toolName); + var record = new UsageRecord(timestamp); + + _usageRecords.AddOrUpdate(key, + new List { record }, + (_, existingRecords) => + { + lock (existingRecords) + { + existingRecords.Add(record); + return existingRecords; + } + }); + + _logger.LogDebug("Recorded usage for {UserId}, {ToolName} at {Timestamp}", userId, toolName, timestamp); + + return Task.CompletedTask; + } + + public Task CleanupOldRecordsAsync(CancellationToken cancellationToken = default) + { + var cutoffTime = DateTime.UtcNow.AddHours(-24); // Keep records for 24 hours + var keysToRemove = new List(); + + foreach (var kvp in _usageRecords) + { + var records = kvp.Value; + + lock (records) + { + records.RemoveAll(r => r.Timestamp < cutoffTime); + + if (records.Count == 0) + { + keysToRemove.Add(kvp.Key); + } + } + } + + foreach (var key in keysToRemove) + { + _usageRecords.TryRemove(key, out _); + } + + if (keysToRemove.Count > 0) + { + _logger.LogDebug("Cleaned up {Count} empty usage record collections", keysToRemove.Count); + } + + return Task.CompletedTask; + } + + public Task> GetUsageStatisticsAsync(string userId, CancellationToken cancellationToken = default) + { + var statistics = new Dictionary(); + var userPrefix = $"{userId}:"; + var windowStart = DateTime.UtcNow.AddHours(-1); // Last hour + + foreach (var kvp in _usageRecords) + { + if (kvp.Key.StartsWith(userPrefix)) + { + var toolName = kvp.Key[userPrefix.Length..]; + var records = kvp.Value; + + lock (records) + { + var count = records.Count(r => r.Timestamp >= windowStart); + if (count > 0) + { + statistics[toolName] = count; + } + } + } + } + + return Task.FromResult(statistics); + } + + private static string GetKey(string userId, string toolName) => $"{userId}:{toolName}"; + + public void Dispose() + { + _cleanupTimer?.Dispose(); + } + + private record UsageRecord(DateTime Timestamp); +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/TESTING_GUIDE.md b/samples/DynamicToolFiltering/TESTING_GUIDE.md new file mode 100644 index 00000000..7b20ff1a --- /dev/null +++ b/samples/DynamicToolFiltering/TESTING_GUIDE.md @@ -0,0 +1,597 @@ +# Testing Guide for Dynamic Tool Filtering Sample + +This guide provides comprehensive testing scenarios and examples for the Dynamic Tool Filtering MCP server sample. + +## Quick Start Testing + +### 1. Start the Server + +```bash +cd samples/DynamicToolFiltering +dotnet run --launch-profile DevelopmentMode +``` + +The server will start on `http://localhost:8080` with development-friendly settings. + +### 2. Basic Health Check + +```bash +curl http://localhost:8080/health +``` + +Expected response: +```json +{ + "Status": "healthy", + "Timestamp": "2024-01-01T12:00:00.000Z", + "Environment": "Development", + "Version": "1.0.0" +} +``` + +## API Key Testing + +The sample includes predefined API keys for different user roles: + +| API Key | Role | Scopes | Description | +|---------|------|--------|-------------| +| `demo-guest-key` | guest | basic:tools | Limited access to public tools | +| `demo-user-key` | user | user:tools, read:tools, basic:tools | Standard user access | +| `demo-premium-key` | premium | premium:tools, user:tools, read:tools, basic:tools | Premium features | +| `demo-admin-key` | admin | admin:tools, premium:tools, user:tools, read:tools, basic:tools | Full administrative access | + +## Test Scenarios + +### 1. Tool Visibility Testing + +Test which tools are visible to different user roles: + +```bash +# Guest user - should see only basic tools +curl -H "X-API-Key: demo-guest-key" \ + http://localhost:8080/mcp/v1/tools + +# User role - should see user-level tools +curl -H "X-API-Key: demo-user-key" \ + http://localhost:8080/mcp/v1/tools + +# Premium user - should see premium tools +curl -H "X-API-Key: demo-premium-key" \ + http://localhost:8080/mcp/v1/tools + +# Admin user - should see all tools +curl -H "X-API-Key: demo-admin-key" \ + http://localhost:8080/mcp/v1/tools +``` + +### 2. Tool Execution Testing + +#### Public Tools (No Authentication Required) + +```bash +# Echo tool - available to everyone +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "Content-Type: application/json" \ + -d '{ + "name": "echo", + "arguments": { + "message": "Hello Dynamic Filtering!" + } + }' + +# System info - public tool +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "Content-Type: application/json" \ + -d '{ + "name": "get_system_info", + "arguments": {} + }' + +# UTC time - public tool +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "Content-Type: application/json" \ + -d '{ + "name": "get_utc_time", + "arguments": {} + }' +``` + +#### User Tools (Requires Authentication) + +```bash +# User profile - requires user role or higher +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-user-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "get_user_profile", + "arguments": {} + }' + +# Hash calculation - user tool +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-user-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "calculate_hash", + "arguments": { + "text": "Hello World", + "algorithm": "sha256" + } + }' + +# UUID generation - user tool +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-user-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "generate_uuid", + "arguments": { + "count": 3 + } + }' + +# Email validation - user tool +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-user-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "validate_email", + "arguments": { + "email": "test@example.com" + } + }' +``` + +#### Premium Tools (Requires Premium Role) + +```bash +# Secure random generation - premium tool +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-premium-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "premium_generate_secure_random", + "arguments": { + "byteCount": 32, + "format": "hex" + } + }' + +# Text analysis - premium tool +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-premium-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "premium_analyze_text", + "arguments": { + "text": "This is a sample text for comprehensive analysis. It contains multiple sentences and various words to demonstrate the text analysis capabilities.", + "depth": "comprehensive" + } + }' + +# Password generation - premium tool +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-premium-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "premium_generate_password", + "arguments": { + "length": 16, + "includeUppercase": true, + "includeLowercase": true, + "includeNumbers": true, + "includeSpecial": true, + "excludeAmbiguous": true + } + }' + +# Performance benchmark - premium tool +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-premium-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "premium_performance_benchmark", + "arguments": { + "benchmarkType": "cpu", + "durationSeconds": 3 + } + }' +``` + +#### Admin Tools (Requires Admin Role) + +```bash +# System diagnostics - admin tool +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-admin-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "admin_get_system_diagnostics", + "arguments": {} + }' + +# Force garbage collection - admin tool +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-admin-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "admin_force_gc", + "arguments": { + "generation": -1 + } + }' + +# List processes - admin tool +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-admin-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "admin_list_processes", + "arguments": { + "limit": 10 + } + }' + +# Reload configuration - admin tool +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-admin-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "admin_reload_config", + "arguments": { + "section": "filtering" + } + }' +``` + +### 3. Authorization Failure Testing + +Test that users cannot access tools above their privilege level: + +```bash +# Try to access admin tool with user key (should fail) +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-user-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "admin_get_system_diagnostics", + "arguments": {} + }' + +# Try to access premium tool with guest key (should fail) +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-guest-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "premium_generate_secure_random", + "arguments": { + "byteCount": 32 + } + }' + +# Try to access user tool without authentication (should fail) +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "Content-Type: application/json" \ + -d '{ + "name": "get_user_profile", + "arguments": {} + }' +``` + +Expected response for unauthorized access: +```json +{ + "error": { + "code": -32002, + "message": "Access denied for tool 'admin_get_system_diagnostics': Tool requires role(s): admin or super_admin. User has role(s): user", + "data": { + "ToolName": "admin_get_system_diagnostics", + "Reason": "Tool requires role(s): admin or super_admin. User has role(s): user", + "HttpStatusCode": 401, + "RequiresAuthentication": true + } + } +} +``` + +### 4. Rate Limiting Testing + +Test rate limiting by making multiple rapid requests: + +```bash +# Test rate limiting with guest account (limited to 20 requests in development) +for i in {1..25}; do + echo "Request $i:" + curl -H "X-API-Key: demo-guest-key" \ + -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "Content-Type: application/json" \ + -d '{"name": "echo", "arguments": {"message": "Test '$i'"}}' \ + -w "Status: %{http_code}\n" -s -o /dev/null + sleep 0.1 +done +``` + +After hitting the limit, you should see HTTP 429 responses with a rate limit error. + +### 5. Feature Flag Testing + +Test feature flag functionality: + +```bash +# Check current feature flags (admin only) +curl -H "X-API-Key: demo-admin-key" \ + http://localhost:8080/admin/feature-flags + +# Toggle a feature flag (admin only) +curl -X POST \ + -H "X-API-Key: demo-admin-key" \ + "http://localhost:8080/admin/feature-flags/premium_features?enabled=false" + +# Try to use a premium tool after disabling the feature flag +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-premium-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "premium_generate_secure_random", + "arguments": { + "byteCount": 32 + } + }' + +# Re-enable the feature flag +curl -X POST \ + -H "X-API-Key: demo-admin-key" \ + "http://localhost:8080/admin/feature-flags/premium_features?enabled=true" +``` + +## Launch Profile Testing + +Test different launch profiles to verify filter configurations: + +### 1. No Filtering Mode + +```bash +dotnet run --launch-profile NoFilteringMode +``` + +In this mode, all tools should be accessible regardless of authentication. + +### 2. Rate Limiting Demo Mode + +```bash +dotnet run --launch-profile RateLimitingDemoMode +``` + +This mode has very strict rate limits (1-minute windows with low limits): +- Guest: 3 requests per minute +- User: 10 requests per minute +- Premium: 25 requests per minute +- Admin: 100 requests per minute + +### 3. Business Hours Demo Mode + +```bash +dotnet run --launch-profile BusinessHoursDemoMode +``` + +This mode restricts admin tools to business hours (9 AM - 5 PM weekdays UTC). + +### 4. Tenant Demo Mode + +```bash +dotnet run --launch-profile TenantDemoMode +``` + +This mode enables tenant isolation. Add `X-Tenant-ID` header to test: + +```bash +curl -H "X-API-Key: demo-user-key" \ + -H "X-Tenant-ID: tenant-a" \ + -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "Content-Type: application/json" \ + -d '{"name": "echo", "arguments": {"message": "Tenant A"}}' +``` + +## Error Response Testing + +### 1. Invalid API Key + +```bash +curl -H "X-API-Key: invalid-key" \ + http://localhost:8080/mcp/v1/tools +``` + +### 2. Malformed Request + +```bash +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "Content-Type: application/json" \ + -d '{"invalid": "json"}' +``` + +### 3. Tool Not Found + +```bash +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-user-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "nonexistent_tool", + "arguments": {} + }' +``` + +### 4. Invalid Arguments + +```bash +curl -X POST http://localhost:8080/mcp/v1/tools/call \ + -H "X-API-Key: demo-user-key" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "calculate_hash", + "arguments": { + "text": "test", + "algorithm": "invalid_algorithm" + } + }' +``` + +## Performance Testing + +### 1. Concurrent Request Testing + +Use a tool like Apache Bench to test concurrent requests: + +```bash +# Install apache2-utils if not installed +# Ubuntu/Debian: sudo apt-get install apache2-utils +# macOS: brew install httpd + +# Test with 10 concurrent connections, 100 total requests +ab -n 100 -c 10 -H "X-API-Key: demo-user-key" \ + -p echo_data.json -T application/json \ + http://localhost:8080/mcp/v1/tools/call +``` + +Create `echo_data.json`: +```json +{ + "name": "echo", + "arguments": { + "message": "Performance test" + } +} +``` + +### 2. Memory Usage Testing + +Monitor memory usage during heavy load: + +```bash +# Run server in background +dotnet run --launch-profile DevelopmentMode & +SERVER_PID=$! + +# Monitor memory usage +while true; do + ps -o pid,vsz,rss,comm -p $SERVER_PID + sleep 5 +done + +# Kill server when done +kill $SERVER_PID +``` + +## JWT Token Testing + +For JWT testing, you need to generate valid tokens. Here's a simple example using a script or online JWT generator: + +### JWT Payload Example + +```json +{ + "sub": "user123", + "name": "Test User", + "role": "premium", + "scope": "premium:tools user:tools read:tools basic:tools", + "iat": 1609459200, + "exp": 1609462800, + "iss": "dynamic-tool-filtering-demo", + "aud": "mcp-api-clients" +} +``` + +Use the secret key from configuration: `your-256-bit-secret-key-here-make-it-secure-and-change-in-production` + +Test with JWT: + +```bash +curl -H "Authorization: Bearer YOUR_JWT_TOKEN" \ + http://localhost:8080/mcp/v1/tools +``` + +## Automated Testing Script + +Create a comprehensive test script: + +```bash +#!/bin/bash + +BASE_URL="http://localhost:8080" +GUEST_KEY="demo-guest-key" +USER_KEY="demo-user-key" +PREMIUM_KEY="demo-premium-key" +ADMIN_KEY="demo-admin-key" + +echo "=== Dynamic Tool Filtering Test Suite ===" + +# Test 1: Health Check +echo "Test 1: Health Check" +curl -s "$BASE_URL/health" | jq . +echo "" + +# Test 2: Tool Visibility by Role +echo "Test 2: Tool Visibility" +echo "Guest tools:" +curl -s -H "X-API-Key: $GUEST_KEY" "$BASE_URL/mcp/v1/tools" | jq '.result.tools[].name' +echo "User tools:" +curl -s -H "X-API-Key: $USER_KEY" "$BASE_URL/mcp/v1/tools" | jq '.result.tools[].name' +echo "Premium tools:" +curl -s -H "X-API-Key: $PREMIUM_KEY" "$BASE_URL/mcp/v1/tools" | jq '.result.tools[].name' +echo "Admin tools:" +curl -s -H "X-API-Key: $ADMIN_KEY" "$BASE_URL/mcp/v1/tools" | jq '.result.tools[].name' +echo "" + +# Test 3: Successful Tool Execution +echo "Test 3: Successful Tool Execution" +curl -s -X POST "$BASE_URL/mcp/v1/tools/call" \ + -H "Content-Type: application/json" \ + -d '{"name": "echo", "arguments": {"message": "Test"}}' | jq . +echo "" + +# Test 4: Authorization Failure +echo "Test 4: Authorization Failure" +curl -s -X POST "$BASE_URL/mcp/v1/tools/call" \ + -H "X-API-Key: $GUEST_KEY" \ + -H "Content-Type: application/json" \ + -d '{"name": "admin_get_system_diagnostics", "arguments": {}}' | jq . +echo "" + +echo "=== Test Suite Complete ===" +``` + +## Troubleshooting + +### Common Issues + +1. **Server not starting**: Check that port 8080 is available +2. **Tools not visible**: Verify API key and user role configuration +3. **Rate limit errors**: Wait for rate limit window to reset or use a different launch profile +4. **Feature flag errors**: Check feature flag configuration and current state + +### Debug Logging + +Enable debug logging in `appsettings.Development.json`: + +```json +{ + "Logging": { + "LogLevel": { + "DynamicToolFiltering": "Debug", + "ModelContextProtocol": "Debug" + } + } +} +``` + +### Verbose Filter Logging + +Check the console output for detailed filter execution logs: + +``` +[Debug] Tool inclusion check for echo: User roles [guest], Required roles [guest, user, premium, admin, super_admin], HasAccess: True +[Debug] Tool execution authorized: echo for user anonymous_12345678. Remaining: 19/20 +``` + +This guide provides comprehensive testing coverage for all aspects of the Dynamic Tool Filtering sample, from basic functionality to advanced scenarios and edge cases. \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Tools/AdminTools.cs b/samples/DynamicToolFiltering/Tools/AdminTools.cs new file mode 100644 index 00000000..0e78d500 --- /dev/null +++ b/samples/DynamicToolFiltering/Tools/AdminTools.cs @@ -0,0 +1,210 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace DynamicToolFiltering.Tools; + +/// +/// Administrative tools that require elevated permissions. +/// These tools are only available to users with admin roles. +/// +public class AdminTools +{ + /// + /// Get detailed system diagnostics and performance metrics. + /// + [McpServerTool(Name = "admin_get_system_diagnostics", Description = "Get detailed system diagnostics and performance metrics")] + public static async Task GetSystemDiagnosticsAsync(CancellationToken cancellationToken = default) + { + await Task.Delay(200, cancellationToken); + + var process = System.Diagnostics.Process.GetCurrentProcess(); + var gc = GC.GetTotalMemory(false); + + var diagnostics = new + { + ProcessInfo = new + { + ProcessId = process.Id, + StartTime = process.StartTime.ToString("O"), + TotalProcessorTime = process.TotalProcessorTime.ToString(), + WorkingSet = process.WorkingSet64, + PrivateMemorySize = process.PrivateMemorySize64, + VirtualMemorySize = process.VirtualMemorySize64 + }, + MemoryInfo = new + { + GCTotalMemory = gc, + Gen0Collections = GC.CollectionCount(0), + Gen1Collections = GC.CollectionCount(1), + Gen2Collections = GC.CollectionCount(2) + }, + EnvironmentInfo = new + { + OSVersion = Environment.OSVersion.ToString(), + CLRVersion = Environment.Version.ToString(), + MachineName = Environment.MachineName, + UserName = Environment.UserName, + ProcessorCount = Environment.ProcessorCount, + SystemDirectory = Environment.SystemDirectory, + TickCount = Environment.TickCount64 + }, + Timestamp = DateTime.UtcNow.ToString("O") + }; + + return CallToolResult.FromContent( + TextContent.Create($"System Diagnostics: {System.Text.Json.JsonSerializer.Serialize(diagnostics, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } + + /// + /// Force garbage collection (admin operation). + /// + [McpServerTool(Name = "admin_force_gc", Description = "Force garbage collection (administrative operation)")] + public static async Task ForceGarbageCollectionAsync( + [Description("GC generation (0, 1, 2, or -1 for all)")] int generation = -1, + CancellationToken cancellationToken = default) + { + await Task.Delay(50, cancellationToken); + + var beforeMemory = GC.GetTotalMemory(false); + var beforeCollections = new + { + Gen0 = GC.CollectionCount(0), + Gen1 = GC.CollectionCount(1), + Gen2 = GC.CollectionCount(2) + }; + + if (generation == -1) + { + GC.Collect(); + } + else if (generation >= 0 && generation <= 2) + { + GC.Collect(generation); + } + else + { + return CallToolResult.FromError("Invalid generation. Must be 0, 1, 2, or -1 for all generations"); + } + + GC.WaitForPendingFinalizers(); + + var afterMemory = GC.GetTotalMemory(false); + var afterCollections = new + { + Gen0 = GC.CollectionCount(0), + Gen1 = GC.CollectionCount(1), + Gen2 = GC.CollectionCount(2) + }; + + var result = new + { + Generation = generation, + MemoryBefore = beforeMemory, + MemoryAfter = afterMemory, + MemoryReclaimed = beforeMemory - afterMemory, + CollectionsBefore = beforeCollections, + CollectionsAfter = afterCollections, + ExecutedAt = DateTime.UtcNow.ToString("O") + }; + + return CallToolResult.FromContent( + TextContent.Create($"Garbage Collection Result: {System.Text.Json.JsonSerializer.Serialize(result, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } + + /// + /// Get list of all running processes (admin operation). + /// + [McpServerTool(Name = "admin_list_processes", Description = "Get list of all running processes")] + public static async Task ListProcessesAsync( + [Description("Maximum number of processes to return")] int limit = 20, + CancellationToken cancellationToken = default) + { + await Task.Delay(300, cancellationToken); + + if (limit < 1 || limit > 100) + { + return CallToolResult.FromError("Limit must be between 1 and 100"); + } + + var processes = System.Diagnostics.Process.GetProcesses() + .Take(limit) + .Select(p => new + { + ProcessId = p.Id, + ProcessName = p.ProcessName, + StartTime = TryGetStartTime(p), + WorkingSet = TryGetWorkingSet(p), + HasExited = TryGetHasExited(p) + }) + .ToArray(); + + var result = new + { + ProcessCount = processes.Length, + Processes = processes, + RetrievedAt = DateTime.UtcNow.ToString("O") + }; + + return CallToolResult.FromContent( + TextContent.Create($"Process List: {System.Text.Json.JsonSerializer.Serialize(result, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } + + /// + /// Simulate configuration reload (admin operation). + /// + [McpServerTool(Name = "admin_reload_config", Description = "Simulate configuration reload")] + public static async Task ReloadConfigurationAsync( + [Description("Configuration section to reload")] string section = "all", + CancellationToken cancellationToken = default) + { + await Task.Delay(500, cancellationToken); // Simulate configuration reload time + + var result = new + { + Section = section, + Status = "success", + ReloadedAt = DateTime.UtcNow.ToString("O"), + Message = $"Configuration section '{section}' has been reloaded successfully", + Version = Guid.NewGuid().ToString("N")[..8] // Simulate new config version + }; + + return CallToolResult.FromContent( + TextContent.Create($"Configuration Reload: {System.Text.Json.JsonSerializer.Serialize(result, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } + + private static string TryGetStartTime(System.Diagnostics.Process process) + { + try + { + return process.StartTime.ToString("O"); + } + catch + { + return "N/A"; + } + } + + private static long TryGetWorkingSet(System.Diagnostics.Process process) + { + try + { + return process.WorkingSet64; + } + catch + { + return -1; + } + } + + private static bool TryGetHasExited(System.Diagnostics.Process process) + { + try + { + return process.HasExited; + } + catch + { + return true; + } + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Tools/PremiumTools.cs b/samples/DynamicToolFiltering/Tools/PremiumTools.cs new file mode 100644 index 00000000..38aa3a6e --- /dev/null +++ b/samples/DynamicToolFiltering/Tools/PremiumTools.cs @@ -0,0 +1,382 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Security.Cryptography; +using System.Text; + +namespace DynamicToolFiltering.Tools; + +/// +/// Premium tools that require subscription or premium access. +/// These tools provide advanced functionality with higher resource usage. +/// +public class PremiumTools +{ + /// + /// Generate cryptographically secure random bytes. + /// + [McpServerTool(Name = "premium_generate_secure_random", Description = "Generate cryptographically secure random bytes")] + public static async Task GenerateSecureRandomAsync( + [Description("Number of bytes to generate (1-1024)")] int byteCount = 32, + [Description("Output format: hex, base64, or bytes")] string format = "hex", + CancellationToken cancellationToken = default) + { + await Task.Delay(100, cancellationToken); + + if (byteCount < 1 || byteCount > 1024) + { + return CallToolResult.FromError("Byte count must be between 1 and 1024"); + } + + using var rng = RandomNumberGenerator.Create(); + var randomBytes = new byte[byteCount]; + rng.GetBytes(randomBytes); + + var output = format.ToLowerInvariant() switch + { + "hex" => Convert.ToHexString(randomBytes).ToLowerInvariant(), + "base64" => Convert.ToBase64String(randomBytes), + "bytes" => string.Join(",", randomBytes), + _ => Convert.ToHexString(randomBytes).ToLowerInvariant() + }; + + var result = new + { + ByteCount = byteCount, + Format = format, + RandomData = output, + GeneratedAt = DateTime.UtcNow.ToString("O"), + Entropy = randomBytes.Length * 8 // bits of entropy + }; + + return CallToolResult.FromContent( + TextContent.Create($"Secure Random Data: {System.Text.Json.JsonSerializer.Serialize(result, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } + + /// + /// Perform advanced text analysis with metrics. + /// + [McpServerTool(Name = "premium_analyze_text", Description = "Perform advanced text analysis with detailed metrics")] + public static async Task AnalyzeTextAsync( + [Description("Text to analyze")] string text, + [Description("Analysis depth: basic, standard, comprehensive")] string depth = "standard", + CancellationToken cancellationToken = default) + { + await Task.Delay(200, cancellationToken); // Simulate analysis time + + if (string.IsNullOrEmpty(text)) + { + return CallToolResult.FromError("Text cannot be empty"); + } + + var words = text.Split(new[] { ' ', '\t', '\n', '\r', '.', ',', ';', ':', '!', '?' }, + StringSplitOptions.RemoveEmptyEntries); + + var sentences = text.Split(new[] { '.', '!', '?' }, + StringSplitOptions.RemoveEmptyEntries) + .Where(s => !string.IsNullOrWhiteSpace(s)) + .ToArray(); + + var basicMetrics = new + { + CharacterCount = text.Length, + WordCount = words.Length, + SentenceCount = sentences.Length, + ParagraphCount = text.Split(new[] { "\n\n", "\r\n\r\n" }, StringSplitOptions.RemoveEmptyEntries).Length, + AverageWordsPerSentence = sentences.Length > 0 ? (double)words.Length / sentences.Length : 0, + AverageCharactersPerWord = words.Length > 0 ? words.Average(w => w.Length) : 0 + }; + + var analysis = new Dictionary + { + ["BasicMetrics"] = basicMetrics + }; + + if (depth is "standard" or "comprehensive") + { + var wordFrequency = words + .GroupBy(w => w.ToLowerInvariant()) + .OrderByDescending(g => g.Count()) + .Take(10) + .ToDictionary(g => g.Key, g => g.Count()); + + var readabilityMetrics = new + { + LongestWord = words.OrderByDescending(w => w.Length).FirstOrDefault()?.Length ?? 0, + ShortestWord = words.OrderBy(w => w.Length).FirstOrDefault()?.Length ?? 0, + UniqueWords = words.Distinct(StringComparer.OrdinalIgnoreCase).Count(), + LexicalDiversity = words.Length > 0 ? (double)words.Distinct(StringComparer.OrdinalIgnoreCase).Count() / words.Length : 0 + }; + + analysis["WordFrequency"] = wordFrequency; + analysis["ReadabilityMetrics"] = readabilityMetrics; + } + + if (depth == "comprehensive") + { + var advancedMetrics = new + { + CapitalLetters = text.Count(char.IsUpper), + LowercaseLetters = text.Count(char.IsLower), + Digits = text.Count(char.IsDigit), + Punctuation = text.Count(char.IsPunctuation), + Whitespace = text.Count(char.IsWhiteSpace), + VowelCount = text.Count(c => "aeiouAEIOU".Contains(c)), + ConsonantCount = text.Count(c => char.IsLetter(c) && !"aeiouAEIOU".Contains(c)) + }; + + analysis["AdvancedMetrics"] = advancedMetrics; + } + + analysis["AnalysisDepth"] = depth; + analysis["AnalyzedAt"] = DateTime.UtcNow.ToString("O"); + + return CallToolResult.FromContent( + TextContent.Create($"Text Analysis: {System.Text.Json.JsonSerializer.Serialize(analysis, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } + + /// + /// Generate secure password with customizable complexity. + /// + [McpServerTool(Name = "premium_generate_password", Description = "Generate secure password with customizable complexity")] + public static async Task GeneratePasswordAsync( + [Description("Password length (8-128)")] int length = 16, + [Description("Include uppercase letters")] bool includeUppercase = true, + [Description("Include lowercase letters")] bool includeLowercase = true, + [Description("Include numbers")] bool includeNumbers = true, + [Description("Include special characters")] bool includeSpecial = true, + [Description("Exclude ambiguous characters (0, O, l, I, etc.)")] bool excludeAmbiguous = false, + CancellationToken cancellationToken = default) + { + await Task.Delay(150, cancellationToken); + + if (length < 8 || length > 128) + { + return CallToolResult.FromError("Password length must be between 8 and 128 characters"); + } + + if (!includeUppercase && !includeLowercase && !includeNumbers && !includeSpecial) + { + return CallToolResult.FromError("At least one character type must be enabled"); + } + + var characterSets = new List(); + var guaranteedChars = new List(); + + if (includeUppercase) + { + var upperChars = excludeAmbiguous ? "ABCDEFGHJKLMNPQRSTUVWXYZ" : "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + characterSets.Add(upperChars); + guaranteedChars.Add(upperChars[RandomNumberGenerator.GetInt32(upperChars.Length)]); + } + + if (includeLowercase) + { + var lowerChars = excludeAmbiguous ? "abcdefghjkmnpqrstuvwxyz" : "abcdefghijklmnopqrstuvwxyz"; + characterSets.Add(lowerChars); + guaranteedChars.Add(lowerChars[RandomNumberGenerator.GetInt32(lowerChars.Length)]); + } + + if (includeNumbers) + { + var numberChars = excludeAmbiguous ? "23456789" : "0123456789"; + characterSets.Add(numberChars); + guaranteedChars.Add(numberChars[RandomNumberGenerator.GetInt32(numberChars.Length)]); + } + + if (includeSpecial) + { + var specialChars = excludeAmbiguous ? "!@#$%^&*+-=" : "!@#$%^&*()_+-=[]{}|;:,.<>?"; + characterSets.Add(specialChars); + guaranteedChars.Add(specialChars[RandomNumberGenerator.GetInt32(specialChars.Length)]); + } + + var allChars = string.Join("", characterSets); + var password = new StringBuilder(); + + // Add guaranteed characters first + foreach (var c in guaranteedChars) + { + password.Append(c); + } + + // Fill remaining positions + for (int i = guaranteedChars.Count; i < length; i++) + { + password.Append(allChars[RandomNumberGenerator.GetInt32(allChars.Length)]); + } + + // Shuffle the password + var passwordArray = password.ToString().ToCharArray(); + for (int i = passwordArray.Length - 1; i > 0; i--) + { + int j = RandomNumberGenerator.GetInt32(i + 1); + (passwordArray[i], passwordArray[j]) = (passwordArray[j], passwordArray[i]); + } + + var finalPassword = new string(passwordArray); + + // Calculate password strength + var entropy = Math.Log2(allChars.Length) * length; + var strengthRating = entropy switch + { + < 50 => "Weak", + < 75 => "Moderate", + < 100 => "Strong", + _ => "Very Strong" + }; + + var result = new + { + Password = finalPassword, + Length = length, + Entropy = Math.Round(entropy, 2), + StrengthRating = strengthRating, + CharacterTypes = new + { + Uppercase = includeUppercase, + Lowercase = includeLowercase, + Numbers = includeNumbers, + Special = includeSpecial, + ExcludeAmbiguous = excludeAmbiguous + }, + GeneratedAt = DateTime.UtcNow.ToString("O") + }; + + return CallToolResult.FromContent( + TextContent.Create($"Password Generation: {System.Text.Json.JsonSerializer.Serialize(result, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } + + /// + /// Perform benchmark test to measure system performance. + /// + [McpServerTool(Name = "premium_performance_benchmark", Description = "Perform system performance benchmark")] + public static async Task PerformanceBenchmarkAsync( + [Description("Benchmark type: cpu, memory, disk, network")] string benchmarkType = "cpu", + [Description("Test duration in seconds (1-30)")] int durationSeconds = 5, + CancellationToken cancellationToken = default) + { + if (durationSeconds < 1 || durationSeconds > 30) + { + return CallToolResult.FromError("Duration must be between 1 and 30 seconds"); + } + + var startTime = DateTime.UtcNow; + var results = new Dictionary(); + + switch (benchmarkType.ToLowerInvariant()) + { + case "cpu": + results = await BenchmarkCpuAsync(durationSeconds, cancellationToken); + break; + case "memory": + results = await BenchmarkMemoryAsync(durationSeconds, cancellationToken); + break; + case "disk": + results = await BenchmarkDiskAsync(durationSeconds, cancellationToken); + break; + default: + return CallToolResult.FromError($"Unknown benchmark type: {benchmarkType}. Supported types: cpu, memory, disk"); + } + + results["BenchmarkType"] = benchmarkType; + results["DurationSeconds"] = durationSeconds; + results["StartTime"] = startTime.ToString("O"); + results["EndTime"] = DateTime.UtcNow.ToString("O"); + + return CallToolResult.FromContent( + TextContent.Create($"Performance Benchmark: {System.Text.Json.JsonSerializer.Serialize(results, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } + + private static async Task> BenchmarkCpuAsync(int durationSeconds, CancellationToken cancellationToken) + { + var operations = 0L; + var endTime = DateTime.UtcNow.AddSeconds(durationSeconds); + + while (DateTime.UtcNow < endTime && !cancellationToken.IsCancellationRequested) + { + // Perform CPU-intensive operation + Math.Sqrt(operations); + operations++; + + if (operations % 10000 == 0) + { + await Task.Yield(); // Allow other tasks to run + } + } + + return new Dictionary + { + ["OperationsPerformed"] = operations, + ["OperationsPerSecond"] = operations / durationSeconds, + ["BenchmarkScore"] = operations / 1000000.0 // Normalize to millions of operations + }; + } + + private static async Task> BenchmarkMemoryAsync(int durationSeconds, CancellationToken cancellationToken) + { + var allocations = 0L; + var totalMemoryAllocated = 0L; + var endTime = DateTime.UtcNow.AddSeconds(durationSeconds); + + while (DateTime.UtcNow < endTime && !cancellationToken.IsCancellationRequested) + { + // Allocate and deallocate memory + var data = new byte[1024]; // 1KB allocation + allocations++; + totalMemoryAllocated += data.Length; + + if (allocations % 1000 == 0) + { + await Task.Yield(); + GC.Collect(0, GCCollectionMode.Optimized, false); + } + } + + return new Dictionary + { + ["AllocationsPerformed"] = allocations, + ["TotalMemoryAllocated"] = totalMemoryAllocated, + ["AllocationsPerSecond"] = allocations / durationSeconds, + ["MemoryThroughputMBps"] = (totalMemoryAllocated / 1024.0 / 1024.0) / durationSeconds + }; + } + + private static async Task> BenchmarkDiskAsync(int durationSeconds, CancellationToken cancellationToken) + { + var operations = 0L; + var tempFile = Path.GetTempFileName(); + var data = new byte[4096]; // 4KB blocks + RandomNumberGenerator.Fill(data); + + try + { + var endTime = DateTime.UtcNow.AddSeconds(durationSeconds); + + using var fileStream = new FileStream(tempFile, FileMode.Create, FileAccess.Write, FileShare.None, 4096, FileOptions.DeleteOnClose); + + while (DateTime.UtcNow < endTime && !cancellationToken.IsCancellationRequested) + { + await fileStream.WriteAsync(data, cancellationToken); + operations++; + + if (operations % 100 == 0) + { + await fileStream.FlushAsync(cancellationToken); + await Task.Yield(); + } + } + } + finally + { + try { File.Delete(tempFile); } catch { /* Ignore cleanup errors */ } + } + + return new Dictionary + { + ["WriteOperations"] = operations, + ["TotalBytesWritten"] = operations * data.Length, + ["OperationsPerSecond"] = operations / durationSeconds, + ["ThroughputMBps"] = (operations * data.Length / 1024.0 / 1024.0) / durationSeconds + }; + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Tools/PublicTools.cs b/samples/DynamicToolFiltering/Tools/PublicTools.cs new file mode 100644 index 00000000..14601d82 --- /dev/null +++ b/samples/DynamicToolFiltering/Tools/PublicTools.cs @@ -0,0 +1,67 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace DynamicToolFiltering.Tools; + +/// +/// Public tools available to all users without authentication. +/// These represent the most basic functionality that doesn't require any authorization. +/// +public class PublicTools +{ + /// + /// Get basic system information - available to all users. + /// + [McpServerTool(Name = "get_system_info", Description = "Get basic system information and API status")] + public static async Task GetSystemInfoAsync(CancellationToken cancellationToken = default) + { + await Task.Delay(50, cancellationToken); // Simulate some work + + var systemInfo = new + { + Status = "online", + Version = "1.0.0", + Timestamp = DateTime.UtcNow.ToString("O"), + Environment = Environment.GetEnvironmentVariable("ASPNETCORE_ENVIRONMENT") ?? "Production", + MachineName = Environment.MachineName, + ProcessorCount = Environment.ProcessorCount + }; + + return CallToolResult.FromContent( + TextContent.Create($"System Information: {System.Text.Json.JsonSerializer.Serialize(systemInfo, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } + + /// + /// Simple echo service for testing connectivity. + /// + [McpServerTool(Name = "echo", Description = "Echo back the provided message")] + public static async Task EchoAsync( + [Description("The message to echo back")] string message, + CancellationToken cancellationToken = default) + { + await Task.Delay(10, cancellationToken); + + return CallToolResult.FromContent( + TextContent.Create($"Echo: {message} (timestamp: {DateTime.UtcNow:O})")); + } + + /// + /// Get current UTC time - useful for timezone-independent operations. + /// + [McpServerTool(Name = "get_utc_time", Description = "Get the current UTC timestamp")] + public static async Task GetUtcTimeAsync(CancellationToken cancellationToken = default) + { + await Task.Delay(5, cancellationToken); + + var timeInfo = new + { + UtcTime = DateTime.UtcNow.ToString("O"), + UnixTimestamp = DateTimeOffset.UtcNow.ToUnixTimeSeconds(), + DayOfWeek = DateTime.UtcNow.DayOfWeek.ToString(), + IsWeekend = DateTime.UtcNow.DayOfWeek is DayOfWeek.Saturday or DayOfWeek.Sunday + }; + + return CallToolResult.FromContent( + TextContent.Create($"Time Information: {System.Text.Json.JsonSerializer.Serialize(timeInfo, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/Tools/UserTools.cs b/samples/DynamicToolFiltering/Tools/UserTools.cs new file mode 100644 index 00000000..8f26eb10 --- /dev/null +++ b/samples/DynamicToolFiltering/Tools/UserTools.cs @@ -0,0 +1,138 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace DynamicToolFiltering.Tools; + +/// +/// User-level tools that require basic authentication. +/// These tools are available to authenticated users with basic permissions. +/// +public class UserTools +{ + /// + /// Get current user profile information. + /// + [McpServerTool(Name = "get_user_profile", Description = "Get current user's profile information")] + public static async Task GetUserProfileAsync( + RequestContext context, + CancellationToken cancellationToken = default) + { + await Task.Delay(100, cancellationToken); + + // Extract user information from the request context + var userId = context.Session.ClientInfo?.Name ?? "anonymous"; + + var userProfile = new + { + UserId = userId, + AuthenticatedAt = DateTime.UtcNow.ToString("O"), + SessionId = context.Session.ToString(), + ClientInfo = context.Session.ClientInfo?.Name ?? "Unknown Client", + Permissions = new[] { "user:read", "user:profile" } + }; + + return CallToolResult.FromContent( + TextContent.Create($"User Profile: {System.Text.Json.JsonSerializer.Serialize(userProfile, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } + + /// + /// Calculate hash of provided text. + /// + [McpServerTool(Name = "calculate_hash", Description = "Calculate SHA256 hash of provided text")] + public static async Task CalculateHashAsync( + [Description("Text to hash")] string text, + [Description("Hash algorithm (sha256, sha1, md5)")] string algorithm = "sha256", + CancellationToken cancellationToken = default) + { + await Task.Delay(50, cancellationToken); + + using var hashAlgorithm = algorithm.ToLowerInvariant() switch + { + "sha256" => System.Security.Cryptography.SHA256.Create(), + "sha1" => System.Security.Cryptography.SHA1.Create(), + "md5" => System.Security.Cryptography.MD5.Create(), + _ => System.Security.Cryptography.SHA256.Create() + }; + + var bytes = System.Text.Encoding.UTF8.GetBytes(text); + var hashBytes = hashAlgorithm.ComputeHash(bytes); + var hashString = Convert.ToHexString(hashBytes).ToLowerInvariant(); + + var result = new + { + Algorithm = algorithm, + Input = text, + Hash = hashString, + ComputedAt = DateTime.UtcNow.ToString("O") + }; + + return CallToolResult.FromContent( + TextContent.Create($"Hash Result: {System.Text.Json.JsonSerializer.Serialize(result, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } + + /// + /// Generate a random UUID. + /// + [McpServerTool(Name = "generate_uuid", Description = "Generate a random UUID")] + public static async Task GenerateUuidAsync( + [Description("Number of UUIDs to generate")] int count = 1, + CancellationToken cancellationToken = default) + { + await Task.Delay(20, cancellationToken); + + if (count < 1 || count > 10) + { + return CallToolResult.FromError("Count must be between 1 and 10"); + } + + var uuids = Enumerable.Range(0, count) + .Select(_ => Guid.NewGuid().ToString()) + .ToArray(); + + var result = new + { + Count = count, + UUIDs = uuids, + GeneratedAt = DateTime.UtcNow.ToString("O") + }; + + return CallToolResult.FromContent( + TextContent.Create($"Generated UUIDs: {System.Text.Json.JsonSerializer.Serialize(result, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } + + /// + /// Validate an email address format. + /// + [McpServerTool(Name = "validate_email", Description = "Validate email address format")] + public static async Task ValidateEmailAsync( + [Description("Email address to validate")] string email, + CancellationToken cancellationToken = default) + { + await Task.Delay(30, cancellationToken); + + bool isValid; + string reason = "Valid email format"; + + try + { + var mail = new System.Net.Mail.MailAddress(email); + isValid = mail.Address == email; + } + catch + { + isValid = false; + reason = "Invalid email format"; + } + + var result = new + { + Email = email, + IsValid = isValid, + Reason = reason, + ValidatedAt = DateTime.UtcNow.ToString("O") + }; + + return CallToolResult.FromContent( + TextContent.Create($"Email Validation: {System.Text.Json.JsonSerializer.Serialize(result, new System.Text.Json.JsonSerializerOptions { WriteIndented = true })}")); + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/appsettings.Development.json b/samples/DynamicToolFiltering/appsettings.Development.json new file mode 100644 index 00000000..9d11feb3 --- /dev/null +++ b/samples/DynamicToolFiltering/appsettings.Development.json @@ -0,0 +1,35 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Debug", + "Microsoft.AspNetCore": "Information", + "DynamicToolFiltering": "Debug", + "ModelContextProtocol": "Debug" + } + }, + "Filtering": { + "RateLimiting": { + "WindowMinutes": 5, + "RoleLimits": { + "guest": 20, + "user": 200, + "premium": 1000, + "admin": -1 + } + }, + "TimeBased": { + "Enabled": false + }, + "TenantIsolation": { + "Enabled": false + }, + "BusinessLogic": { + "QuotaManagement": { + "Enabled": false + }, + "EnvironmentRestrictions": { + "Enabled": false + } + } + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/appsettings.Production.json b/samples/DynamicToolFiltering/appsettings.Production.json new file mode 100644 index 00000000..52ef2bad --- /dev/null +++ b/samples/DynamicToolFiltering/appsettings.Production.json @@ -0,0 +1,63 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Warning", + "Microsoft.AspNetCore": "Error", + "DynamicToolFiltering": "Information" + } + }, + "Jwt": { + "SecretKey": "${JWT_SECRET_KEY}", + "Issuer": "${JWT_ISSUER}", + "Audience": "${JWT_AUDIENCE}" + }, + "Filtering": { + "RateLimiting": { + "WindowMinutes": 60, + "RoleLimits": { + "guest": 5, + "user": 50, + "premium": 200, + "admin": 500, + "super_admin": -1 + }, + "ToolLimits": { + "premium_performance_benchmark": 2, + "admin_*": 20 + } + }, + "TimeBased": { + "Enabled": true, + "BusinessHours": { + "Enabled": true, + "StartTime": "08:00", + "EndTime": "18:00", + "BusinessDays": [ "Monday", "Tuesday", "Wednesday", "Thursday", "Friday" ], + "RestrictedTools": [ "admin_*", "premium_performance_benchmark" ] + } + }, + "TenantIsolation": { + "Enabled": true + }, + "BusinessLogic": { + "QuotaManagement": { + "Enabled": true, + "QuotaPeriodDays": 30, + "RoleQuotas": { + "guest": 100, + "user": 1000, + "premium": 5000, + "admin": -1 + } + }, + "EnvironmentRestrictions": { + "Enabled": true, + "ProductionRestrictedTools": [ + "admin_force_gc", + "admin_list_processes", + "premium_performance_benchmark" + ] + } + } + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/appsettings.json b/samples/DynamicToolFiltering/appsettings.json new file mode 100644 index 00000000..42b7a5e1 --- /dev/null +++ b/samples/DynamicToolFiltering/appsettings.json @@ -0,0 +1,147 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning", + "DynamicToolFiltering": "Debug" + } + }, + "AllowedHosts": "*", + "Jwt": { + "SecretKey": "your-256-bit-secret-key-here-make-it-secure-and-change-in-production", + "Issuer": "dynamic-tool-filtering-demo", + "Audience": "mcp-api-clients", + "ExpirationMinutes": 60 + }, + "Filtering": { + "Enabled": true, + "DefaultBehavior": "deny", + "RoleBased": { + "Enabled": true, + "Priority": 100, + "RoleClaimType": "role", + "ToolRoleMapping": { + "admin_*": [ "admin", "super_admin" ], + "premium_*": [ "premium", "admin", "super_admin" ], + "*_user_*": [ "user", "premium", "admin", "super_admin" ], + "get_*": [ "guest", "user", "premium", "admin", "super_admin" ], + "echo": [ "guest", "user", "premium", "admin", "super_admin" ], + "get_utc_time": [ "guest", "user", "premium", "admin", "super_admin" ], + "*": [ "user", "premium", "admin", "super_admin" ] + }, + "UseHierarchicalRoles": true, + "RoleHierarchy": [ "super_admin", "admin", "premium", "user", "guest" ] + }, + "TimeBased": { + "Enabled": false, + "Priority": 200, + "TimeZone": "UTC", + "BusinessHours": { + "Enabled": false, + "StartTime": "09:00", + "EndTime": "17:00", + "BusinessDays": [ "Monday", "Tuesday", "Wednesday", "Thursday", "Friday" ], + "RestrictedTools": [ "admin_*" ] + }, + "MaintenanceWindows": [] + }, + "ScopeBased": { + "Enabled": true, + "Priority": 150, + "ScopeClaimType": "scope", + "ToolScopeMapping": { + "admin_*": [ "admin:tools" ], + "premium_*": [ "premium:tools" ], + "*_user_*": [ "user:tools" ], + "get_*": [ "read:tools" ], + "echo": [ "basic:tools" ], + "get_utc_time": [ "basic:tools" ], + "*": [ "user:tools" ] + } + }, + "RateLimiting": { + "Enabled": true, + "Priority": 50, + "WindowMinutes": 60, + "RoleLimits": { + "guest": 10, + "user": 100, + "premium": 500, + "admin": 1000, + "super_admin": -1 + }, + "ToolLimits": { + "premium_performance_benchmark": 5, + "admin_*": 50 + }, + "UseSlidingWindow": true + }, + "TenantIsolation": { + "Enabled": false, + "Priority": 75, + "TenantClaimType": "tenant_id", + "TenantHeaderName": "X-Tenant-ID", + "TenantConfigurations": { + "tenant-a": { + "Name": "Tenant A", + "IsActive": true, + "AllowedTools": [ "*" ], + "DeniedTools": [ "admin_*" ], + "CustomRateLimits": { + "premium_*": 10 + } + }, + "tenant-b": { + "Name": "Tenant B", + "IsActive": true, + "AllowedTools": [ "get_*", "echo", "*_user_*", "premium_*" ], + "DeniedTools": [ "admin_*", "premium_performance_benchmark" ], + "CustomRateLimits": {} + }, + "enterprise-tenant": { + "Name": "Enterprise Tenant", + "IsActive": true, + "AllowedTools": [ "*" ], + "DeniedTools": [], + "CustomRateLimits": { + "*": 1000 + } + } + } + }, + "BusinessLogic": { + "Enabled": true, + "Priority": 300, + "FeatureFlags": { + "Enabled": true, + "ToolFeatureMapping": { + "premium_*": "premium_features", + "admin_performance_*": "admin_performance_tools" + }, + "DefaultFeatureFlagState": false + }, + "QuotaManagement": { + "Enabled": false, + "QuotaPeriodDays": 30, + "RoleQuotas": { + "user": 1000, + "premium": 10000, + "admin": -1 + }, + "ToolQuotaCosts": { + "premium_performance_benchmark": 10, + "premium_*": 2, + "*": 1 + } + }, + "EnvironmentRestrictions": { + "Enabled": true, + "ProductionRestrictedTools": [ + "admin_force_gc", + "admin_list_processes" + ], + "DevelopmentOnlyTools": [] + } + } + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/clients/Program.cs b/samples/DynamicToolFiltering/clients/Program.cs new file mode 100644 index 00000000..bf1d736a --- /dev/null +++ b/samples/DynamicToolFiltering/clients/Program.cs @@ -0,0 +1,399 @@ +using DynamicToolFiltering.TestClient; +using Microsoft.Extensions.Logging; +using System.CommandLine; + +/// +/// Test client console application for the Dynamic Tool Filtering MCP server. +/// +/// This application demonstrates how to: +/// 1. Connect to and authenticate with an MCP server +/// 2. Discover available tools based on user role +/// 3. Execute tools with proper error handling +/// 4. Run comprehensive test suites +/// 5. Perform load testing and performance analysis +/// +/// USAGE EXAMPLES: +/// +/// Basic health check: +/// dotnet run -- health --url http://localhost:8080 +/// +/// Discover tools with API key: +/// dotnet run -- discover --url http://localhost:8080 --api-key demo-user-key +/// +/// Execute a specific tool: +/// dotnet run -- execute --url http://localhost:8080 --api-key demo-user-key --tool echo --args '{"message":"Hello World"}' +/// +/// Run comprehensive tests: +/// dotnet run -- test --url http://localhost:8080 +/// +/// Performance testing: +/// dotnet run -- perf --url http://localhost:8080 --users 10 --requests 50 +/// +public class Program +{ + public static async Task Main(string[] args) + { + var rootCommand = new RootCommand("Dynamic Tool Filtering MCP Server Test Client"); + + // Common options + var urlOption = new Option("--url", () => "http://localhost:8080", "Server URL"); + var apiKeyOption = new Option("--api-key", "API key for authentication"); + var verboseOption = new Option("--verbose", "Enable verbose logging"); + + // Health check command + var healthCommand = new Command("health", "Check server health"); + healthCommand.AddOption(urlOption); + healthCommand.AddOption(verboseOption); + healthCommand.SetHandler(async (string url, bool verbose) => + { + using var client = CreateTestClient(url, verbose); + await RunHealthCheckAsync(client); + }, urlOption, verboseOption); + + // Discover tools command + var discoverCommand = new Command("discover", "Discover available tools"); + discoverCommand.AddOption(urlOption); + discoverCommand.AddOption(apiKeyOption); + discoverCommand.AddOption(verboseOption); + discoverCommand.SetHandler(async (string url, string? apiKey, bool verbose) => + { + using var client = CreateTestClient(url, verbose); + await RunDiscoveryAsync(client, apiKey); + }, urlOption, apiKeyOption, verboseOption); + + // Execute tool command + var executeCommand = new Command("execute", "Execute a specific tool"); + executeCommand.AddOption(urlOption); + executeCommand.AddOption(apiKeyOption); + executeCommand.AddOption(verboseOption); + var toolOption = new Option("--tool", "Tool name to execute") { IsRequired = true }; + var argsOption = new Option("--args", "Tool arguments as JSON"); + executeCommand.AddOption(toolOption); + executeCommand.AddOption(argsOption); + executeCommand.SetHandler(async (string url, string? apiKey, bool verbose, string tool, string? args) => + { + using var client = CreateTestClient(url, verbose); + await RunToolExecutionAsync(client, apiKey, tool, args); + }, urlOption, apiKeyOption, verboseOption, toolOption, argsOption); + + // Comprehensive test command + var testCommand = new Command("test", "Run comprehensive test suite"); + testCommand.AddOption(urlOption); + testCommand.AddOption(verboseOption); + testCommand.SetHandler(async (string url, bool verbose) => + { + using var client = CreateTestClient(url, verbose); + await RunComprehensiveTestsAsync(client); + }, urlOption, verboseOption); + + // Performance test command + var perfCommand = new Command("perf", "Run performance tests"); + perfCommand.AddOption(urlOption); + perfCommand.AddOption(verboseOption); + var usersOption = new Option("--users", () => 5, "Number of concurrent users"); + var requestsOption = new Option("--requests", () => 20, "Requests per user"); + perfCommand.AddOption(usersOption); + perfCommand.AddOption(requestsOption); + perfCommand.SetHandler(async (string url, bool verbose, int users, int requests) => + { + using var client = CreateTestClient(url, verbose); + await RunPerformanceTestsAsync(client, users, requests); + }, urlOption, verboseOption, usersOption, requestsOption); + + // Demo command - interactive demonstration + var demoCommand = new Command("demo", "Run interactive demonstration"); + demoCommand.AddOption(urlOption); + demoCommand.AddOption(verboseOption); + demoCommand.SetHandler(async (string url, bool verbose) => + { + using var client = CreateTestClient(url, verbose); + await RunInteractiveDemoAsync(client); + }, urlOption, verboseOption); + + // Add all commands to root + rootCommand.AddCommand(healthCommand); + rootCommand.AddCommand(discoverCommand); + rootCommand.AddCommand(executeCommand); + rootCommand.AddCommand(testCommand); + rootCommand.AddCommand(perfCommand); + rootCommand.AddCommand(demoCommand); + + return await rootCommand.InvokeAsync(args); + } + + private static TestClient CreateTestClient(string url, bool verbose) + { + using var loggerFactory = LoggerFactory.Create(builder => + { + builder.AddConsole(); + builder.SetMinimumLevel(verbose ? LogLevel.Debug : LogLevel.Information); + }); + + var logger = loggerFactory.CreateLogger(); + return new TestClient(url, logger); + } + + private static async Task RunHealthCheckAsync(TestClient client) + { + Console.WriteLine("🏥 Checking server health..."); + + try + { + var health = await client.CheckHealthAsync(); + Console.WriteLine($"✅ Server is healthy!"); + Console.WriteLine($" Status: {health.Status}"); + Console.WriteLine($" Environment: {health.Environment}"); + Console.WriteLine($" Version: {health.Version}"); + Console.WriteLine($" Timestamp: {health.Timestamp:yyyy-MM-dd HH:mm:ss}"); + } + catch (Exception ex) + { + Console.WriteLine($"❌ Health check failed: {ex.Message}"); + } + } + + private static async Task RunDiscoveryAsync(TestClient client, string? apiKey) + { + Console.WriteLine("🔍 Discovering available tools..."); + + if (!string.IsNullOrEmpty(apiKey)) + { + client.SetApiKey(apiKey); + Console.WriteLine($" Using API key: {apiKey}"); + } + else + { + Console.WriteLine(" No authentication (testing public access)"); + } + + try + { + var tools = await client.DiscoverToolsAsync(); + + if (tools.Count == 0) + { + Console.WriteLine(" No tools available"); + return; + } + + Console.WriteLine($"✅ Found {tools.Count} available tools:"); + foreach (var tool in tools) + { + Console.WriteLine($" 📋 {tool.Name}"); + if (!string.IsNullOrEmpty(tool.Description)) + { + Console.WriteLine($" {tool.Description}"); + } + } + } + catch (Exception ex) + { + Console.WriteLine($"❌ Tool discovery failed: {ex.Message}"); + } + } + + private static async Task RunToolExecutionAsync(TestClient client, string? apiKey, string toolName, string? argsJson) + { + Console.WriteLine($"⚡ Executing tool: {toolName}"); + + if (!string.IsNullOrEmpty(apiKey)) + { + client.SetApiKey(apiKey); + Console.WriteLine($" Using API key: {apiKey}"); + } + + try + { + object? arguments = null; + if (!string.IsNullOrEmpty(argsJson)) + { + arguments = System.Text.Json.JsonSerializer.Deserialize(argsJson); + Console.WriteLine($" Arguments: {argsJson}"); + } + + var result = await client.ExecuteToolAsync(toolName, arguments); + + if (result.Success) + { + Console.WriteLine("✅ Tool execution successful!"); + Console.WriteLine($" Response: {result.Content}"); + } + else + { + Console.WriteLine($"❌ Tool execution failed (HTTP {result.StatusCode})"); + Console.WriteLine($" Error: {result.ErrorMessage}"); + } + } + catch (Exception ex) + { + Console.WriteLine($"❌ Tool execution failed: {ex.Message}"); + } + } + + private static async Task RunComprehensiveTestsAsync(TestClient client) + { + Console.WriteLine("🧪 Running comprehensive test suite..."); + Console.WriteLine(); + + try + { + var results = await client.RunComprehensiveTestsAsync(); + + Console.WriteLine("📊 Test Results:"); + Console.WriteLine($" Total: {results.TotalCount}"); + Console.WriteLine($" Passed: {results.PassedCount} ✅"); + Console.WriteLine($" Failed: {results.FailedCount} ❌"); + Console.WriteLine(); + + Console.WriteLine("📋 Detailed Results:"); + foreach (var (name, passed) in results.Tests) + { + var icon = passed ? "✅" : "❌"; + Console.WriteLine($" {icon} {name}"); + } + + var successRate = (double)results.PassedCount / results.TotalCount * 100; + Console.WriteLine(); + Console.WriteLine($"🎯 Success Rate: {successRate:F1}%"); + } + catch (Exception ex) + { + Console.WriteLine($"❌ Test suite failed: {ex.Message}"); + } + } + + private static async Task RunPerformanceTestsAsync(TestClient client, int users, int requests) + { + Console.WriteLine($"🚀 Running performance tests..."); + Console.WriteLine($" Users: {users}"); + Console.WriteLine($" Requests per user: {requests}"); + Console.WriteLine($" Total requests: {users * requests}"); + Console.WriteLine(); + + try + { + var results = await client.RunPerformanceTestsAsync(users, requests); + + Console.WriteLine("📈 Performance Results:"); + Console.WriteLine($" Total Requests: {results.TotalRequests}"); + Console.WriteLine($" Successful: {results.UserResults.Sum(r => r.SuccessfulRequests)}"); + Console.WriteLine($" Failed: {results.UserResults.Sum(r => r.FailedRequests)}"); + Console.WriteLine(); + + Console.WriteLine("⏱️ Response Times:"); + Console.WriteLine($" Average: {results.AverageResponseTime:F2} ms"); + Console.WriteLine($" Minimum: {results.MinResponseTime} ms"); + Console.WriteLine($" Maximum: {results.MaxResponseTime} ms"); + Console.WriteLine($" 95th percentile: {results.P95ResponseTime} ms"); + Console.WriteLine(); + + Console.WriteLine($"🔥 Throughput: {results.ThroughputPerSecond:F2} requests/second"); + + // Performance assessment + if (results.AverageResponseTime < 200) + Console.WriteLine("🎉 Excellent performance!"); + else if (results.AverageResponseTime < 500) + Console.WriteLine("👍 Good performance"); + else + Console.WriteLine("⚠️ Performance could be improved"); + } + catch (Exception ex) + { + Console.WriteLine($"❌ Performance test failed: {ex.Message}"); + } + } + + private static async Task RunInteractiveDemoAsync(TestClient client) + { + Console.WriteLine("🎭 Interactive Demo - Dynamic Tool Filtering"); + Console.WriteLine("=========================================="); + Console.WriteLine(); + + // Step 1: Health check + Console.WriteLine("Step 1: Checking server health..."); + await RunHealthCheckAsync(client); + Console.WriteLine(); + + // Step 2: Demonstrate role-based access + Console.WriteLine("Step 2: Demonstrating role-based access control..."); + + var roles = new Dictionary + { + ["Guest"] = "demo-guest-key", + ["User"] = "demo-user-key", + ["Premium"] = "demo-premium-key", + ["Admin"] = "demo-admin-key" + }; + + foreach (var (roleName, apiKey) in roles) + { + Console.WriteLine($"\n 🎭 Testing as {roleName} user:"); + client.SetApiKey(apiKey); + + try + { + var tools = await client.DiscoverToolsAsync(); + Console.WriteLine($" Visible tools: {tools.Count}"); + + if (tools.Count > 0) + { + Console.WriteLine($" Examples: {string.Join(", ", tools.Take(3).Select(t => t.Name))}"); + } + } + catch (Exception ex) + { + Console.WriteLine($" Error: {ex.Message}"); + } + } + + Console.WriteLine(); + + // Step 3: Demonstrate tool execution + Console.WriteLine("Step 3: Demonstrating tool execution..."); + + client.SetApiKey("demo-user-key"); + + var testCases = new[] + { + new { Tool = "echo", Args = new { message = "Hello from demo!" }, Description = "Public tool" }, + new { Tool = "get_user_profile", Args = (object)new { }, Description = "User tool" } + }; + + foreach (var testCase in testCases) + { + Console.WriteLine($"\n ⚡ Executing {testCase.Description}: {testCase.Tool}"); + var result = await client.ExecuteToolAsync(testCase.Tool, testCase.Args); + + if (result.Success) + { + Console.WriteLine(" ✅ Success!"); + } + else + { + Console.WriteLine($" ❌ Failed: {result.ErrorMessage}"); + } + } + + // Step 4: Demonstrate authorization failure + Console.WriteLine("\nStep 4: Demonstrating authorization controls..."); + Console.WriteLine(" 🚫 Attempting to access admin tool with user credentials:"); + + var unauthorizedResult = await client.ExecuteToolAsync("admin_get_system_diagnostics"); + if (!unauthorizedResult.Success) + { + Console.WriteLine(" ✅ Access correctly denied - security working!"); + } + else + { + Console.WriteLine(" ⚠️ Unexpected: Access was granted"); + } + + Console.WriteLine(); + Console.WriteLine("🎉 Demo completed! The Dynamic Tool Filtering system is working correctly."); + Console.WriteLine(); + Console.WriteLine("Next steps:"); + Console.WriteLine("- Run comprehensive tests: dotnet run test"); + Console.WriteLine("- Try performance testing: dotnet run perf"); + Console.WriteLine("- Explore with different API keys: dotnet run discover --api-key demo-admin-key"); + } +} \ No newline at end of file diff --git a/samples/DynamicToolFiltering/clients/README.md b/samples/DynamicToolFiltering/clients/README.md new file mode 100644 index 00000000..a43e0d2f --- /dev/null +++ b/samples/DynamicToolFiltering/clients/README.md @@ -0,0 +1,443 @@ +# Dynamic Tool Filtering Test Client + +A comprehensive test client for the Dynamic Tool Filtering MCP server that demonstrates client-side integration patterns and provides automated testing capabilities. + +## Features + +- **Health Monitoring**: Check server availability and status +- **Tool Discovery**: Explore available tools based on user permissions +- **Authentication Testing**: Test various authentication methods (API key, JWT) +- **Role-based Access Control Validation**: Verify hierarchical permission enforcement +- **Performance Testing**: Load testing with concurrent users +- **Error Handling Demonstration**: Test edge cases and error scenarios +- **Interactive Demo Mode**: Guided demonstration of server capabilities + +## Quick Start + +### Prerequisites + +- .NET 9.0 SDK +- Running Dynamic Tool Filtering MCP server + +### Build and Run + +```bash +# Navigate to the client directory +cd clients + +# Restore dependencies +dotnet restore + +# Build the client +dotnet build + +# Run basic health check +dotnet run -- health + +# Discover tools with user permissions +dotnet run -- discover --api-key demo-user-key + +# Execute a tool +dotnet run -- execute --tool echo --args '{"message":"Hello World"}' --api-key demo-user-key + +# Run comprehensive test suite +dotnet run -- test + +# Run performance tests +dotnet run -- perf --users 5 --requests 20 + +# Interactive demonstration +dotnet run -- demo +``` + +## Commands + +### Health Check + +Check if the server is running and responding correctly: + +```bash +dotnet run -- health --url http://localhost:8080 +``` + +**Expected Output:** +``` +🏥 Checking server health... +✅ Server is healthy! + Status: healthy + Environment: Development + Version: 1.0.0 + Timestamp: 2024-01-01 12:00:00 +``` + +### Tool Discovery + +Discover available tools for different user roles: + +```bash +# Public access (no authentication) +dotnet run -- discover + +# Guest user access +dotnet run -- discover --api-key demo-guest-key + +# User access +dotnet run -- discover --api-key demo-user-key + +# Premium user access +dotnet run -- discover --api-key demo-premium-key + +# Admin access +dotnet run -- discover --api-key demo-admin-key +``` + +**Expected Behavior:** +- Guest users see only public tools +- User role sees public + user tools +- Premium role sees public + user + premium tools +- Admin role sees all tools + +### Tool Execution + +Execute specific tools with different authentication levels: + +```bash +# Execute public tool (no auth required) +dotnet run -- execute --tool echo --args '{"message":"Hello World"}' + +# Execute user tool +dotnet run -- execute --tool get_user_profile --api-key demo-user-key + +# Execute premium tool +dotnet run -- execute --tool premium_generate_secure_random --args '{"byteCount":32}' --api-key demo-premium-key + +# Execute admin tool +dotnet run -- execute --tool admin_get_system_diagnostics --api-key demo-admin-key + +# Test authorization failure +dotnet run -- execute --tool admin_get_system_diagnostics --api-key demo-user-key +``` + +### Comprehensive Testing + +Run a full test suite covering all functionality: + +```bash +dotnet run -- test --url http://localhost:8080 --verbose +``` + +**Test Categories:** +- Health check validation +- Authentication method testing +- Role-based access control verification +- Tool execution success/failure scenarios +- Error handling validation +- Rate limiting enforcement + +**Expected Output:** +``` +🧪 Running comprehensive test suite... + +📊 Test Results: + Total: 15 + Passed: 14 ✅ + Failed: 1 ❌ + +📋 Detailed Results: + ✅ Health Check + ✅ API Key Authentication - Guest + ✅ API Key Authentication - User + ✅ API Key Authentication - Premium + ✅ API Key Authentication - Admin + ✅ Hierarchical Role Access + ✅ Role Restriction Enforcement + ✅ Public Tool Execution + ✅ Authenticated Tool Execution + ✅ Premium Tool Execution + ✅ Invalid Tool Handling + ✅ Invalid Arguments Handling + ✅ Invalid API Key Handling + ✅ Rate Limiting Enforcement + ✅ Rate Limiting Allows Some Requests + +🎯 Success Rate: 93.3% +``` + +### Performance Testing + +Test server performance under load: + +```bash +# Basic performance test +dotnet run -- perf + +# Custom load test +dotnet run -- perf --users 10 --requests 50 + +# Stress test +dotnet run -- perf --users 20 --requests 100 --url http://localhost:8080 +``` + +**Expected Output:** +``` +🚀 Running performance tests... + Users: 10 + Requests per user: 50 + Total requests: 500 + +📈 Performance Results: + Total Requests: 500 + Successful: 485 + Failed: 15 + +⏱️ Response Times: + Average: 45.67 ms + Minimum: 12 ms + Maximum: 234 ms + 95th percentile: 156 ms + +🔥 Throughput: 127.32 requests/second +🎉 Excellent performance! +``` + +### Interactive Demo + +Run a guided demonstration: + +```bash +dotnet run -- demo +``` + +This command provides a step-by-step walkthrough of the server's capabilities, showing: +1. Server health verification +2. Role-based access control demonstration +3. Tool execution examples +4. Authorization enforcement + +## Integration Examples + +### Custom Client Development + +Use the test client as a reference for developing your own MCP clients: + +```csharp +using DynamicToolFiltering.TestClient; + +// Create client instance +using var client = new TestClient("http://localhost:8080"); + +// Authenticate +client.SetApiKey("demo-user-key"); + +// Discover tools +var tools = await client.DiscoverToolsAsync(); +Console.WriteLine($"Available tools: {string.Join(", ", tools.Select(t => t.Name))}"); + +// Execute tool +var result = await client.ExecuteToolAsync("echo", new { message = "Hello from custom client!" }); +if (result.Success) +{ + Console.WriteLine($"Tool response: {result.Content}"); +} +``` + +### Automated Testing Integration + +Integrate with test frameworks: + +```csharp +[Test] +public async Task TestServerHealth() +{ + using var client = new TestClient("http://localhost:8080"); + var health = await client.CheckHealthAsync(); + Assert.AreEqual("healthy", health.Status); +} + +[Test] +public async Task TestRoleBasedAccess() +{ + using var client = new TestClient("http://localhost:8080"); + + // User should see limited tools + client.SetApiKey("demo-user-key"); + var userTools = await client.DiscoverToolsAsync(); + + // Admin should see more tools + client.SetApiKey("demo-admin-key"); + var adminTools = await client.DiscoverToolsAsync(); + + Assert.GreaterOrEqual(adminTools.Count, userTools.Count); +} +``` + +### CI/CD Integration + +Use in continuous integration pipelines: + +```yaml +# .github/workflows/integration-test.yml +name: Integration Tests + +on: [push, pull_request] + +jobs: + integration-test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Setup .NET + uses: actions/setup-dotnet@v3 + with: + dotnet-version: '9.0.x' + + - name: Start MCP Server + run: | + cd samples/DynamicToolFiltering + dotnet run --launch-profile DevelopmentMode & + sleep 30 + + - name: Run Integration Tests + run: | + cd samples/DynamicToolFiltering/clients + dotnet run -- test --url http://localhost:8080 + + - name: Run Performance Tests + run: | + cd samples/DynamicToolFiltering/clients + dotnet run -- perf --users 5 --requests 20 +``` + +## Advanced Usage + +### Custom Authentication + +Extend the client for custom authentication methods: + +```csharp +// JWT token authentication +client.SetBearerToken("your-jwt-token"); + +// Custom header authentication +// (Modify TestClient.cs to add custom auth methods) +``` + +### Performance Monitoring + +Use the client for continuous performance monitoring: + +```csharp +// Monitor response times over time +var results = await client.RunPerformanceTestsAsync(concurrentUsers: 5, requestsPerUser: 20); + +// Alert if performance degrades +if (results.AverageResponseTime > 500) +{ + SendAlert($"Performance degraded: {results.AverageResponseTime}ms average response time"); +} +``` + +### Load Testing + +Scale up for serious load testing: + +```bash +# Heavy load test +dotnet run -- perf --users 50 --requests 100 + +# Sustained load test (run multiple instances) +for i in {1..5}; do + dotnet run -- perf --users 10 --requests 50 & +done +wait +``` + +## Troubleshooting + +### Common Issues + +1. **Connection Refused** + ``` + ❌ Health check failed: Connection refused + ``` + - Ensure the MCP server is running + - Verify the URL and port + - Check firewall settings + +2. **Authentication Failures** + ``` + ❌ Tool discovery failed: Unauthorized + ``` + - Verify API key is correct + - Check server authentication configuration + - Ensure role mappings are configured + +3. **Performance Issues** + ``` + ⚠️ Performance could be improved + ``` + - Check server resource usage + - Verify rate limiting settings + - Consider server scaling + +### Debug Mode + +Enable verbose logging for troubleshooting: + +```bash +dotnet run -- test --verbose +dotnet run -- perf --verbose --users 2 --requests 5 +``` + +### Network Issues + +Test with different server URLs: + +```bash +# Local development +dotnet run -- health --url http://localhost:8080 + +# Docker container +dotnet run -- health --url http://localhost:9080 + +# Remote server +dotnet run -- health --url https://your-server.com +``` + +## Development + +### Building from Source + +```bash +git clone https://github.com/microsoft/mcp-csharp-sdk.git +cd mcp-csharp-sdk/samples/DynamicToolFiltering/clients +dotnet build +``` + +### Adding New Test Scenarios + +Extend the test client by adding methods to `TestClient.cs`: + +```csharp +public async Task TestCustomScenarioAsync() +{ + // Your custom test logic + return true; +} +``` + +Then integrate into the comprehensive test suite. + +### Contributing + +1. Follow the existing code patterns +2. Add comprehensive error handling +3. Include logging for debugging +4. Update documentation for new features +5. Test with various server configurations + +## License + +This test client is part of the MCP C# SDK sample collection and follows the same license terms as the main project. \ No newline at end of file diff --git a/samples/DynamicToolFiltering/clients/TestClient.cs b/samples/DynamicToolFiltering/clients/TestClient.cs new file mode 100644 index 00000000..065a6494 --- /dev/null +++ b/samples/DynamicToolFiltering/clients/TestClient.cs @@ -0,0 +1,545 @@ +using System.Net.Http.Json; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace DynamicToolFiltering.TestClient; + +/// +/// Comprehensive test client for the Dynamic Tool Filtering MCP server. +/// Demonstrates various authentication methods, tool discovery, and execution patterns. +/// +/// USAGE SCENARIOS: +/// 1. Integration testing - Verify server behavior programmatically +/// 2. Load testing - Generate realistic traffic patterns +/// 3. API exploration - Understand server capabilities +/// 4. Client development - Reference implementation for MCP clients +/// +public class TestClient : IDisposable +{ + private readonly HttpClient _httpClient; + private readonly string _baseUrl; + private readonly ILogger _logger; + private readonly JsonSerializerOptions _jsonOptions; + + public TestClient(string baseUrl = "http://localhost:8080", ILogger? logger = null) + { + _baseUrl = baseUrl; + _httpClient = new HttpClient { BaseAddress = new Uri(baseUrl) }; + _logger = logger ?? CreateConsoleLogger(); + + _jsonOptions = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + WriteIndented = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }; + } + + #region Authentication Methods + + /// + /// Authenticate using API key authentication. + /// This is the primary authentication method demonstrated in the sample. + /// + public void SetApiKey(string apiKey) + { + _httpClient.DefaultRequestHeaders.Remove("X-API-Key"); + _httpClient.DefaultRequestHeaders.Add("X-API-Key", apiKey); + _logger.LogInformation("API key authentication configured"); + } + + /// + /// Authenticate using JWT Bearer token. + /// Demonstrates integration with OAuth2/OIDC providers. + /// + public void SetBearerToken(string token) + { + _httpClient.DefaultRequestHeaders.Authorization = + new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", token); + _logger.LogInformation("Bearer token authentication configured"); + } + + /// + /// Clear all authentication headers. + /// Useful for testing unauthenticated access scenarios. + /// + public void ClearAuthentication() + { + _httpClient.DefaultRequestHeaders.Remove("X-API-Key"); + _httpClient.DefaultRequestHeaders.Authorization = null; + _logger.LogInformation("Authentication cleared"); + } + + #endregion + + #region Server Health and Discovery + + /// + /// Check server health and retrieve basic information. + /// Should always be the first call to verify server availability. + /// + public async Task CheckHealthAsync() + { + _logger.LogInformation("Checking server health..."); + + try + { + var response = await _httpClient.GetAsync("/health"); + response.EnsureSuccessStatusCode(); + + var health = await response.Content.ReadFromJsonAsync(_jsonOptions); + _logger.LogInformation("Server health: {Status} (Environment: {Environment})", + health?.Status, health?.Environment); + + return health ?? throw new InvalidOperationException("Invalid health response"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Health check failed"); + throw; + } + } + + /// + /// Discover available tools for the authenticated user. + /// Tool visibility varies based on user role and filter configuration. + /// + public async Task> DiscoverToolsAsync() + { + _logger.LogInformation("Discovering available tools..."); + + try + { + var response = await _httpClient.GetAsync("/mcp/v1/tools"); + response.EnsureSuccessStatusCode(); + + var toolsResponse = await response.Content.ReadFromJsonAsync(_jsonOptions); + var tools = toolsResponse?.Result?.Tools ?? new List(); + + _logger.LogInformation("Discovered {Count} tools: {ToolNames}", + tools.Count, string.Join(", ", tools.Select(t => t.Name))); + + return tools; + } + catch (Exception ex) + { + _logger.LogError(ex, "Tool discovery failed"); + throw; + } + } + + #endregion + + #region Tool Execution + + /// + /// Execute a tool with the provided arguments. + /// Demonstrates the core MCP tool execution pattern. + /// + public async Task ExecuteToolAsync(string toolName, object? arguments = null) + { + _logger.LogInformation("Executing tool: {ToolName}", toolName); + + try + { + var request = new ToolExecutionRequest + { + Name = toolName, + Arguments = arguments ?? new { } + }; + + var response = await _httpClient.PostAsJsonAsync("/mcp/v1/tools/call", request, _jsonOptions); + + if (!response.IsSuccessStatusCode) + { + var errorContent = await response.Content.ReadAsStringAsync(); + _logger.LogError("Tool execution failed: {StatusCode} - {Content}", + response.StatusCode, errorContent); + + return new ToolExecutionResult + { + Success = false, + StatusCode = (int)response.StatusCode, + ErrorMessage = errorContent + }; + } + + var resultContent = await response.Content.ReadAsStringAsync(); + _logger.LogInformation("Tool execution successful: {ToolName}", toolName); + + return new ToolExecutionResult + { + Success = true, + StatusCode = 200, + Content = resultContent + }; + } + catch (Exception ex) + { + _logger.LogError(ex, "Tool execution failed: {ToolName}", toolName); + return new ToolExecutionResult + { + Success = false, + ErrorMessage = ex.Message + }; + } + } + + #endregion + + #region Test Scenarios + + /// + /// Run a comprehensive test suite covering all major functionality. + /// Useful for integration testing and validation. + /// + public async Task RunComprehensiveTestsAsync() + { + var results = new TestResults(); + + _logger.LogInformation("Starting comprehensive test suite..."); + + // Test 1: Health Check + results.AddTest("Health Check", await TestHealthAsync()); + + // Test 2: Unauthenticated Access + results.AddTest("Unauthenticated Tool Discovery", await TestUnauthenticatedAccessAsync()); + + // Test 3: Authentication Methods + await TestAuthenticationMethodsAsync(results); + + // Test 4: Role-based Access Control + await TestRoleBasedAccessAsync(results); + + // Test 5: Tool Execution + await TestToolExecutionAsync(results); + + // Test 6: Error Handling + await TestErrorHandlingAsync(results); + + // Test 7: Rate Limiting (if enabled) + await TestRateLimitingAsync(results); + + _logger.LogInformation("Test suite completed: {Passed}/{Total} tests passed", + results.PassedCount, results.TotalCount); + + return results; + } + + private async Task TestHealthAsync() + { + try + { + var health = await CheckHealthAsync(); + return health.Status == "healthy"; + } + catch + { + return false; + } + } + + private async Task TestUnauthenticatedAccessAsync() + { + ClearAuthentication(); + try + { + var tools = await DiscoverToolsAsync(); + // Should be able to discover at least public tools + return tools.Count > 0; + } + catch + { + return false; + } + } + + private async Task TestAuthenticationMethodsAsync(TestResults results) + { + var apiKeys = new Dictionary + { + ["Guest"] = "demo-guest-key", + ["User"] = "demo-user-key", + ["Premium"] = "demo-premium-key", + ["Admin"] = "demo-admin-key" + }; + + foreach (var (role, apiKey) in apiKeys) + { + SetApiKey(apiKey); + var success = await TestToolDiscoveryAsync(); + results.AddTest($"API Key Authentication - {role}", success); + } + } + + private async Task TestToolDiscoveryAsync() + { + try + { + var tools = await DiscoverToolsAsync(); + return tools.Count > 0; + } + catch + { + return false; + } + } + + private async Task TestRoleBasedAccessAsync(TestResults results) + { + // Test hierarchical access - admin should see more tools than user + SetApiKey("demo-user-key"); + var userTools = await DiscoverToolsAsync(); + + SetApiKey("demo-admin-key"); + var adminTools = await DiscoverToolsAsync(); + + results.AddTest("Hierarchical Role Access", adminTools.Count >= userTools.Count); + + // Test role restrictions - user should not access admin tools + SetApiKey("demo-user-key"); + var adminToolResult = await ExecuteToolAsync("admin_get_system_diagnostics"); + results.AddTest("Role Restriction Enforcement", !adminToolResult.Success); + } + + private async Task TestToolExecutionAsync(TestResults results) + { + // Test public tool execution + ClearAuthentication(); + var echoResult = await ExecuteToolAsync("echo", new { message = "test" }); + results.AddTest("Public Tool Execution", echoResult.Success); + + // Test authenticated tool execution + SetApiKey("demo-user-key"); + var profileResult = await ExecuteToolAsync("get_user_profile"); + results.AddTest("Authenticated Tool Execution", profileResult.Success); + + // Test premium tool execution + SetApiKey("demo-premium-key"); + var premiumResult = await ExecuteToolAsync("premium_generate_secure_random", new { byteCount = 16 }); + results.AddTest("Premium Tool Execution", premiumResult.Success); + } + + private async Task TestErrorHandlingAsync(TestResults results) + { + // Test invalid tool name + SetApiKey("demo-user-key"); + var invalidToolResult = await ExecuteToolAsync("nonexistent_tool"); + results.AddTest("Invalid Tool Handling", !invalidToolResult.Success); + + // Test invalid arguments + var invalidArgsResult = await ExecuteToolAsync("echo", new { invalid_argument = "test" }); + results.AddTest("Invalid Arguments Handling", !invalidArgsResult.Success); + + // Test invalid API key + SetApiKey("invalid-key"); + var invalidKeyResult = await ExecuteToolAsync("echo", new { message = "test" }); + results.AddTest("Invalid API Key Handling", !invalidKeyResult.Success); + } + + private async Task TestRateLimitingAsync(TestResults results) + { + SetApiKey("demo-guest-key"); + var successCount = 0; + var rateLimitedCount = 0; + + // Make multiple rapid requests to trigger rate limiting + for (int i = 0; i < 25; i++) + { + var result = await ExecuteToolAsync("echo", new { message = $"rate test {i}" }); + if (result.Success) + successCount++; + else if (result.StatusCode == 429) + rateLimitedCount++; + + await Task.Delay(50); // Small delay to avoid overwhelming the server + } + + // Rate limiting should trigger for guest users with heavy usage + results.AddTest("Rate Limiting Enforcement", rateLimitedCount > 0); + results.AddTest("Rate Limiting Allows Some Requests", successCount > 0); + } + + #endregion + + #region Performance Testing + + /// + /// Run performance tests to measure response times and throughput. + /// Useful for load testing and performance benchmarking. + /// + public async Task RunPerformanceTestsAsync(int concurrentUsers = 5, int requestsPerUser = 20) + { + _logger.LogInformation("Starting performance tests: {Users} users, {Requests} requests each", + concurrentUsers, requestsPerUser); + + var tasks = new List>(); + + for (int i = 0; i < concurrentUsers; i++) + { + var userId = i; + tasks.Add(RunUserPerformanceTestAsync(userId, requestsPerUser)); + } + + var userResults = await Task.WhenAll(tasks); + + var overallResults = new PerformanceResults + { + ConcurrentUsers = concurrentUsers, + TotalRequests = concurrentUsers * requestsPerUser, + UserResults = userResults.ToList() + }; + + overallResults.CalculateStatistics(); + + _logger.LogInformation("Performance test completed: {TotalRequests} requests, " + + "avg response time: {AvgResponseTime}ms, " + + "95th percentile: {P95ResponseTime}ms", + overallResults.TotalRequests, + overallResults.AverageResponseTime, + overallResults.P95ResponseTime); + + return overallResults; + } + + private async Task RunUserPerformanceTestAsync(int userId, int requestCount) + { + var client = new HttpClient { BaseAddress = new Uri(_baseUrl) }; + client.DefaultRequestHeaders.Add("X-API-Key", "demo-user-key"); + + var result = new UserPerformanceResult { UserId = userId }; + + for (int i = 0; i < requestCount; i++) + { + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + try + { + var response = await client.PostAsJsonAsync("/mcp/v1/tools/call", + new ToolExecutionRequest + { + Name = "echo", + Arguments = new { message = $"perf test user {userId} request {i}" } + }); + + stopwatch.Stop(); + + result.ResponseTimes.Add(stopwatch.ElapsedMilliseconds); + + if (response.IsSuccessStatusCode) + result.SuccessfulRequests++; + else + result.FailedRequests++; + } + catch + { + stopwatch.Stop(); + result.FailedRequests++; + } + } + + client.Dispose(); + return result; + } + + #endregion + + #region Utility Methods + + private static ILogger CreateConsoleLogger() + { + using var loggerFactory = LoggerFactory.Create(builder => + builder.AddConsole().SetMinimumLevel(LogLevel.Information)); + return loggerFactory.CreateLogger(); + } + + public void Dispose() + { + _httpClient?.Dispose(); + } + + #endregion +} + +#region Data Models + +public record HealthResponse(string Status, string Environment, string Version, DateTime Timestamp); + +public record ToolInfo(string Name, string Description, object? InputSchema = null); + +public record ToolsResponse(ToolsResult Result); +public record ToolsResult(List Tools); + +public record ToolExecutionRequest +{ + public string Name { get; init; } = ""; + public object Arguments { get; init; } = new { }; +} + +public record ToolExecutionResult +{ + public bool Success { get; init; } + public int StatusCode { get; init; } + public string? Content { get; init; } + public string? ErrorMessage { get; init; } +} + +public class TestResults +{ + private readonly List<(string Name, bool Passed)> _tests = new(); + + public void AddTest(string name, bool passed) + { + _tests.Add((name, passed)); + } + + public int TotalCount => _tests.Count; + public int PassedCount => _tests.Count(t => t.Passed); + public int FailedCount => _tests.Count(t => !t.Passed); + + public IReadOnlyList<(string Name, bool Passed)> Tests => _tests.AsReadOnly(); +} + +public record UserPerformanceResult +{ + public int UserId { get; init; } + public List ResponseTimes { get; } = new(); + public int SuccessfulRequests { get; set; } + public int FailedRequests { get; set; } +} + +public class PerformanceResults +{ + public int ConcurrentUsers { get; init; } + public int TotalRequests { get; init; } + public List UserResults { get; init; } = new(); + + public double AverageResponseTime { get; private set; } + public long P95ResponseTime { get; private set; } + public long MinResponseTime { get; private set; } + public long MaxResponseTime { get; private set; } + public double ThroughputPerSecond { get; private set; } + + public void CalculateStatistics() + { + var allResponseTimes = UserResults.SelectMany(r => r.ResponseTimes).ToList(); + + if (allResponseTimes.Count > 0) + { + allResponseTimes.Sort(); + + AverageResponseTime = allResponseTimes.Average(); + MinResponseTime = allResponseTimes.First(); + MaxResponseTime = allResponseTimes.Last(); + + var p95Index = (int)(allResponseTimes.Count * 0.95); + P95ResponseTime = allResponseTimes[Math.Min(p95Index, allResponseTimes.Count - 1)]; + + var totalSuccessful = UserResults.Sum(r => r.SuccessfulRequests); + var totalTime = allResponseTimes.Sum() / 1000.0; // Convert to seconds + ThroughputPerSecond = totalTime > 0 ? totalSuccessful / totalTime : 0; + } + } +} + +#endregion \ No newline at end of file diff --git a/samples/DynamicToolFiltering/clients/TestClient.csproj b/samples/DynamicToolFiltering/clients/TestClient.csproj new file mode 100644 index 00000000..b86df756 --- /dev/null +++ b/samples/DynamicToolFiltering/clients/TestClient.csproj @@ -0,0 +1,38 @@ + + + + Exe + net9.0 + enable + enable + DynamicToolFiltering.TestClient + DynamicToolFiltering.TestClient + + + Dynamic Tool Filtering Test Client + Comprehensive test client for the Dynamic Tool Filtering MCP server + 1.0.0 + 1.0.0 + + + false + + true + + + + + + + + + + + + + + PreserveNewest + + + + \ No newline at end of file diff --git a/samples/DynamicToolFiltering/docker-compose.yml b/samples/DynamicToolFiltering/docker-compose.yml new file mode 100644 index 00000000..e556f1d1 --- /dev/null +++ b/samples/DynamicToolFiltering/docker-compose.yml @@ -0,0 +1,155 @@ +# Dynamic Tool Filtering - Docker Compose Configuration +# Provides multiple deployment scenarios for development and production + +version: '3.8' + +services: + # Main MCP Server + dynamic-tool-filtering: + build: + context: . + dockerfile: Dockerfile + ports: + - "${PORT:-8080}:8080" + environment: + # Core settings + - ASPNETCORE_ENVIRONMENT=${ENVIRONMENT:-Development} + - ASPNETCORE_URLS=http://+:8080 + + # Filter configuration + - Filtering__Enabled=${FILTERING_ENABLED:-true} + - Filtering__RoleBased__Enabled=${ROLE_BASED_ENABLED:-true} + - Filtering__TimeBased__Enabled=${TIME_BASED_ENABLED:-false} + - Filtering__ScopeBased__Enabled=${SCOPE_BASED_ENABLED:-true} + - Filtering__RateLimiting__Enabled=${RATE_LIMITING_ENABLED:-true} + - Filtering__TenantIsolation__Enabled=${TENANT_ISOLATION_ENABLED:-false} + - Filtering__BusinessLogic__Enabled=${BUSINESS_LOGIC_ENABLED:-true} + - Filtering__BusinessLogic__FeatureFlags__Enabled=${FEATURE_FLAGS_ENABLED:-true} + - Filtering__BusinessLogic__QuotaManagement__Enabled=${QUOTA_MANAGEMENT_ENABLED:-false} + + # JWT Configuration + - Jwt__SecretKey=${JWT_SECRET_KEY:-your-256-bit-secret-key-here-make-it-secure-and-change-in-production} + - Jwt__Issuer=${JWT_ISSUER:-dynamic-tool-filtering} + - Jwt__Audience=${JWT_AUDIENCE:-mcp-api} + + # External service connections (for production profile) + - ConnectionStrings__Redis=${REDIS_CONNECTION:-localhost:6379} + - ConnectionStrings__DefaultConnection=${DB_CONNECTION:-} + + volumes: + # Persist logs and data + - ./logs:/app/logs + - ./data:/app/data + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 15s + depends_on: + - redis + - postgres + profiles: + - development + - production + + # Redis for production-ready rate limiting and caching + redis: + image: redis:7-alpine + ports: + - "${REDIS_PORT:-6379}:6379" + command: redis-server --appendonly yes --maxmemory 256mb --maxmemory-policy allkeys-lru + volumes: + - redis_data:/data + restart: unless-stopped + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 30s + timeout: 10s + retries: 3 + profiles: + - production + + # PostgreSQL for quota management and persistent storage + postgres: + image: postgres:15-alpine + ports: + - "${POSTGRES_PORT:-5432}:5432" + environment: + - POSTGRES_DB=${POSTGRES_DB:-dynamic_tool_filtering} + - POSTGRES_USER=${POSTGRES_USER:-mcpuser} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-mcppassword} + - PGDATA=/var/lib/postgresql/data/pgdata + volumes: + - postgres_data:/var/lib/postgresql/data + - ./scripts/init-db.sql:/docker-entrypoint-initdb.d/init-db.sql:ro + restart: unless-stopped + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-mcpuser} -d ${POSTGRES_DB:-dynamic_tool_filtering}"] + interval: 30s + timeout: 10s + retries: 3 + profiles: + - production + + # Monitoring with Prometheus (optional) + prometheus: + image: prom/prometheus:latest + ports: + - "${PROMETHEUS_PORT:-9090}:9090" + volumes: + - ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml:ro + - prometheus_data:/prometheus + command: + - '--config.file=/etc/prometheus/prometheus.yml' + - '--storage.tsdb.path=/prometheus' + - '--web.console.libraries=/etc/prometheus/console_libraries' + - '--web.console.templates=/etc/prometheus/consoles' + - '--web.enable-lifecycle' + restart: unless-stopped + profiles: + - monitoring + + # Grafana for metrics visualization (optional) + grafana: + image: grafana/grafana:latest + ports: + - "${GRAFANA_PORT:-3000}:3000" + environment: + - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_PASSWORD:-admin} + volumes: + - grafana_data:/var/lib/grafana + - ./monitoring/grafana/dashboards:/etc/grafana/provisioning/dashboards:ro + - ./monitoring/grafana/datasources:/etc/grafana/provisioning/datasources:ro + restart: unless-stopped + depends_on: + - prometheus + profiles: + - monitoring + + # Nginx reverse proxy (optional) + nginx: + image: nginx:alpine + ports: + - "${NGINX_PORT:-80}:80" + - "${NGINX_SSL_PORT:-443}:443" + volumes: + - ./nginx/nginx.conf:/etc/nginx/nginx.conf:ro + - ./nginx/ssl:/etc/nginx/ssl:ro + depends_on: + - dynamic-tool-filtering + restart: unless-stopped + profiles: + - production + - proxy + +volumes: + redis_data: + postgres_data: + prometheus_data: + grafana_data: + +networks: + default: + name: dynamic-tool-filtering-network \ No newline at end of file diff --git a/samples/DynamicToolFiltering/docs/PERFORMANCE.md b/samples/DynamicToolFiltering/docs/PERFORMANCE.md new file mode 100644 index 00000000..833b3fb6 --- /dev/null +++ b/samples/DynamicToolFiltering/docs/PERFORMANCE.md @@ -0,0 +1,580 @@ +# Performance Testing and Monitoring Guide + +This guide provides comprehensive information on performance testing, monitoring, and optimization for the Dynamic Tool Filtering MCP server. + +## Table of Contents + +1. [Performance Testing](#performance-testing) +2. [Monitoring Setup](#monitoring-setup) +3. [Benchmarking](#benchmarking) +4. [Optimization Guidelines](#optimization-guidelines) +5. [Production Monitoring](#production-monitoring) +6. [Troubleshooting Performance Issues](#troubleshooting-performance-issues) + +## Performance Testing + +### Load Testing Tools + +#### 1. Apache Bench (ab) + +Basic load testing for HTTP endpoints: + +```bash +# Install Apache Bench +# Ubuntu/Debian: sudo apt-get install apache2-utils +# macOS: brew install httpd + +# Test health endpoint - 1000 requests, 10 concurrent +ab -n 1000 -c 10 http://localhost:8080/health + +# Test tool listing with API key +ab -n 500 -c 5 -H "X-API-Key: demo-user-key" http://localhost:8080/mcp/v1/tools + +# Test tool execution with POST data +ab -n 100 -c 2 -p echo_test.json -T application/json -H "X-API-Key: demo-user-key" http://localhost:8080/mcp/v1/tools/call +``` + +Create `echo_test.json` for POST testing: +```json +{ + "name": "echo", + "arguments": { + "message": "Performance test" + } +} +``` + +#### 2. wrk (Recommended for Advanced Testing) + +```bash +# Install wrk +# Ubuntu: sudo apt install wrk +# macOS: brew install wrk + +# Basic load test +wrk -t4 -c50 -d30s http://localhost:8080/health + +# Test with custom script for authenticated requests +wrk -t4 -c20 -d30s -s auth_test.lua http://localhost:8080/mcp/v1/tools +``` + +Create `auth_test.lua`: +```lua +wrk.method = "GET" +wrk.headers["X-API-Key"] = "demo-user-key" +wrk.headers["Content-Type"] = "application/json" +``` + +#### 3. Artillery.js (Advanced Scenarios) + +Install and configure Artillery for complex test scenarios: + +```bash +npm install -g artillery +``` + +Create `performance-test.yml`: +```yaml +config: + target: 'http://localhost:8080' + phases: + - duration: 60 + arrivalRate: 10 + name: "Warm up" + - duration: 120 + arrivalRate: 20 + name: "Sustained load" + - duration: 60 + arrivalRate: 50 + name: "Peak load" + defaults: + headers: + X-API-Key: "demo-user-key" + +scenarios: + - name: "Mixed workload" + weight: 100 + flow: + - get: + url: "/health" + capture: + - json: "$.Status" + as: "health_status" + - get: + url: "/mcp/v1/tools" + - post: + url: "/mcp/v1/tools/call" + json: + name: "echo" + arguments: + message: "Load test {{ $randomString() }}" + - think: 1 +``` + +Run Artillery test: +```bash +artillery run performance-test.yml +``` + +### Custom Performance Test Script + +Create a comprehensive test script: + +```bash +#!/bin/bash +# scripts/performance-test.sh + +BASE_URL="http://localhost:8080" +USER_KEY="demo-user-key" +ADMIN_KEY="demo-admin-key" + +echo "=== Performance Test Suite ===" + +# Test 1: Health endpoint baseline +echo "Testing health endpoint..." +ab -n 1000 -c 10 -q $BASE_URL/health | grep "Requests per second" + +# Test 2: Tool listing performance +echo "Testing tool listing..." +ab -n 500 -c 5 -H "X-API-Key: $USER_KEY" -q $BASE_URL/mcp/v1/tools | grep "Requests per second" + +# Test 3: Tool execution performance +echo "Testing tool execution..." +echo '{"name": "echo", "arguments": {"message": "perf test"}}' > /tmp/echo_test.json +ab -n 200 -c 4 -p /tmp/echo_test.json -T application/json -H "X-API-Key: $USER_KEY" -q $BASE_URL/mcp/v1/tools/call | grep "Requests per second" + +# Test 4: Rate limiting behavior +echo "Testing rate limiting..." +for i in {1..25}; do + curl -s -w "%{http_code}\n" -o /dev/null \ + -H "X-API-Key: demo-guest-key" \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{"name": "echo", "arguments": {"message": "rate limit test"}}' \ + $BASE_URL/mcp/v1/tools/call +done | sort | uniq -c + +# Test 5: Concurrent filter execution +echo "Testing concurrent requests..." +( +for i in {1..5}; do + (ab -n 50 -c 2 -H "X-API-Key: $USER_KEY" -q $BASE_URL/mcp/v1/tools &) +done +wait +) 2>/dev/null + +echo "Performance tests complete" +``` + +## Monitoring Setup + +### Application Performance Monitoring (APM) + +#### 1. OpenTelemetry Configuration + +The sample already includes OpenTelemetry. Enhance it for production: + +```csharp +// Program.cs - Enhanced telemetry configuration +builder.Services.AddOpenTelemetry() + .WithTracing(b => b + .SetSampler(new TraceIdRatioBasedSampler(0.1)) // 10% sampling + .AddSource("DynamicToolFiltering") + .AddAspNetCoreInstrumentation(options => + { + options.RecordException = true; + options.EnrichWithHttpRequest = (activity, request) => + { + activity.SetTag("user.role", request.Headers["X-API-Key"]); + }; + }) + .AddHttpClientInstrumentation() + .AddSqlClientInstrumentation() + .UseOtlpExporter()) + .WithMetrics(b => b + .AddMeter("DynamicToolFiltering") + .AddAspNetCoreInstrumentation() + .AddRuntimeInstrumentation() + .AddProcessInstrumentation() + .UseOtlpExporter()); +``` + +#### 2. Custom Metrics + +Add custom metrics for filter performance: + +```csharp +// Services/MetricsService.cs +public class MetricsService +{ + private readonly Meter _meter; + private readonly Counter _filterExecutions; + private readonly Histogram _filterDuration; + private readonly Counter _authorizationFailures; + + public MetricsService() + { + _meter = new Meter("DynamicToolFiltering"); + _filterExecutions = _meter.CreateCounter("filter_executions_total"); + _filterDuration = _meter.CreateHistogram("filter_duration_seconds"); + _authorizationFailures = _meter.CreateCounter("authorization_failures_total"); + } + + public void RecordFilterExecution(string filterName, string result) + { + _filterExecutions.Add(1, new("filter", filterName), new("result", result)); + } + + public void RecordFilterDuration(string filterName, double durationSeconds) + { + _filterDuration.Record(durationSeconds, new("filter", filterName)); + } + + public void RecordAuthorizationFailure(string reason, string toolName) + { + _authorizationFailures.Add(1, new("reason", reason), new("tool", toolName)); + } +} +``` + +### Prometheus and Grafana Setup + +#### 1. Prometheus Configuration + +Create `monitoring/prometheus.yml`: + +```yaml +global: + scrape_interval: 15s + evaluation_interval: 15s + +rule_files: + - "alert_rules.yml" + +scrape_configs: + - job_name: 'dynamic-tool-filtering' + static_configs: + - targets: ['dynamic-tool-filtering:8080'] + metrics_path: '/metrics' + scrape_interval: 10s + + - job_name: 'redis' + static_configs: + - targets: ['redis:6379'] + + - job_name: 'postgres' + static_configs: + - targets: ['postgres:5432'] + +alerting: + alertmanagers: + - static_configs: + - targets: + - alertmanager:9093 +``` + +#### 2. Grafana Dashboard + +Create `monitoring/grafana/dashboards/mcp-server.json`: + +```json +{ + "dashboard": { + "id": null, + "title": "Dynamic Tool Filtering MCP Server", + "tags": ["mcp", "performance"], + "timezone": "browser", + "panels": [ + { + "title": "Request Rate", + "type": "graph", + "targets": [ + { + "expr": "rate(http_requests_total[5m])", + "legendFormat": "Requests/sec" + } + ] + }, + { + "title": "Response Times", + "type": "graph", + "targets": [ + { + "expr": "histogram_quantile(0.95, rate(http_request_duration_seconds_bucket[5m]))", + "legendFormat": "95th percentile" + }, + { + "expr": "histogram_quantile(0.50, rate(http_request_duration_seconds_bucket[5m]))", + "legendFormat": "50th percentile" + } + ] + }, + { + "title": "Filter Performance", + "type": "graph", + "targets": [ + { + "expr": "rate(filter_executions_total[5m])", + "legendFormat": "{{filter}} - {{result}}" + } + ] + }, + { + "title": "Authorization Failures", + "type": "graph", + "targets": [ + { + "expr": "rate(authorization_failures_total[5m])", + "legendFormat": "{{reason}}" + } + ] + } + ], + "time": { + "from": "now-1h", + "to": "now" + }, + "refresh": "10s" + } +} +``` + +### Performance Benchmarks + +#### Expected Performance Baselines + +| Endpoint | Concurrent Users | Expected RPS | 95th Percentile | +|----------|-----------------|--------------|-----------------| +| `/health` | 50 | 1000+ | < 50ms | +| `/mcp/v1/tools` (auth) | 20 | 500+ | < 100ms | +| `/mcp/v1/tools/call` (simple) | 10 | 200+ | < 200ms | +| `/mcp/v1/tools/call` (complex) | 5 | 50+ | < 500ms | + +#### Resource Usage Targets + +- **Memory**: < 100MB under normal load +- **CPU**: < 50% on single core under load +- **Response Time**: 95th percentile < 200ms +- **Error Rate**: < 0.1% under normal conditions + +## Optimization Guidelines + +### 1. Filter Performance + +```csharp +// Optimize filter execution with caching +public class CachedRoleBasedFilter : IToolFilter +{ + private readonly IMemoryCache _cache; + private readonly RoleBasedToolFilter _innerFilter; + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken) + { + var cacheKey = $"role_filter_{context.User?.Identity?.Name}_{tool.Name}"; + + return await _cache.GetOrCreateAsync(cacheKey, async entry => + { + entry.SlidingExpiration = TimeSpan.FromMinutes(5); + return await _innerFilter.ShouldIncludeToolAsync(tool, context, cancellationToken); + }); + } +} +``` + +### 2. Rate Limiting Optimization + +```csharp +// Use Redis for distributed rate limiting +public class RedisRateLimitingService : IRateLimitingService +{ + private readonly IDatabase _database; + + public async Task IsAllowedAsync(string userId, string toolName) + { + var key = $"rate_limit:{userId}:{toolName}"; + var current = await _database.StringIncrementAsync(key); + + if (current == 1) + { + await _database.KeyExpireAsync(key, TimeSpan.FromHours(1)); + } + + return current <= GetLimitForUser(userId); + } +} +``` + +### 3. Database Query Optimization + +```csharp +// Optimize quota queries with indexing +public class OptimizedQuotaService : IQuotaService +{ + public async Task HasAvailableQuotaAsync(string userId, string toolName) + { + // Use efficient query with proper indexing + var result = await _context.Database.ExecuteSqlRawAsync( + "SELECT usage_count FROM user_quotas WHERE user_id = {0} AND tool_name = {1}", + userId, toolName); + + return result < GetQuotaLimit(userId, toolName); + } +} +``` + +## Production Monitoring + +### Key Metrics to Monitor + +#### 1. Application Metrics +- Request rate and response times +- Filter execution times +- Authorization success/failure rates +- Rate limiting violations +- Feature flag evaluation times + +#### 2. Infrastructure Metrics +- CPU and memory usage +- Disk I/O and network latency +- Database connection pool usage +- Redis hit/miss ratios + +#### 3. Business Metrics +- Tool usage patterns by role +- Peak usage times +- Most popular tools +- Error patterns and trends + +### Alerting Rules + +Create `monitoring/alert_rules.yml`: + +```yaml +groups: + - name: mcp-server-alerts + rules: + - alert: HighResponseTime + expr: histogram_quantile(0.95, rate(http_request_duration_seconds_bucket[5m])) > 0.5 + for: 2m + annotations: + summary: "High response time detected" + + - alert: HighErrorRate + expr: rate(http_requests_total{status=~"5.."}[5m]) > 0.01 + for: 1m + annotations: + summary: "High error rate detected" + + - alert: MemoryUsageHigh + expr: process_resident_memory_bytes / 1024 / 1024 > 500 + for: 5m + annotations: + summary: "High memory usage" +``` + +## Troubleshooting Performance Issues + +### Common Performance Problems + +#### 1. Slow Filter Execution +```bash +# Check filter execution times in logs +grep "Filter execution" logs/dynamic-tool-filtering-*.log | \ + awk '{print $NF}' | sort -n | tail -10 +``` + +#### 2. Memory Leaks +```bash +# Monitor memory usage over time +docker stats dynamic-tool-filtering --format "table {{.MemUsage}}\t{{.CPUPerc}}" +``` + +#### 3. Database Performance +```sql +-- Check slow queries (PostgreSQL) +SELECT query, calls, total_time, mean_time +FROM pg_stat_statements +ORDER BY mean_time DESC +LIMIT 10; +``` + +### Performance Profiling + +#### Using dotnet-trace +```bash +# Install profiling tools +dotnet tool install -g dotnet-trace +dotnet tool install -g dotnet-counters + +# Collect performance trace +dotnet-trace collect --process-id --providers Microsoft-AspNetCore-Server-Kestrel + +# Monitor real-time counters +dotnet-counters monitor --process-id --counters System.Runtime,Microsoft.AspNetCore.Hosting +``` + +#### Using Application Insights +```csharp +// Add detailed telemetry +builder.Services.AddApplicationInsightsTelemetry(options => +{ + options.EnableAdaptiveSampling = true; + options.EnableQuickPulseMetricStream = true; + options.EnablePerformanceCounterCollectionModule = true; +}); +``` + +### Load Testing in CI/CD + +Create `.github/workflows/performance-test.yml`: + +```yaml +name: Performance Tests + +on: + pull_request: + branches: [ main ] + schedule: + - cron: '0 2 * * *' # Daily at 2 AM + +jobs: + performance-test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Setup .NET + uses: actions/setup-dotnet@v3 + with: + dotnet-version: '9.0.x' + + - name: Build application + run: dotnet build -c Release + + - name: Start application + run: | + dotnet run --launch-profile DevelopmentMode & + sleep 30 # Wait for startup + + - name: Install testing tools + run: | + sudo apt-get update + sudo apt-get install -y apache2-utils + + - name: Run performance tests + run: | + # Health endpoint test + ab -n 1000 -c 10 http://localhost:8080/health + + # API performance test + ab -n 500 -c 5 -H "X-API-Key: demo-user-key" \ + http://localhost:8080/mcp/v1/tools + + - name: Verify performance benchmarks + run: | + # Add custom verification logic + ./scripts/verify-performance.sh +``` + +This comprehensive performance guide provides the foundation for monitoring, testing, and optimizing the Dynamic Tool Filtering MCP server in production environments. \ No newline at end of file diff --git a/samples/DynamicToolFiltering/docs/TROUBLESHOOTING.md b/samples/DynamicToolFiltering/docs/TROUBLESHOOTING.md new file mode 100644 index 00000000..ef6e187f --- /dev/null +++ b/samples/DynamicToolFiltering/docs/TROUBLESHOOTING.md @@ -0,0 +1,714 @@ +# Troubleshooting Guide + +This guide helps you diagnose and resolve common issues with the Dynamic Tool Filtering MCP server. + +## Table of Contents + +1. [Quick Diagnostics](#quick-diagnostics) +2. [Common Issues](#common-issues) +3. [Authentication Problems](#authentication-problems) +4. [Authorization Failures](#authorization-failures) +5. [Performance Issues](#performance-issues) +6. [Configuration Problems](#configuration-problems) +7. [Docker Issues](#docker-issues) +8. [Development Environment Issues](#development-environment-issues) +9. [Logging and Debugging](#logging-and-debugging) +10. [Getting Help](#getting-help) + +## Quick Diagnostics + +### Health Check Commands + +```bash +# Basic health check +curl http://localhost:8080/health + +# Expected response: +# { +# "Status": "healthy", +# "Timestamp": "2024-01-01T12:00:00.000Z", +# "Environment": "Development", +# "Version": "1.0.0" +# } + +# Detailed server information +curl -v http://localhost:8080/health + +# Check if server is listening +netstat -tlnp | grep :8080 +# or +ss -tlnp | grep :8080 +``` + +### Quick Test Script + +```bash +#!/bin/bash +# scripts/quick-diagnose.sh + +BASE_URL="http://localhost:8080" + +echo "=== Quick Diagnostics ===" + +# Test 1: Server responding +echo -n "Server health: " +if curl -s -f "$BASE_URL/health" > /dev/null; then + echo "✅ OK" +else + echo "❌ FAILED - Server not responding" + exit 1 +fi + +# Test 2: Authentication working +echo -n "Authentication: " +response=$(curl -s -w "%{http_code}" -H "X-API-Key: demo-user-key" "$BASE_URL/mcp/v1/tools") +if [[ "${response: -3}" == "200" ]]; then + echo "✅ OK" +else + echo "❌ FAILED - HTTP ${response: -3}" +fi + +# Test 3: Tool execution +echo -n "Tool execution: " +response=$(curl -s -w "%{http_code}" -X POST \ + -H "Content-Type: application/json" \ + -d '{"name": "echo", "arguments": {"message": "test"}}' \ + "$BASE_URL/mcp/v1/tools/call") +if [[ "${response: -3}" == "200" ]]; then + echo "✅ OK" +else + echo "❌ FAILED - HTTP ${response: -3}" +fi + +echo "Diagnostics complete" +``` + +## Common Issues + +### 1. Server Won't Start + +**Symptoms:** +- Application exits immediately +- Port binding errors +- Missing dependencies + +**Solutions:** + +```bash +# Check if port is already in use +lsof -i :8080 +# Kill existing process if needed +kill -9 + +# Verify .NET installation +dotnet --version +# Should show 9.0.x or later + +# Check project dependencies +dotnet restore +dotnet build + +# Run with verbose logging +DOTNET_ENVIRONMENT=Development dotnet run --verbosity diagnostic +``` + +**Common Error Messages:** + +``` +Error: Unable to bind to https://localhost:8080 +Solution: Change port in launchSettings.json or kill process using port 8080 + +Error: Could not load file or assembly +Solution: Run 'dotnet restore' and 'dotnet build' + +Error: The configured user limit (128) on the number of inotify instances has been reached +Solution: Increase inotify limit: echo fs.inotify.max_user_instances=524288 | sudo tee -a /etc/sysctl.conf +``` + +### 2. Tools Not Visible + +**Symptoms:** +- Empty tool list +- Some tools missing for specific roles +- All tools showing regardless of role + +**Diagnosis:** + +```bash +# Check tool visibility for different roles +echo "=== Tool Visibility Test ===" + +for role in "demo-guest-key" "demo-user-key" "demo-premium-key" "demo-admin-key"; do + echo "Role: $role" + curl -s -H "X-API-Key: $role" http://localhost:8080/mcp/v1/tools | \ + jq -r '.result.tools[].name' | sort + echo "" +done +``` + +**Solutions:** + +```bash +# Check filter configuration +grep -r "Filtering" appsettings*.json + +# Verify filters are registered +grep -A 20 "ConfigureFiltering" Program.cs + +# Check logs for filter execution +tail -f logs/dynamic-tool-filtering-*.log | grep -i filter +``` + +### 3. All Tools Visible Regardless of Role + +**Cause:** Filtering disabled or role-based filter not working + +**Solutions:** + +```bash +# Check if filtering is enabled +curl -s http://localhost:8080/health | jq . + +# Verify environment variables +printenv | grep Filtering + +# Check role extraction in logs +# Look for: "User roles extracted: [role_name]" +``` + +## Authentication Problems + +### 1. Invalid API Key Errors + +**Symptoms:** +```json +{ + "error": { + "code": -32002, + "message": "Invalid API key" + } +} +``` + +**Solutions:** + +```bash +# Verify API key format +echo "Your API key should be one of:" +echo "- demo-guest-key (guest role)" +echo "- demo-user-key (user role)" +echo "- demo-premium-key (premium role)" +echo "- demo-admin-key (admin role)" + +# Test with correct key +curl -H "X-API-Key: demo-user-key" http://localhost:8080/mcp/v1/tools + +# Check key extraction in code +grep -A 10 "HandleAuthenticateAsync" Program.cs +``` + +### 2. JWT Token Issues + +**Symptoms:** +```json +{ + "error": { + "code": -32002, + "message": "JWT token validation failed" + } +} +``` + +**Solutions:** + +```bash +# Verify JWT configuration +grep -A 10 "JwtBearer" Program.cs + +# Check token format (should be: Authorization: Bearer ) +# Test JWT token generation +# Use online JWT debugger: jwt.io + +# Verify secret key matches +grep "SecretKey" appsettings*.json +``` + +## Authorization Failures + +### 1. User Trying to Access Admin Tools + +**Expected Behavior:** +```json +{ + "error": { + "code": -32002, + "message": "Access denied for tool 'admin_get_system_diagnostics': Tool requires role(s): admin or super_admin. User has role(s): user", + "data": { + "ToolName": "admin_get_system_diagnostics", + "Reason": "Tool requires role(s): admin or super_admin. User has role(s): user", + "HttpStatusCode": 401, + "RequiresAuthentication": true + } + } +} +``` + +**If this doesn't happen:** + +```bash +# Check role-based filter is enabled +grep "RoleBased.*Enabled" appsettings*.json + +# Verify filter priority +grep -A 5 "RoleBasedToolFilter" Program.cs + +# Check role extraction logic +tail -f logs/dynamic-tool-filtering-*.log | grep -i "role" +``` + +### 2. Rate Limiting Not Working + +**Test Rate Limiting:** + +```bash +# Rapid requests test +for i in {1..25}; do + echo "Request $i:" + curl -s -w "HTTP %{http_code}\n" -o /dev/null \ + -H "X-API-Key: demo-guest-key" \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{"name": "echo", "arguments": {"message": "test"}}' \ + http://localhost:8080/mcp/v1/tools/call + sleep 0.1 +done +``` + +**Expected:** Some requests return HTTP 429 after hitting limit + +**If rate limiting not working:** + +```bash +# Check rate limiting configuration +grep -A 10 "RateLimiting" appsettings*.json + +# Verify rate limiting service registration +grep "IRateLimitingService" Program.cs + +# Check rate limiting logs +tail -f logs/dynamic-tool-filtering-*.log | grep -i "rate" +``` + +## Performance Issues + +### 1. Slow Response Times + +**Diagnosis:** + +```bash +# Measure response times +time curl -s http://localhost:8080/mcp/v1/tools > /dev/null + +# Load test +ab -n 100 -c 5 http://localhost:8080/health + +# Check for memory leaks +# Monitor memory usage over time +while true; do + ps aux | grep DynamicToolFiltering | grep -v grep | awk '{print $4 "%", $6/1024 "MB"}' + sleep 5 +done +``` + +**Solutions:** + +```bash +# Enable performance logging +export DOTNET_EnableEventLog=true + +# Use dotnet-trace for profiling +dotnet-trace collect --process-id $(pgrep -f DynamicToolFiltering) + +# Check for inefficient filters +# Look for filters taking > 100ms +tail -f logs/dynamic-tool-filtering-*.log | grep -E "(Filter.*took|duration.*ms)" +``` + +### 2. High Memory Usage + +**Diagnosis:** + +```bash +# Monitor memory +dotnet-counters monitor --process-id $(pgrep -f DynamicToolFiltering) --counters System.Runtime + +# Check for leaks +dotnet-gcdump collect --process-id $(pgrep -f DynamicToolFiltering) +``` + +**Solutions:** + +```bash +# Review caching configuration +grep -r "MemoryCache" . --include="*.cs" + +# Check for proper disposal +grep -r "using\|Dispose" . --include="*.cs" + +# Optimize garbage collection +export DOTNET_gcServer=1 +export DOTNET_gcConcurrent=1 +``` + +## Configuration Problems + +### 1. Environment Variables Not Working + +**Diagnosis:** + +```bash +# List all environment variables +printenv | grep -i filtering + +# Check configuration binding +grep -A 20 "Configure" Program.cs + +# Test configuration reading +dotnet run -- --help +``` + +**Solutions:** + +```bash +# Set environment variables properly +export Filtering__Enabled=true +export Filtering__RoleBased__Enabled=true + +# Use appsettings file instead +# Edit appsettings.Development.json +``` + +### 2. Launch Profile Issues + +**Problem:** Launch profile not found or not working + +**Solutions:** + +```bash +# List available profiles +grep -A 5 "profiles" Properties/launchSettings.json + +# Run specific profile +dotnet run --launch-profile DevelopmentMode + +# Use environment variables directly +ASPNETCORE_ENVIRONMENT=Development dotnet run +``` + +## Docker Issues + +### 1. Container Won't Start + +**Diagnosis:** + +```bash +# Check container logs +docker logs dynamic-tool-filtering + +# Inspect container +docker inspect dynamic-tool-filtering + +# Check port mapping +docker port dynamic-tool-filtering +``` + +**Solutions:** + +```bash +# Rebuild container +docker build --no-cache -t dynamic-tool-filtering . + +# Run with different port +docker run -p 9080:8080 dynamic-tool-filtering + +# Check Dockerfile configuration +grep EXPOSE Dockerfile +``` + +### 2. Health Check Failing + +**Diagnosis:** + +```bash +# Check health status +docker ps --format "table {{.Names}}\t{{.Status}}" + +# Test health endpoint manually +docker exec dynamic-tool-filtering curl -f http://localhost:8080/health +``` + +**Solutions:** + +```bash +# Increase health check timeout +# Edit docker-compose.yml: +# healthcheck: +# timeout: 30s +# start_period: 60s + +# Disable health check temporarily +docker run --no-healthcheck dynamic-tool-filtering +``` + +## Development Environment Issues + +### 1. VS Code Debugging Not Working + +**Solutions:** + +```bash +# Verify C# extension installed +code --list-extensions | grep ms-dotnettools.csharp + +# Install omnisharp +dotnet tool install -g omnisharp + +# Clear omnisharp cache +rm -rf ~/.omnisharp + +# Reload window in VS Code +# Ctrl+Shift+P -> "Developer: Reload Window" +``` + +### 2. IntelliSense Not Working + +**Solutions:** + +```bash +# Restore packages +dotnet restore + +# Clean and rebuild +dotnet clean && dotnet build + +# Check omnisharp logs in VS Code +# View -> Output -> OmniSharp Log +``` + +## Logging and Debugging + +### 1. Enable Debug Logging + +**appsettings.Development.json:** + +```json +{ + "Logging": { + "LogLevel": { + "Default": "Debug", + "DynamicToolFiltering": "Trace", + "Microsoft.AspNetCore": "Warning" + } + } +} +``` + +### 2. Filter-Specific Logging + +```json +{ + "Logging": { + "LogLevel": { + "DynamicToolFiltering.Authorization.Filters": "Debug" + } + } +} +``` + +### 3. Structured Logging Queries + +```bash +# Find authorization failures +grep "Access denied" logs/dynamic-tool-filtering-*.log + +# Find rate limiting events +grep "Rate limit" logs/dynamic-tool-filtering-*.log + +# Find slow requests +grep -E "took [0-9]{3,}" logs/dynamic-tool-filtering-*.log +``` + +### 4. Real-time Log Monitoring + +```bash +# Monitor all logs +tail -f logs/dynamic-tool-filtering-*.log + +# Monitor specific events +tail -f logs/dynamic-tool-filtering-*.log | grep -E "(ERROR|WARNING|Rate|Auth)" + +# Monitor with highlighting +tail -f logs/dynamic-tool-filtering-*.log | grep --color=always -E "(ERROR|WARNING|$)" +``` + +## Advanced Debugging + +### 1. Enable Request/Response Logging + +```csharp +// Add to Program.cs +app.Use(async (context, next) => +{ + var logger = context.RequestServices.GetRequiredService>(); + logger.LogInformation("Request: {Method} {Path}", context.Request.Method, context.Request.Path); + await next(); + logger.LogInformation("Response: {StatusCode}", context.Response.StatusCode); +}); +``` + +### 2. Filter Execution Tracing + +```csharp +// Add to any filter +public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken) +{ + var stopwatch = Stopwatch.StartNew(); + try + { + var result = await DoFilterLogic(tool, context, cancellationToken); + _logger.LogDebug("Filter {FilterName} for tool {ToolName}: {Result} (took {ElapsedMs}ms)", + GetType().Name, tool.Name, result, stopwatch.ElapsedMilliseconds); + return result; + } + catch (Exception ex) + { + _logger.LogError(ex, "Filter {FilterName} failed for tool {ToolName}", GetType().Name, tool.Name); + throw; + } +} +``` + +### 3. Memory Dump Analysis + +```bash +# Create memory dump +dotnet-dump collect --process-id $(pgrep -f DynamicToolFiltering) + +# Analyze with dotnet-dump +dotnet-dump analyze core_20240101_120000 + +# Common commands in analyzer: +# dumpheap -stat +# gcroot
+# dumpheap -mt +``` + +## Getting Help + +### 1. Collect Diagnostic Information + +Create a diagnostic script: + +```bash +#!/bin/bash +# scripts/collect-diagnostics.sh + +echo "=== System Information ===" +uname -a +dotnet --info + +echo "=== Process Information ===" +ps aux | grep -i dotnet + +echo "=== Port Usage ===" +netstat -tlnp | grep 8080 + +echo "=== Configuration ===" +cat appsettings*.json + +echo "=== Recent Logs ===" +tail -50 logs/dynamic-tool-filtering-*.log + +echo "=== Environment Variables ===" +printenv | grep -E "(DOTNET|ASPNETCORE|Filtering)" + +echo "=== Docker Status ===" +docker ps | grep dynamic-tool-filtering || echo "Not running in Docker" +``` + +### 2. Create Issue Report Template + +```markdown +## Issue Description +Brief description of the problem + +## Environment +- OS: +- .NET Version: +- Docker: Yes/No +- Launch Profile: + +## Steps to Reproduce +1. +2. +3. + +## Expected Behavior +What should happen + +## Actual Behavior +What actually happens + +## Logs +``` +[Include relevant log entries] +``` + +## Configuration +```json +[Include relevant configuration] +``` + +## Additional Context +Any other relevant information +``` + +### 3. Support Channels + +- Check existing issues in the repository +- Review documentation and samples +- Use the diagnostic scripts provided +- Include full error messages and logs +- Provide minimal reproduction steps + +## Preventive Measures + +### 1. Regular Health Checks + +```bash +# Add to crontab for production monitoring +*/5 * * * * curl -f http://localhost:8080/health || echo "Health check failed" | mail -s "MCP Server Alert" admin@company.com +``` + +### 2. Log Rotation + +```bash +# Configure logrotate +echo '/app/logs/*.log { + daily + rotate 30 + compress + missingok + notifempty + copytruncate +}' > /etc/logrotate.d/mcp-server +``` + +### 3. Resource Monitoring + +```bash +# Monitor resource usage +watch -n 5 'ps aux | grep DynamicToolFiltering | grep -v grep' +``` + +This troubleshooting guide covers the most common issues you'll encounter with the Dynamic Tool Filtering MCP server. Keep it handy for quick reference during development and deployment. \ No newline at end of file diff --git a/samples/DynamicToolFiltering/monitoring/alert_rules.yml b/samples/DynamicToolFiltering/monitoring/alert_rules.yml new file mode 100644 index 00000000..abe3c12d --- /dev/null +++ b/samples/DynamicToolFiltering/monitoring/alert_rules.yml @@ -0,0 +1,172 @@ +# Prometheus alerting rules for Dynamic Tool Filtering MCP Server + +groups: + - name: mcp-server-performance + rules: + - alert: HighResponseTime + expr: histogram_quantile(0.95, rate(http_request_duration_seconds_bucket[5m])) > 0.5 + for: 2m + labels: + severity: warning + service: dynamic-tool-filtering + annotations: + summary: "High response time detected" + description: "95th percentile response time is {{ $value }}s, which is above the 500ms threshold" + + - alert: VeryHighResponseTime + expr: histogram_quantile(0.95, rate(http_request_duration_seconds_bucket[5m])) > 1.0 + for: 1m + labels: + severity: critical + service: dynamic-tool-filtering + annotations: + summary: "Very high response time detected" + description: "95th percentile response time is {{ $value }}s, which is critically high" + + - alert: HighErrorRate + expr: rate(http_requests_total{status=~"5.."}[5m]) > 0.01 + for: 1m + labels: + severity: warning + service: dynamic-tool-filtering + annotations: + summary: "High error rate detected" + description: "Error rate is {{ $value | humanizePercentage }}, which is above 1%" + + - alert: CriticalErrorRate + expr: rate(http_requests_total{status=~"5.."}[5m]) > 0.05 + for: 30s + labels: + severity: critical + service: dynamic-tool-filtering + annotations: + summary: "Critical error rate detected" + description: "Error rate is {{ $value | humanizePercentage }}, which is critically high" + + - name: mcp-server-resources + rules: + - alert: HighMemoryUsage + expr: process_resident_memory_bytes / 1024 / 1024 > 500 + for: 5m + labels: + severity: warning + service: dynamic-tool-filtering + annotations: + summary: "High memory usage" + description: "Memory usage is {{ $value }}MB, which is above 500MB threshold" + + - alert: CriticalMemoryUsage + expr: process_resident_memory_bytes / 1024 / 1024 > 1000 + for: 2m + labels: + severity: critical + service: dynamic-tool-filtering + annotations: + summary: "Critical memory usage" + description: "Memory usage is {{ $value }}MB, which is critically high" + + - alert: HighCPUUsage + expr: rate(process_cpu_seconds_total[5m]) * 100 > 80 + for: 3m + labels: + severity: warning + service: dynamic-tool-filtering + annotations: + summary: "High CPU usage" + description: "CPU usage is {{ $value }}%, which is above 80%" + + - name: mcp-server-business + rules: + - alert: HighAuthorizationFailureRate + expr: rate(authorization_failures_total[5m]) > 0.1 + for: 2m + labels: + severity: warning + service: dynamic-tool-filtering + annotations: + summary: "High authorization failure rate" + description: "Authorization failure rate is {{ $value }}/sec, indicating potential security issues" + + - alert: RateLimitingTriggered + expr: increase(rate_limit_violations_total[5m]) > 10 + for: 1m + labels: + severity: info + service: dynamic-tool-filtering + annotations: + summary: "Rate limiting frequently triggered" + description: "Rate limiting has been triggered {{ $value }} times in the last 5 minutes" + + - alert: FilterPerformanceDegraded + expr: histogram_quantile(0.95, rate(filter_duration_seconds_bucket[5m])) > 0.1 + for: 3m + labels: + severity: warning + service: dynamic-tool-filtering + annotations: + summary: "Filter performance degraded" + description: "95th percentile filter execution time is {{ $value }}s, which may impact response times" + + - name: mcp-server-availability + rules: + - alert: ServiceDown + expr: up{job="dynamic-tool-filtering"} == 0 + for: 30s + labels: + severity: critical + service: dynamic-tool-filtering + annotations: + summary: "MCP server is down" + description: "The Dynamic Tool Filtering MCP server is not responding" + + - alert: HealthCheckFailing + expr: probe_success{job="dynamic-tool-filtering"} == 0 + for: 1m + labels: + severity: critical + service: dynamic-tool-filtering + annotations: + summary: "Health check failing" + description: "The health check endpoint is not responding successfully" + + - alert: LowRequestVolume + expr: rate(http_requests_total[5m]) < 0.1 + for: 10m + labels: + severity: info + service: dynamic-tool-filtering + annotations: + summary: "Low request volume" + description: "Request rate is {{ $value }}/sec, which is unusually low" + + - name: external-dependencies + rules: + - alert: RedisConnectionFailed + expr: redis_connected_clients == 0 + for: 1m + labels: + severity: warning + service: redis + annotations: + summary: "Redis connection failed" + description: "No clients connected to Redis, rate limiting may be impacted" + + - alert: DatabaseConnectionIssues + expr: pg_up == 0 + for: 1m + labels: + severity: critical + service: postgres + annotations: + summary: "Database connection failed" + description: "PostgreSQL database is not responding, quota management may be impacted" + + - alert: HighDatabaseConnections + expr: pg_stat_activity_count > 80 + for: 5m + labels: + severity: warning + service: postgres + annotations: + summary: "High database connection count" + description: "Database has {{ $value }} active connections, approaching limit" \ No newline at end of file diff --git a/samples/DynamicToolFiltering/monitoring/prometheus.yml b/samples/DynamicToolFiltering/monitoring/prometheus.yml new file mode 100644 index 00000000..c7d3f745 --- /dev/null +++ b/samples/DynamicToolFiltering/monitoring/prometheus.yml @@ -0,0 +1,62 @@ +# Prometheus configuration for Dynamic Tool Filtering MCP Server + +global: + scrape_interval: 15s + evaluation_interval: 15s + external_labels: + monitor: 'dynamic-tool-filtering' + +rule_files: + - "alert_rules.yml" + +# Scrape configurations for monitoring targets +scrape_configs: + # Main MCP server metrics + - job_name: 'dynamic-tool-filtering' + static_configs: + - targets: ['dynamic-tool-filtering:8080'] + metrics_path: '/metrics' + scrape_interval: 10s + scrape_timeout: 5s + honor_labels: true + params: + format: ['prometheus'] + relabel_configs: + - source_labels: [__address__] + target_label: __param_target + - source_labels: [__param_target] + target_label: instance + - target_label: __address__ + replacement: dynamic-tool-filtering:8080 + + # Redis metrics (if using Redis for rate limiting) + - job_name: 'redis' + static_configs: + - targets: ['redis:6379'] + scrape_interval: 30s + + # PostgreSQL metrics (if using PostgreSQL for quota management) + - job_name: 'postgres' + static_configs: + - targets: ['postgres_exporter:9187'] + scrape_interval: 30s + + # Node exporter for system metrics + - job_name: 'node-exporter' + static_configs: + - targets: ['node-exporter:9100'] + scrape_interval: 30s + +# Alerting configuration +alerting: + alertmanagers: + - static_configs: + - targets: + - alertmanager:9093 + +# Remote write configuration (for external monitoring services) +# remote_write: +# - url: "https://your-monitoring-service/api/v1/write" +# basic_auth: +# username: "your-username" +# password: "your-password" \ No newline at end of file diff --git a/samples/DynamicToolFiltering/scripts/setup-dev.sh b/samples/DynamicToolFiltering/scripts/setup-dev.sh new file mode 100755 index 00000000..af997da9 --- /dev/null +++ b/samples/DynamicToolFiltering/scripts/setup-dev.sh @@ -0,0 +1,925 @@ +#!/bin/bash + +# Dynamic Tool Filtering - Development Environment Setup Script +# This script sets up a complete development environment with VS Code configuration, +# Docker setup, and necessary tools for MCP development. + +set -e + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Helper functions +log_info() { + echo -e "${BLUE}[INFO]${NC} $1" +} + +log_success() { + echo -e "${GREEN}[SUCCESS]${NC} $1" +} + +log_error() { + echo -e "${RED}[ERROR]${NC} $1" +} + +log_warning() { + echo -e "${YELLOW}[WARNING]${NC} $1" +} + +check_dependency() { + local cmd="$1" + local name="$2" + local install_hint="$3" + + if command -v "$cmd" &> /dev/null; then + log_success "$name is installed" + return 0 + else + log_warning "$name is not installed. $install_hint" + return 1 + fi +} + +create_vscode_config() { + log_info "Creating VS Code configuration..." + + mkdir -p .vscode + + # Create launch.json for debugging + cat > .vscode/launch.json << 'EOF' +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Launch (Development)", + "type": "coreclr", + "request": "launch", + "preLaunchTask": "build", + "program": "${workspaceFolder}/bin/Debug/net9.0/DynamicToolFiltering.dll", + "args": [], + "cwd": "${workspaceFolder}", + "console": "integratedTerminal", + "stopAtEntry": false, + "env": { + "ASPNETCORE_ENVIRONMENT": "Development", + "ASPNETCORE_URLS": "http://localhost:8080", + "Filtering__Enabled": "true", + "Filtering__RoleBased__Enabled": "true", + "Filtering__TimeBased__Enabled": "false", + "Filtering__ScopeBased__Enabled": "true", + "Filtering__RateLimiting__Enabled": "true", + "Filtering__TenantIsolation__Enabled": "false", + "Filtering__BusinessLogic__Enabled": "true" + }, + "serverReadyAction": { + "action": "openExternally", + "pattern": "\\bNow listening on:\\s+(https?://\\S+)" + } + }, + { + "name": "Launch (No Filtering)", + "type": "coreclr", + "request": "launch", + "preLaunchTask": "build", + "program": "${workspaceFolder}/bin/Debug/net9.0/DynamicToolFiltering.dll", + "args": [], + "cwd": "${workspaceFolder}", + "console": "integratedTerminal", + "stopAtEntry": false, + "env": { + "ASPNETCORE_ENVIRONMENT": "Development", + "ASPNETCORE_URLS": "http://localhost:8080", + "Filtering__Enabled": "false" + } + }, + { + "name": "Launch (Production Mode)", + "type": "coreclr", + "request": "launch", + "preLaunchTask": "build", + "program": "${workspaceFolder}/bin/Debug/net9.0/DynamicToolFiltering.dll", + "args": [], + "cwd": "${workspaceFolder}", + "console": "integratedTerminal", + "stopAtEntry": false, + "env": { + "ASPNETCORE_ENVIRONMENT": "Production", + "ASPNETCORE_URLS": "http://localhost:8080", + "Filtering__Enabled": "true", + "Filtering__RoleBased__Enabled": "true", + "Filtering__TimeBased__Enabled": "true", + "Filtering__ScopeBased__Enabled": "true", + "Filtering__RateLimiting__Enabled": "true", + "Filtering__TenantIsolation__Enabled": "true", + "Filtering__BusinessLogic__Enabled": "true" + } + }, + { + "name": "Attach to Process", + "type": "coreclr", + "request": "attach", + "processId": "${command:pickProcess}" + } + ] +} +EOF + + # Create tasks.json for build tasks + cat > .vscode/tasks.json << 'EOF' +{ + "version": "2.0.0", + "tasks": [ + { + "label": "build", + "command": "dotnet", + "type": "process", + "args": [ + "build", + "${workspaceFolder}/DynamicToolFiltering.csproj", + "/property:GenerateFullPaths=true", + "/consoleloggerparameters:NoSummary" + ], + "problemMatcher": "$msCompile", + "group": { + "kind": "build", + "isDefault": true + } + }, + { + "label": "publish", + "command": "dotnet", + "type": "process", + "args": [ + "publish", + "${workspaceFolder}/DynamicToolFiltering.csproj", + "/property:GenerateFullPaths=true", + "/consoleloggerparameters:NoSummary" + ], + "problemMatcher": "$msCompile" + }, + { + "label": "watch", + "command": "dotnet", + "type": "process", + "args": [ + "watch", + "run", + "${workspaceFolder}/DynamicToolFiltering.csproj" + ], + "problemMatcher": "$msCompile" + }, + { + "label": "test", + "command": "./scripts/test-all.sh", + "type": "shell", + "group": "test", + "presentation": { + "echo": true, + "reveal": "always", + "focus": false, + "panel": "shared", + "showReuseMessage": true, + "clear": false + }, + "problemMatcher": [] + }, + { + "label": "test-authentication", + "command": "./scripts/test-all.sh", + "type": "shell", + "args": ["authentication"], + "group": "test", + "presentation": { + "echo": true, + "reveal": "always", + "focus": false, + "panel": "shared" + } + }, + { + "label": "test-authorization", + "command": "./scripts/test-all.sh", + "type": "shell", + "args": ["authorization"], + "group": "test", + "presentation": { + "echo": true, + "reveal": "always", + "focus": false, + "panel": "shared" + } + }, + { + "label": "docker-build", + "command": "docker", + "type": "shell", + "args": [ + "build", + "-t", + "dynamic-tool-filtering", + "." + ], + "group": "build", + "presentation": { + "echo": true, + "reveal": "always", + "focus": false, + "panel": "shared" + } + }, + { + "label": "docker-run", + "command": "docker", + "type": "shell", + "args": [ + "run", + "-p", + "8080:8080", + "--rm", + "dynamic-tool-filtering" + ], + "group": "build", + "presentation": { + "echo": true, + "reveal": "always", + "focus": false, + "panel": "shared" + }, + "dependsOn": "docker-build" + } + ] +} +EOF + + # Create settings.json for VS Code settings + cat > .vscode/settings.json << 'EOF' +{ + "dotnet.defaultSolution": "DynamicToolFiltering.csproj", + "omnisharp.enableEditorConfigSupport": true, + "omnisharp.enableImportCompletion": true, + "omnisharp.enableRoslynAnalyzers": true, + "files.exclude": { + "**/bin": true, + "**/obj": true, + "**/.vs": true + }, + "files.watcherExclude": { + "**/bin/**": true, + "**/obj/**": true, + "**/.vs/**": true + }, + "csharp.semanticHighlighting.enabled": true, + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "explicit" + }, + "json.schemas": [ + { + "fileMatch": ["appsettings*.json"], + "schema": { + "type": "object", + "properties": { + "Filtering": { + "type": "object", + "description": "Tool filtering configuration", + "properties": { + "Enabled": {"type": "boolean"}, + "RoleBased": {"type": "object"}, + "TimeBased": {"type": "object"}, + "ScopeBased": {"type": "object"}, + "RateLimiting": {"type": "object"}, + "TenantIsolation": {"type": "object"}, + "BusinessLogic": {"type": "object"} + } + } + } + } + } + ], + "rest-client.environmentVariables": { + "local": { + "baseUrl": "http://localhost:8080", + "guestKey": "demo-guest-key", + "userKey": "demo-user-key", + "premiumKey": "demo-premium-key", + "adminKey": "demo-admin-key" + } + } +} +EOF + + # Create extensions.json for recommended VS Code extensions + cat > .vscode/extensions.json << 'EOF' +{ + "recommendations": [ + "ms-dotnettools.csharp", + "ms-dotnettools.vscode-dotnet-runtime", + "humao.rest-client", + "ms-vscode.vscode-json", + "redhat.vscode-yaml", + "ms-azuretools.vscode-docker", + "github.copilot", + "streetsidesoftware.code-spell-checker", + "esbenp.prettier-vscode", + "bradlc.vscode-tailwindcss" + ] +} +EOF + + log_success "VS Code configuration created" +} + +create_docker_config() { + log_info "Creating Docker configuration..." + + # Create Dockerfile + cat > Dockerfile << 'EOF' +# Build stage +FROM mcr.microsoft.com/dotnet/sdk:9.0 AS build +WORKDIR /src + +# Copy project file and restore dependencies +COPY DynamicToolFiltering.csproj . +RUN dotnet restore + +# Copy source code and build +COPY . . +RUN dotnet build -c Release -o /app/build + +# Publish stage +FROM build AS publish +RUN dotnet publish -c Release -o /app/publish --no-restore + +# Runtime stage +FROM mcr.microsoft.com/dotnet/aspnet:9.0 AS runtime +WORKDIR /app + +# Install curl for health checks +RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/* + +# Copy published application +COPY --from=publish /app/publish . + +# Create logs directory +RUN mkdir -p logs + +# Expose port +EXPOSE 8080 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8080/health || exit 1 + +# Start application +ENTRYPOINT ["dotnet", "DynamicToolFiltering.dll"] +EOF + + # Create docker-compose.yml + cat > docker-compose.yml << 'EOF' +version: '3.8' + +services: + dynamic-tool-filtering: + build: . + ports: + - "8080:8080" + environment: + - ASPNETCORE_ENVIRONMENT=Development + - ASPNETCORE_URLS=http://+:8080 + - Filtering__Enabled=true + - Filtering__RoleBased__Enabled=true + - Filtering__TimeBased__Enabled=false + - Filtering__ScopeBased__Enabled=true + - Filtering__RateLimiting__Enabled=true + - Filtering__TenantIsolation__Enabled=false + - Filtering__BusinessLogic__Enabled=true + volumes: + - ./logs:/app/logs + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 10s + + # Optional: Redis for production-ready rate limiting + redis: + image: redis:7-alpine + ports: + - "6379:6379" + restart: unless-stopped + profiles: + - production + + # Optional: PostgreSQL for production-ready quota management + postgres: + image: postgres:15-alpine + ports: + - "5432:5432" + environment: + - POSTGRES_DB=dynamic_tool_filtering + - POSTGRES_USER=mcpuser + - POSTGRES_PASSWORD=mcppassword + volumes: + - postgres_data:/var/lib/postgresql/data + restart: unless-stopped + profiles: + - production + +volumes: + postgres_data: +EOF + + # Create .dockerignore + cat > .dockerignore << 'EOF' +# Build artifacts +bin/ +obj/ +*.dll +*.exe +*.pdb + +# Development files +.vs/ +.vscode/ +*.user +*.suo + +# Logs +logs/ +*.log + +# OS files +.DS_Store +Thumbs.db + +# Git +.git/ +.gitignore + +# Documentation (include in image if needed) +# *.md + +# Test results +TestResults/ +coverage/ + +# Node modules (if any) +node_modules/ + +# Temporary files +*.tmp +*.temp +EOF + + log_success "Docker configuration created" +} + +create_rest_client_config() { + log_info "Creating REST Client configuration for testing..." + + mkdir -p tests + + cat > tests/api-tests.http << 'EOF' +### Dynamic Tool Filtering API Tests +### Use with VS Code REST Client extension + +@baseUrl = http://localhost:8080 +@guestKey = demo-guest-key +@userKey = demo-user-key +@premiumKey = demo-premium-key +@adminKey = demo-admin-key + +### Health Check +GET {{baseUrl}}/health + +### List Tools - Guest User +GET {{baseUrl}}/mcp/v1/tools +X-API-Key: {{guestKey}} + +### List Tools - User +GET {{baseUrl}}/mcp/v1/tools +X-API-Key: {{userKey}} + +### List Tools - Premium User +GET {{baseUrl}}/mcp/v1/tools +X-API-Key: {{premiumKey}} + +### List Tools - Admin User +GET {{baseUrl}}/mcp/v1/tools +X-API-Key: {{adminKey}} + +### Execute Public Tool (Echo) +POST {{baseUrl}}/mcp/v1/tools/call +Content-Type: application/json + +{ + "name": "echo", + "arguments": { + "message": "Hello from REST Client!" + } +} + +### Execute User Tool - Get Profile +POST {{baseUrl}}/mcp/v1/tools/call +Content-Type: application/json +X-API-Key: {{userKey}} + +{ + "name": "get_user_profile", + "arguments": {} +} + +### Execute Premium Tool - Generate Secure Random +POST {{baseUrl}}/mcp/v1/tools/call +Content-Type: application/json +X-API-Key: {{premiumKey}} + +{ + "name": "premium_generate_secure_random", + "arguments": { + "byteCount": 32, + "format": "hex" + } +} + +### Execute Admin Tool - System Diagnostics +POST {{baseUrl}}/mcp/v1/tools/call +Content-Type: application/json +X-API-Key: {{adminKey}} + +{ + "name": "admin_get_system_diagnostics", + "arguments": {} +} + +### Test Authorization Failure - User trying Admin Tool +POST {{baseUrl}}/mcp/v1/tools/call +Content-Type: application/json +X-API-Key: {{userKey}} + +{ + "name": "admin_get_system_diagnostics", + "arguments": {} +} + +### Test Invalid API Key +GET {{baseUrl}}/mcp/v1/tools +X-API-Key: invalid-key-123 + +### Get Feature Flags (Admin only) +GET {{baseUrl}}/admin/feature-flags +X-API-Key: {{adminKey}} + +### Set Feature Flag (Admin only) +POST {{baseUrl}}/admin/feature-flags/premium_features?enabled=true +X-API-Key: {{adminKey}} + +### Test Rate Limiting (rapid requests) +POST {{baseUrl}}/mcp/v1/tools/call +Content-Type: application/json +X-API-Key: {{guestKey}} + +{ + "name": "echo", + "arguments": { + "message": "Rate limit test 1" + } +} + +### (Copy the above request multiple times to test rate limiting) +EOF + + log_success "REST Client configuration created" +} + +create_git_config() { + log_info "Creating Git configuration..." + + # Create .gitignore if it doesn't exist + if [ ! -f .gitignore ]; then + cat > .gitignore << 'EOF' +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +build/ +bld/ +[Bb]in/ +[Oo]bj/ +[Oo]ut/ +msbuild.log +msbuild.err +msbuild.wrn + +# Visual Studio +.vs/ +*.user +*.suo +*.userosscache +*.sln.docstates +*.vspx +*.sap + +# Logs +logs/ +*.log + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Coverage directory used by tools like istanbul +coverage + +# nyc test coverage +.nyc_output + +# Dependency directories +node_modules/ + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Microbundle cache +.rpt2_cache/ +.rts2_cache_cjs/ +.rts2_cache_es/ +.rts2_cache_umd/ + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variables file +.env +.env.test +.env.local + +# parcel-bundler cache (https://parceljs.org/) +.cache +.parcel-cache + +# next.js build output +.next + +# nuxt.js build output +.nuxt + +# gatsby files +.cache/ +public + +# vuepress build output +.vuepress/dist + +# Serverless directories +.serverless/ + +# FuseBox cache +.fusebox/ + +# DynamoDB Local files +.dynamodb/ + +# TernJS port file +.tern-port + +# IDE files +.idea/ +*.swp +*.swo +*~ + +# OS generated files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db +EOF + log_success "Created .gitignore file" + else + log_info ".gitignore already exists" + fi +} + +install_dev_tools() { + log_info "Installing development tools..." + + # Install .NET tools if not already installed + if check_dependency "dotnet" ".NET SDK"; then + log_info "Installing useful .NET tools..." + + # Install dotnet-format for code formatting + dotnet tool install -g dotnet-format 2>/dev/null || log_info "dotnet-format already installed" + + # Install dotnet-outdated for checking outdated packages + dotnet tool install -g dotnet-outdated-tool 2>/dev/null || log_info "dotnet-outdated already installed" + + # Install dotnet-ef for Entity Framework migrations (if needed) + dotnet tool install -g dotnet-ef 2>/dev/null || log_info "dotnet-ef already installed" + + log_success "Development tools installed" + fi +} + +create_documentation() { + log_info "Creating additional documentation..." + + mkdir -p docs + + # Create DEVELOPMENT.md + cat > docs/DEVELOPMENT.md << 'EOF' +# Development Guide + +## Quick Start + +1. **Setup Development Environment** + ```bash + ./scripts/setup-dev.sh + ``` + +2. **Open in VS Code** + ```bash + code . + ``` + +3. **Start Debugging** + - Press F5 or use "Launch (Development)" configuration + - The server will start with development settings + +## Available Configurations + +### VS Code Launch Profiles +- **Launch (Development)**: Standard development mode with basic filtering +- **Launch (No Filtering)**: All filtering disabled for testing +- **Launch (Production Mode)**: All filters enabled with strict settings +- **Attach to Process**: Attach debugger to running process + +### VS Code Tasks +- **build**: Build the project +- **test**: Run all tests +- **test-authentication**: Run authentication tests only +- **test-authorization**: Run authorization tests only +- **docker-build**: Build Docker image +- **docker-run**: Run in Docker container + +## Testing + +### Manual Testing with REST Client +Use the `tests/api-tests.http` file with VS Code REST Client extension for interactive testing. + +### Automated Testing +```bash +# Run all tests +./scripts/test-all.sh + +# Run specific test category +./scripts/test-all.sh authentication +./scripts/test-all.sh authorization +./scripts/test-all.sh rate-limiting +``` + +### PowerShell Testing +```powershell +# Run all tests +.\scripts\test-all.ps1 + +# Run specific category +.\scripts\test-all.ps1 -Category authentication +``` + +## Docker Development + +### Build and Run +```bash +# Build image +docker build -t dynamic-tool-filtering . + +# Run container +docker run -p 8080:8080 dynamic-tool-filtering + +# Use docker-compose +docker-compose up --build +``` + +### Production-like Environment +```bash +# Start with Redis and PostgreSQL +docker-compose --profile production up +``` + +## Debugging Tips + +1. **Filter Execution**: Set breakpoints in filter classes to see execution flow +2. **Authentication**: Check claims in the `ToolAuthorizationContext` +3. **Rate Limiting**: Monitor the `IRateLimitingService` implementation +4. **Feature Flags**: Debug the `IFeatureFlagService` to see flag evaluations + +## Code Formatting + +```bash +# Format code +dotnet format + +# Check for outdated packages +dotnet outdated +``` + +## Adding New Filters + +1. Create a new class implementing `IToolFilter` +2. Set appropriate priority value +3. Register in `Program.cs` +4. Add configuration options to `FilteringOptions.cs` +5. Write tests for the new filter + +## Performance Profiling + +Use dotnet-trace for performance analysis: +```bash +dotnet-trace collect --process-id --providers Microsoft-AspNetCore-Server-Kestrel +``` +EOF + + log_success "Development documentation created" +} + +main() { + echo "========================================" + echo "Dynamic Tool Filtering - Dev Setup" + echo "========================================" + echo "Setting up development environment..." + echo "" + + # Check if we're in the right directory + if [ ! -f "Program.cs" ]; then + log_error "This script must be run from the DynamicToolFiltering project directory" + log_info "Expected to find Program.cs in current directory" + exit 1 + fi + + # Check dependencies + log_info "Checking dependencies..." + check_dependency "dotnet" ".NET SDK" "Install from https://dotnet.microsoft.com/" + check_dependency "git" "Git" "Install from https://git-scm.com/" + check_dependency "curl" "curl" "Install with package manager" + check_dependency "docker" "Docker" "Install from https://www.docker.com/" || log_warning "Docker is optional but recommended" + check_dependency "code" "VS Code" "Install from https://code.visualstudio.com/" || log_warning "VS Code is optional but recommended" + + echo "" + + # Create configurations + create_vscode_config + create_docker_config + create_rest_client_config + create_git_config + install_dev_tools + create_documentation + + # Make scripts executable + chmod +x scripts/*.sh 2>/dev/null || true + + echo "" + echo "========================================" + echo "Setup Complete!" + echo "========================================" + echo "" + echo "Next steps:" + echo "1. Open VS Code: code ." + echo "2. Install recommended extensions (VS Code will prompt)" + echo "3. Press F5 to start debugging" + echo "4. Run tests: ./scripts/test-all.sh" + echo "5. Try API testing with tests/api-tests.http" + echo "" + echo "Documentation:" + echo "- Development Guide: docs/DEVELOPMENT.md" + echo "- VS Code Tasks: Ctrl+Shift+P -> 'Tasks: Run Task'" + echo "- REST Client: Open tests/api-tests.http in VS Code" + echo "" + log_success "Development environment ready!" +} + +# Execute main function +main "$@" \ No newline at end of file diff --git a/samples/DynamicToolFiltering/scripts/test-all.ps1 b/samples/DynamicToolFiltering/scripts/test-all.ps1 new file mode 100644 index 00000000..08cf52b7 --- /dev/null +++ b/samples/DynamicToolFiltering/scripts/test-all.ps1 @@ -0,0 +1,483 @@ +#!/usr/bin/env pwsh + +<# +.SYNOPSIS + Dynamic Tool Filtering - Comprehensive Test Suite (PowerShell) + +.DESCRIPTION + This script tests all aspects of the MCP server including authentication, + authorization, rate limiting, feature flags, and error handling. + +.PARAMETER Category + Test category to run. Options: health, authentication, authorization, visibility, + rate-limiting, feature-flags, error-handling, performance, all + +.PARAMETER BaseUrl + Base URL of the server to test (default: http://localhost:8080) + +.EXAMPLE + .\test-all.ps1 + Runs all test categories + +.EXAMPLE + .\test-all.ps1 -Category authentication + Runs only authentication tests + +.EXAMPLE + .\test-all.ps1 -BaseUrl "http://localhost:9000" + Tests a server on a different port +#> + +param( + [Parameter(Position=0)] + [ValidateSet("health", "authentication", "authorization", "visibility", "rate-limiting", "feature-flags", "error-handling", "performance", "all")] + [string]$Category = "all", + + [Parameter()] + [string]$BaseUrl = "http://localhost:8080" +) + +# Configuration +$script:BaseUrl = $BaseUrl +$script:GuestKey = "demo-guest-key" +$script:UserKey = "demo-user-key" +$script:PremiumKey = "demo-premium-key" +$script:AdminKey = "demo-admin-key" +$script:InvalidKey = "invalid-test-key" + +# Test counters +$script:TotalTests = 0 +$script:PassedTests = 0 +$script:FailedTests = 0 + +# Helper functions +function Write-InfoLog { + param([string]$Message) + Write-Host "[INFO] $Message" -ForegroundColor Blue +} + +function Write-SuccessLog { + param([string]$Message) + Write-Host "[PASS] $Message" -ForegroundColor Green + $script:PassedTests++ +} + +function Write-ErrorLog { + param([string]$Message) + Write-Host "[FAIL] $Message" -ForegroundColor Red + $script:FailedTests++ +} + +function Write-WarningLog { + param([string]$Message) + Write-Host "[WARN] $Message" -ForegroundColor Yellow +} + +function Invoke-TestRequest { + param( + [string]$TestName, + [int]$ExpectedStatus, + [hashtable]$Headers = @{}, + [string]$Method = "GET", + [string]$Uri, + [string]$Body = $null, + [string]$ContentType = "application/json" + ) + + $script:TotalTests++ + Write-InfoLog "Running test: $TestName" + + try { + $requestParams = @{ + Uri = $Uri + Method = $Method + Headers = $Headers + UseBasicParsing = $true + ErrorAction = 'Stop' + } + + if ($Body) { + $requestParams.Body = $Body + $requestParams.Headers['Content-Type'] = $ContentType + } + + $response = Invoke-WebRequest @requestParams + $actualStatus = $response.StatusCode + + if ($actualStatus -eq $ExpectedStatus) { + Write-SuccessLog "$TestName (HTTP $actualStatus)" + if ($response.Content) { + $content = $response.Content | ConvertFrom-Json -ErrorAction SilentlyContinue + if ($content) { + Write-Host " Response: $($content | ConvertTo-Json -Compress | Select-Object -First 1)" -ForegroundColor DarkGray + } else { + Write-Host " Response: $($response.Content.Substring(0, [Math]::Min(100, $response.Content.Length)))" -ForegroundColor DarkGray + } + } + } else { + Write-ErrorLog "$TestName (Expected HTTP $ExpectedStatus, got HTTP $actualStatus)" + if ($response.Content) { + Write-Host " Response: $($response.Content.Substring(0, [Math]::Min(200, $response.Content.Length)))" -ForegroundColor DarkGray + } + } + } + catch { + $actualStatus = 0 + if ($_.Exception.Response) { + $actualStatus = [int]$_.Exception.Response.StatusCode + } + + if ($actualStatus -eq $ExpectedStatus) { + Write-SuccessLog "$TestName (HTTP $actualStatus)" + if ($_.Exception.Response) { + $reader = New-Object System.IO.StreamReader($_.Exception.Response.GetResponseStream()) + $responseBody = $reader.ReadToEnd() + Write-Host " Response: $($responseBody.Substring(0, [Math]::Min(200, $responseBody.Length)))" -ForegroundColor DarkGray + } + } else { + Write-ErrorLog "$TestName (Expected HTTP $ExpectedStatus, got HTTP $actualStatus)" + Write-Host " Error: $($_.Exception.Message)" -ForegroundColor DarkGray + } + } + + Write-Host "" +} + +function Wait-ForServer { + Write-InfoLog "Waiting for server to be ready at $script:BaseUrl..." + $retries = 30 + $count = 0 + + while ($count -lt $retries) { + try { + $response = Invoke-WebRequest -Uri "$script:BaseUrl/health" -UseBasicParsing -TimeoutSec 5 -ErrorAction Stop + if ($response.StatusCode -eq 200) { + Write-SuccessLog "Server is ready!" + return $true + } + } + catch { + # Server not ready yet + } + + $count++ + if ($count -eq $retries) { + Write-ErrorLog "Server did not start within expected time" + return $false + } + + Start-Sleep -Seconds 2 + } + + return $false +} + +function Test-HealthCheck { + Write-InfoLog "=== Health Check Tests ===" + Invoke-TestRequest -TestName "Health endpoint availability" -ExpectedStatus 200 -Uri "$script:BaseUrl/health" +} + +function Test-Authentication { + Write-InfoLog "=== Authentication Tests ===" + + # Test valid API keys + Invoke-TestRequest -TestName "Guest API key authentication" -ExpectedStatus 200 ` + -Headers @{"X-API-Key" = $script:GuestKey} -Uri "$script:BaseUrl/mcp/v1/tools" + + Invoke-TestRequest -TestName "User API key authentication" -ExpectedStatus 200 ` + -Headers @{"X-API-Key" = $script:UserKey} -Uri "$script:BaseUrl/mcp/v1/tools" + + Invoke-TestRequest -TestName "Premium API key authentication" -ExpectedStatus 200 ` + -Headers @{"X-API-Key" = $script:PremiumKey} -Uri "$script:BaseUrl/mcp/v1/tools" + + Invoke-TestRequest -TestName "Admin API key authentication" -ExpectedStatus 200 ` + -Headers @{"X-API-Key" = $script:AdminKey} -Uri "$script:BaseUrl/mcp/v1/tools" + + # Test invalid API key + Invoke-TestRequest -TestName "Invalid API key rejection" -ExpectedStatus 401 ` + -Headers @{"X-API-Key" = $script:InvalidKey} -Uri "$script:BaseUrl/mcp/v1/tools" + + # Test missing API key for protected endpoint + Invoke-TestRequest -TestName "Missing API key for admin endpoint" -ExpectedStatus 401 ` + -Uri "$script:BaseUrl/admin/filters/status" +} + +function Test-Authorization { + Write-InfoLog "=== Authorization Tests ===" + + # Test tool execution with proper authorization + Invoke-TestRequest -TestName "Public tool execution (no auth)" -ExpectedStatus 200 ` + -Method POST -Uri "$script:BaseUrl/mcp/v1/tools/call" ` + -Body '{"name": "echo", "arguments": {"message": "test"}}' + + Invoke-TestRequest -TestName "User tool execution (user key)" -ExpectedStatus 200 ` + -Method POST -Headers @{"X-API-Key" = $script:UserKey} ` + -Uri "$script:BaseUrl/mcp/v1/tools/call" ` + -Body '{"name": "get_user_profile", "arguments": {}}' + + Invoke-TestRequest -TestName "Premium tool execution (premium key)" -ExpectedStatus 200 ` + -Method POST -Headers @{"X-API-Key" = $script:PremiumKey} ` + -Uri "$script:BaseUrl/mcp/v1/tools/call" ` + -Body '{"name": "premium_generate_secure_random", "arguments": {"byteCount": 16}}' + + Invoke-TestRequest -TestName "Admin tool execution (admin key)" -ExpectedStatus 200 ` + -Method POST -Headers @{"X-API-Key" = $script:AdminKey} ` + -Uri "$script:BaseUrl/mcp/v1/tools/call" ` + -Body '{"name": "admin_get_system_diagnostics", "arguments": {}}' + + # Test authorization failures + Invoke-TestRequest -TestName "User trying admin tool (should fail)" -ExpectedStatus 401 ` + -Method POST -Headers @{"X-API-Key" = $script:UserKey} ` + -Uri "$script:BaseUrl/mcp/v1/tools/call" ` + -Body '{"name": "admin_get_system_diagnostics", "arguments": {}}' + + Invoke-TestRequest -TestName "Guest trying premium tool (should fail)" -ExpectedStatus 401 ` + -Method POST -Headers @{"X-API-Key" = $script:GuestKey} ` + -Uri "$script:BaseUrl/mcp/v1/tools/call" ` + -Body '{"name": "premium_generate_secure_random", "arguments": {"byteCount": 16}}' + + Invoke-TestRequest -TestName "Guest trying user tool (should fail)" -ExpectedStatus 401 ` + -Method POST -Headers @{"X-API-Key" = $script:GuestKey} ` + -Uri "$script:BaseUrl/mcp/v1/tools/call" ` + -Body '{"name": "get_user_profile", "arguments": {}}' +} + +function Test-ToolVisibility { + Write-InfoLog "=== Tool Visibility Tests ===" + + try { + # Get tool lists for different roles and count them + $guestResponse = Invoke-WebRequest -Uri "$script:BaseUrl/mcp/v1/tools" -Headers @{"X-API-Key" = $script:GuestKey} -UseBasicParsing + $userResponse = Invoke-WebRequest -Uri "$script:BaseUrl/mcp/v1/tools" -Headers @{"X-API-Key" = $script:UserKey} -UseBasicParsing + $premiumResponse = Invoke-WebRequest -Uri "$script:BaseUrl/mcp/v1/tools" -Headers @{"X-API-Key" = $script:PremiumKey} -UseBasicParsing + $adminResponse = Invoke-WebRequest -Uri "$script:BaseUrl/mcp/v1/tools" -Headers @{"X-API-Key" = $script:AdminKey} -UseBasicParsing + + $guestTools = ($guestResponse.Content | ConvertFrom-Json).result.tools.Count + $userTools = ($userResponse.Content | ConvertFrom-Json).result.tools.Count + $premiumTools = ($premiumResponse.Content | ConvertFrom-Json).result.tools.Count + $adminTools = ($adminResponse.Content | ConvertFrom-Json).result.tools.Count + + $script:TotalTests += 4 + + # Verify hierarchical access (each higher role should see more or equal tools) + if ($guestTools -gt 0) { + Write-SuccessLog "Guest can see tools ($guestTools tools visible)" + } else { + Write-ErrorLog "Guest cannot see any tools" + } + + if ($userTools -ge $guestTools) { + Write-SuccessLog "User sees >= guest tools ($userTools >= $guestTools)" + } else { + Write-ErrorLog "User sees fewer tools than guest ($userTools < $guestTools)" + } + + if ($premiumTools -ge $userTools) { + Write-SuccessLog "Premium sees >= user tools ($premiumTools >= $userTools)" + } else { + Write-ErrorLog "Premium sees fewer tools than user ($premiumTools < $userTools)" + } + + if ($adminTools -ge $premiumTools) { + Write-SuccessLog "Admin sees >= premium tools ($adminTools >= $premiumTools)" + } else { + Write-ErrorLog "Admin sees fewer tools than premium ($adminTools < $premiumTools)" + } + } + catch { + Write-ErrorLog "Failed to test tool visibility: $($_.Exception.Message)" + $script:TotalTests += 4 + $script:FailedTests += 4 + } +} + +function Test-RateLimiting { + Write-InfoLog "=== Rate Limiting Tests ===" + + # Test rapid requests to trigger rate limiting + $successCount = 0 + $rateLimitedCount = 0 + + Write-InfoLog "Making 10 rapid requests with guest key to test rate limiting..." + + for ($i = 1; $i -le 10; $i++) { + try { + $response = Invoke-WebRequest -Uri "$script:BaseUrl/mcp/v1/tools/call" ` + -Method POST -Headers @{"X-API-Key" = $script:GuestKey} ` + -Body '{"name": "echo", "arguments": {"message": "rate limit test"}}' ` + -ContentType "application/json" -UseBasicParsing -ErrorAction Stop + + if ($response.StatusCode -eq 200) { + $successCount++ + } + } + catch { + if ($_.Exception.Response -and [int]$_.Exception.Response.StatusCode -eq 429) { + $rateLimitedCount++ + } + } + + Start-Sleep -Milliseconds 100 + } + + $script:TotalTests++ + if ($successCount -gt 0) { + Write-SuccessLog "Rate limiting test: $successCount successful, $rateLimitedCount rate-limited" + } else { + Write-ErrorLog "Rate limiting test: All requests failed" + } +} + +function Test-FeatureFlags { + Write-InfoLog "=== Feature Flag Tests ===" + + # Test feature flag endpoint (admin only) + Invoke-TestRequest -TestName "Get feature flags (admin)" -ExpectedStatus 200 ` + -Headers @{"X-API-Key" = $script:AdminKey} -Uri "$script:BaseUrl/admin/feature-flags" + + Invoke-TestRequest -TestName "Get feature flags (non-admin should fail)" -ExpectedStatus 401 ` + -Headers @{"X-API-Key" = $script:UserKey} -Uri "$script:BaseUrl/admin/feature-flags" + + # Test setting feature flags + Invoke-TestRequest -TestName "Set feature flag (admin)" -ExpectedStatus 200 ` + -Method POST -Headers @{"X-API-Key" = $script:AdminKey} ` + -Uri "$script:BaseUrl/admin/feature-flags/premium_features?enabled=true" +} + +function Test-ErrorHandling { + Write-InfoLog "=== Error Handling Tests ===" + + # Test malformed requests + Invoke-TestRequest -TestName "Malformed JSON request" -ExpectedStatus 400 ` + -Method POST -Headers @{"X-API-Key" = $script:UserKey} ` + -Uri "$script:BaseUrl/mcp/v1/tools/call" ` + -Body '{"invalid": json}' + + # Test nonexistent tool + Invoke-TestRequest -TestName "Nonexistent tool execution" -ExpectedStatus 400 ` + -Method POST -Headers @{"X-API-Key" = $script:UserKey} ` + -Uri "$script:BaseUrl/mcp/v1/tools/call" ` + -Body '{"name": "nonexistent_tool", "arguments": {}}' + + # Test invalid arguments + Invoke-TestRequest -TestName "Invalid tool arguments" -ExpectedStatus 400 ` + -Method POST -Headers @{"X-API-Key" = $script:UserKey} ` + -Uri "$script:BaseUrl/mcp/v1/tools/call" ` + -Body '{"name": "calculate_hash", "arguments": {"text": "test", "algorithm": "invalid"}}' + + # Test nonexistent endpoint + Invoke-TestRequest -TestName "Nonexistent endpoint" -ExpectedStatus 404 ` + -Uri "$script:BaseUrl/nonexistent/endpoint" +} + +function Test-Performance { + Write-InfoLog "=== Performance Tests ===" + + # Test concurrent requests + Write-InfoLog "Testing concurrent requests..." + $startTime = Get-Date + + # Launch 5 concurrent requests + $jobs = @() + for ($i = 1; $i -le 5; $i++) { + $job = Start-Job -ScriptBlock { + param($baseUrl, $userKey) + try { + Invoke-WebRequest -Uri "$baseUrl/mcp/v1/tools/call" ` + -Method POST -Headers @{"X-API-Key" = $userKey} ` + -Body '{"name": "echo", "arguments": {"message": "concurrent test"}}' ` + -ContentType "application/json" -UseBasicParsing -ErrorAction Stop + return $true + } + catch { + return $false + } + } -ArgumentList $script:BaseUrl, $script:UserKey + + $jobs += $job + } + + # Wait for all requests to complete + $jobs | Wait-Job | Out-Null + $endTime = Get-Date + $duration = ($endTime - $startTime).TotalSeconds + + # Clean up jobs + $jobs | Remove-Job + + $script:TotalTests++ + if ($duration -le 10) { + Write-SuccessLog "Concurrent requests completed in $([Math]::Round($duration, 2))s (acceptable)" + } else { + Write-ErrorLog "Concurrent requests took $([Math]::Round($duration, 2))s (too slow)" + } +} + +function Invoke-SpecificCategory { + param([string]$TestCategory) + + switch ($TestCategory) { + "health" { Test-HealthCheck } + "authentication" { Test-Authentication } + "authorization" { Test-Authorization } + "visibility" { Test-ToolVisibility } + "rate-limiting" { Test-RateLimiting } + "feature-flags" { Test-FeatureFlags } + "error-handling" { Test-ErrorHandling } + "performance" { Test-Performance } + "all" { + Test-HealthCheck + Test-Authentication + Test-Authorization + Test-ToolVisibility + Test-RateLimiting + Test-FeatureFlags + Test-ErrorHandling + Test-Performance + } + default { + Write-ErrorLog "Unknown test category: $TestCategory" + Write-Host "Available categories: health, authentication, authorization, visibility, rate-limiting, feature-flags, error-handling, performance, all" + exit 1 + } + } +} + +# Main execution +function Main { + Write-Host "==================================" -ForegroundColor Cyan + Write-Host "Dynamic Tool Filtering Test Suite" -ForegroundColor Cyan + Write-Host "==================================" -ForegroundColor Cyan + Write-Host "Server: $script:BaseUrl" + Write-Host "Category: $Category" + Write-Host "Time: $(Get-Date)" + Write-Host "" + + # Wait for server to be ready + if (-not (Wait-ForServer)) { + exit 1 + } + + # Run tests + Invoke-SpecificCategory -TestCategory $Category + + # Results summary + Write-Host "" + Write-Host "==================================" -ForegroundColor Cyan + Write-Host "Test Results Summary" -ForegroundColor Cyan + Write-Host "==================================" -ForegroundColor Cyan + Write-Host "Total Tests: $script:TotalTests" + Write-Host "Passed: $script:PassedTests" -ForegroundColor Green + Write-Host "Failed: $script:FailedTests" -ForegroundColor Red + + if ($script:FailedTests -eq 0) { + Write-Host "`n✅ All tests passed!" -ForegroundColor Green + exit 0 + } else { + Write-Host "`n❌ Some tests failed!" -ForegroundColor Red + exit 1 + } +} + +# Execute main function +Main \ No newline at end of file diff --git a/samples/DynamicToolFiltering/scripts/test-all.sh b/samples/DynamicToolFiltering/scripts/test-all.sh new file mode 100755 index 00000000..caaa96f2 --- /dev/null +++ b/samples/DynamicToolFiltering/scripts/test-all.sh @@ -0,0 +1,474 @@ +#!/bin/bash + +# Dynamic Tool Filtering - Comprehensive Test Suite +# This script tests all aspects of the MCP server including authentication, +# authorization, rate limiting, feature flags, and error handling. + +set -e + +# Configuration +BASE_URL="${BASE_URL:-http://localhost:8080}" +GUEST_KEY="demo-guest-key" +USER_KEY="demo-user-key" +PREMIUM_KEY="demo-premium-key" +ADMIN_KEY="demo-admin-key" +INVALID_KEY="invalid-test-key" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Test counters +TOTAL_TESTS=0 +PASSED_TESTS=0 +FAILED_TESTS=0 + +# Test categories +CATEGORY="${1:-all}" + +# Helper functions +log_info() { + echo -e "${BLUE}[INFO]${NC} $1" +} + +log_success() { + echo -e "${GREEN}[PASS]${NC} $1" + ((PASSED_TESTS++)) +} + +log_error() { + echo -e "${RED}[FAIL]${NC} $1" + ((FAILED_TESTS++)) +} + +log_warning() { + echo -e "${YELLOW}[WARN]${NC} $1" +} + +run_test() { + local test_name="$1" + local expected_status="$2" + local curl_args="${@:3}" + + ((TOTAL_TESTS++)) + log_info "Running test: $test_name" + + # Execute curl command and capture both status and response + local response + local status + response=$(curl -s -w "HTTPSTATUS:%{http_code}" $curl_args 2>/dev/null || echo "HTTPSTATUS:000") + status=$(echo "$response" | grep -o "HTTPSTATUS:[0-9]*" | cut -d: -f2) + local body=$(echo "$response" | sed -E 's/HTTPSTATUS:[0-9]*$//') + + if [[ "$status" == "$expected_status" ]]; then + log_success "$test_name (HTTP $status)" + if [[ -n "$body" && "$body" != "null" ]]; then + echo " Response: $(echo "$body" | jq -r . 2>/dev/null || echo "$body" | head -c 100)..." + fi + else + log_error "$test_name (Expected HTTP $expected_status, got HTTP $status)" + if [[ -n "$body" ]]; then + echo " Response: $(echo "$body" | jq -r . 2>/dev/null || echo "$body" | head -c 200)..." + fi + fi + + echo "" +} + +wait_for_server() { + log_info "Waiting for server to be ready at $BASE_URL..." + local retries=30 + local count=0 + + while [ $count -lt $retries ]; do + if curl -s "$BASE_URL/health" > /dev/null 2>&1; then + log_success "Server is ready!" + return 0 + fi + + ((count++)) + if [ $count -eq $retries ]; then + log_error "Server did not start within expected time" + return 1 + fi + + sleep 2 + done +} + +test_health_check() { + log_info "=== Health Check Tests ===" + run_test "Health endpoint availability" "200" "$BASE_URL/health" +} + +test_authentication() { + log_info "=== Authentication Tests ===" + + # Test valid API keys + run_test "Guest API key authentication" "200" \ + -H "X-API-Key: $GUEST_KEY" "$BASE_URL/mcp/v1/tools" + + run_test "User API key authentication" "200" \ + -H "X-API-Key: $USER_KEY" "$BASE_URL/mcp/v1/tools" + + run_test "Premium API key authentication" "200" \ + -H "X-API-Key: $PREMIUM_KEY" "$BASE_URL/mcp/v1/tools" + + run_test "Admin API key authentication" "200" \ + -H "X-API-Key: $ADMIN_KEY" "$BASE_URL/mcp/v1/tools" + + # Test invalid API key + run_test "Invalid API key rejection" "401" \ + -H "X-API-Key: $INVALID_KEY" "$BASE_URL/mcp/v1/tools" + + # Test missing API key for protected endpoint + run_test "Missing API key for admin endpoint" "401" \ + "$BASE_URL/admin/filters/status" +} + +test_authorization() { + log_info "=== Authorization Tests ===" + + # Test tool execution with proper authorization + run_test "Public tool execution (no auth)" "200" \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{"name": "echo", "arguments": {"message": "test"}}' \ + "$BASE_URL/mcp/v1/tools/call" + + run_test "User tool execution (user key)" "200" \ + -X POST \ + -H "X-API-Key: $USER_KEY" \ + -H "Content-Type: application/json" \ + -d '{"name": "get_user_profile", "arguments": {}}' \ + "$BASE_URL/mcp/v1/tools/call" + + run_test "Premium tool execution (premium key)" "200" \ + -X POST \ + -H "X-API-Key: $PREMIUM_KEY" \ + -H "Content-Type: application/json" \ + -d '{"name": "premium_generate_secure_random", "arguments": {"byteCount": 16}}' \ + "$BASE_URL/mcp/v1/tools/call" + + run_test "Admin tool execution (admin key)" "200" \ + -X POST \ + -H "X-API-Key: $ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{"name": "admin_get_system_diagnostics", "arguments": {}}' \ + "$BASE_URL/mcp/v1/tools/call" + + # Test authorization failures + run_test "User trying admin tool (should fail)" "401" \ + -X POST \ + -H "X-API-Key: $USER_KEY" \ + -H "Content-Type: application/json" \ + -d '{"name": "admin_get_system_diagnostics", "arguments": {}}' \ + "$BASE_URL/mcp/v1/tools/call" + + run_test "Guest trying premium tool (should fail)" "401" \ + -X POST \ + -H "X-API-Key: $GUEST_KEY" \ + -H "Content-Type: application/json" \ + -d '{"name": "premium_generate_secure_random", "arguments": {"byteCount": 16}}' \ + "$BASE_URL/mcp/v1/tools/call" + + run_test "Guest trying user tool (should fail)" "401" \ + -X POST \ + -H "X-API-Key: $GUEST_KEY" \ + -H "Content-Type: application/json" \ + -d '{"name": "get_user_profile", "arguments": {}}' \ + "$BASE_URL/mcp/v1/tools/call" +} + +test_tool_visibility() { + log_info "=== Tool Visibility Tests ===" + + # Get tool lists for different roles and count them + local guest_tools=$(curl -s -H "X-API-Key: $GUEST_KEY" "$BASE_URL/mcp/v1/tools" | jq -r '.result.tools | length' 2>/dev/null || echo "0") + local user_tools=$(curl -s -H "X-API-Key: $USER_KEY" "$BASE_URL/mcp/v1/tools" | jq -r '.result.tools | length' 2>/dev/null || echo "0") + local premium_tools=$(curl -s -H "X-API-Key: $PREMIUM_KEY" "$BASE_URL/mcp/v1/tools" | jq -r '.result.tools | length' 2>/dev/null || echo "0") + local admin_tools=$(curl -s -H "X-API-Key: $ADMIN_KEY" "$BASE_URL/mcp/v1/tools" | jq -r '.result.tools | length' 2>/dev/null || echo "0") + + ((TOTAL_TESTS += 4)) + + # Verify hierarchical access (each higher role should see more or equal tools) + if [[ "$guest_tools" -gt 0 ]]; then + log_success "Guest can see tools ($guest_tools tools visible)" + ((PASSED_TESTS++)) + else + log_error "Guest cannot see any tools" + ((FAILED_TESTS++)) + fi + + if [[ "$user_tools" -ge "$guest_tools" ]]; then + log_success "User sees >= guest tools ($user_tools >= $guest_tools)" + ((PASSED_TESTS++)) + else + log_error "User sees fewer tools than guest ($user_tools < $guest_tools)" + ((FAILED_TESTS++)) + fi + + if [[ "$premium_tools" -ge "$user_tools" ]]; then + log_success "Premium sees >= user tools ($premium_tools >= $user_tools)" + ((PASSED_TESTS++)) + else + log_error "Premium sees fewer tools than user ($premium_tools < $user_tools)" + ((FAILED_TESTS++)) + fi + + if [[ "$admin_tools" -ge "$premium_tools" ]]; then + log_success "Admin sees >= premium tools ($admin_tools >= $premium_tools)" + ((PASSED_TESTS++)) + else + log_error "Admin sees fewer tools than premium ($admin_tools < $premium_tools)" + ((FAILED_TESTS++)) + fi +} + +test_rate_limiting() { + log_info "=== Rate Limiting Tests ===" + + # Test rapid requests to trigger rate limiting + local success_count=0 + local rate_limited_count=0 + + log_info "Making 10 rapid requests with guest key to test rate limiting..." + + for i in {1..10}; do + local status=$(curl -s -w "%{http_code}" -o /dev/null \ + -H "X-API-Key: $GUEST_KEY" \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{"name": "echo", "arguments": {"message": "rate limit test"}}' \ + "$BASE_URL/mcp/v1/tools/call") + + if [[ "$status" == "200" ]]; then + ((success_count++)) + elif [[ "$status" == "429" ]]; then + ((rate_limited_count++)) + fi + + sleep 0.1 + done + + ((TOTAL_TESTS++)) + if [[ "$success_count" -gt 0 ]]; then + log_success "Rate limiting test: $success_count successful, $rate_limited_count rate-limited" + ((PASSED_TESTS++)) + else + log_error "Rate limiting test: All requests failed" + ((FAILED_TESTS++)) + fi +} + +test_feature_flags() { + log_info "=== Feature Flag Tests ===" + + # Test feature flag endpoint (admin only) + run_test "Get feature flags (admin)" "200" \ + -H "X-API-Key: $ADMIN_KEY" \ + "$BASE_URL/admin/feature-flags" + + run_test "Get feature flags (non-admin should fail)" "401" \ + -H "X-API-Key: $USER_KEY" \ + "$BASE_URL/admin/feature-flags" + + # Test setting feature flags + run_test "Set feature flag (admin)" "200" \ + -X POST \ + -H "X-API-Key: $ADMIN_KEY" \ + "$BASE_URL/admin/feature-flags/premium_features?enabled=true" +} + +test_error_handling() { + log_info "=== Error Handling Tests ===" + + # Test malformed requests + run_test "Malformed JSON request" "400" \ + -X POST \ + -H "X-API-Key: $USER_KEY" \ + -H "Content-Type: application/json" \ + -d '{"invalid": json}' \ + "$BASE_URL/mcp/v1/tools/call" + + # Test nonexistent tool + run_test "Nonexistent tool execution" "400" \ + -X POST \ + -H "X-API-Key: $USER_KEY" \ + -H "Content-Type: application/json" \ + -d '{"name": "nonexistent_tool", "arguments": {}}' \ + "$BASE_URL/mcp/v1/tools/call" + + # Test invalid arguments + run_test "Invalid tool arguments" "400" \ + -X POST \ + -H "X-API-Key: $USER_KEY" \ + -H "Content-Type: application/json" \ + -d '{"name": "calculate_hash", "arguments": {"text": "test", "algorithm": "invalid"}}' \ + "$BASE_URL/mcp/v1/tools/call" + + # Test nonexistent endpoint + run_test "Nonexistent endpoint" "404" \ + "$BASE_URL/nonexistent/endpoint" +} + +test_performance() { + log_info "=== Performance Tests ===" + + # Test concurrent requests + log_info "Testing concurrent requests..." + local start_time=$(date +%s) + local pids=() + + # Launch 5 concurrent requests + for i in {1..5}; do + (curl -s -o /dev/null \ + -H "X-API-Key: $USER_KEY" \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{"name": "echo", "arguments": {"message": "concurrent test"}}' \ + "$BASE_URL/mcp/v1/tools/call") & + pids+=($!) + done + + # Wait for all requests to complete + for pid in "${pids[@]}"; do + wait $pid + done + + local end_time=$(date +%s) + local duration=$((end_time - start_time)) + + ((TOTAL_TESTS++)) + if [[ $duration -le 10 ]]; then + log_success "Concurrent requests completed in ${duration}s (acceptable)" + ((PASSED_TESTS++)) + else + log_error "Concurrent requests took ${duration}s (too slow)" + ((FAILED_TESTS++)) + fi +} + +run_specific_category() { + case $CATEGORY in + "health") + test_health_check + ;; + "authentication") + test_authentication + ;; + "authorization") + test_authorization + ;; + "visibility") + test_tool_visibility + ;; + "rate-limiting") + test_rate_limiting + ;; + "feature-flags") + test_feature_flags + ;; + "error-handling") + test_error_handling + ;; + "performance") + test_performance + ;; + "all") + test_health_check + test_authentication + test_authorization + test_tool_visibility + test_rate_limiting + test_feature_flags + test_error_handling + test_performance + ;; + *) + log_error "Unknown test category: $CATEGORY" + echo "Available categories: health, authentication, authorization, visibility, rate-limiting, feature-flags, error-handling, performance, all" + exit 1 + ;; + esac +} + +show_help() { + echo "Dynamic Tool Filtering - Test Suite" + echo "" + echo "Usage: $0 [CATEGORY]" + echo "" + echo "Categories:" + echo " health - Test health endpoint" + echo " authentication - Test API key authentication" + echo " authorization - Test role-based authorization" + echo " visibility - Test tool visibility by role" + echo " rate-limiting - Test rate limiting functionality" + echo " feature-flags - Test feature flag management" + echo " error-handling - Test error responses" + echo " performance - Test performance and concurrency" + echo " all - Run all tests (default)" + echo "" + echo "Environment Variables:" + echo " BASE_URL - Server URL (default: http://localhost:8080)" + echo "" + echo "Examples:" + echo " $0 # Run all tests" + echo " $0 authentication # Run only authentication tests" + echo " BASE_URL=http://localhost:9000 $0 # Test different server" +} + +# Main execution +main() { + if [[ "$1" == "--help" || "$1" == "-h" ]]; then + show_help + exit 0 + fi + + echo "==================================" + echo "Dynamic Tool Filtering Test Suite" + echo "==================================" + echo "Server: $BASE_URL" + echo "Category: $CATEGORY" + echo "Time: $(date)" + echo "" + + # Check dependencies + if ! command -v curl &> /dev/null; then + log_error "curl is required but not installed" + exit 1 + fi + + if ! command -v jq &> /dev/null; then + log_warning "jq is not installed - JSON parsing will be limited" + fi + + # Wait for server to be ready + wait_for_server || exit 1 + + # Run tests + run_specific_category + + # Results summary + echo "" + echo "==================================" + echo "Test Results Summary" + echo "==================================" + echo "Total Tests: $TOTAL_TESTS" + echo -e "Passed: ${GREEN}$PASSED_TESTS${NC}" + echo -e "Failed: ${RED}$FAILED_TESTS${NC}" + + if [[ $FAILED_TESTS -eq 0 ]]; then + echo -e "\n${GREEN}✅ All tests passed!${NC}" + exit 0 + else + echo -e "\n${RED}❌ Some tests failed!${NC}" + exit 1 + fi +} + +# Run main function with all arguments +main "$@" \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationHttpExtensions.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationHttpExtensions.cs new file mode 100644 index 00000000..f7aeaef2 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationHttpExtensions.cs @@ -0,0 +1,178 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Net.Http.Headers; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server.Authorization; +using System.Text.Json; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Extension methods for handling authorization challenges in HTTP responses. +/// +public static class AuthorizationHttpExtensions +{ + /// + /// Writes an authorization challenge response to the HTTP context. + /// + /// The HTTP context to write the response to. + /// The authorization exception containing challenge details. + /// Token to monitor for cancellation requests. + /// A task representing the asynchronous operation. + public static async Task WriteAuthorizationChallengeAsync( + this HttpContext context, + AuthorizationHttpException authException, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(context); + ArgumentNullException.ThrowIfNull(authException); + + // Set the HTTP status code + context.Response.StatusCode = authException.HttpStatusCode; + + // Add WWW-Authenticate header if provided + if (!string.IsNullOrEmpty(authException.WwwAuthenticateHeaderValue)) + { + context.Response.Headers.WWWAuthenticate = authException.WwwAuthenticateHeaderValue; + } + + // Create JSON-RPC error response + var jsonRpcError = new JsonRpcError + { + Error = new JsonRpcErrorDetail + { + Code = (int)authException.ErrorCode, + Message = authException.Message, + Data = new + { + ToolName = authException.ToolName, + Reason = authException.Reason, + HttpStatusCode = authException.HttpStatusCode, + RequiresAuthentication = !string.IsNullOrEmpty(authException.WwwAuthenticateHeaderValue) + } + } + }; + + // Set content type and write the JSON response + context.Response.ContentType = "application/json"; + await JsonSerializer.SerializeAsync( + context.Response.Body, + jsonRpcError, + McpJsonUtilities.JsonContext.Default.JsonRpcError, + cancellationToken); + } + + /// + /// Writes a generic authorization error response to the HTTP context. + /// + /// The HTTP context to write the response to. + /// The name of the tool that was denied access. + /// The reason for the authorization failure. + /// The HTTP status code to return (default: 403 Forbidden). + /// Optional WWW-Authenticate header value. + /// Token to monitor for cancellation requests. + /// A task representing the asynchronous operation. + public static async Task WriteAuthorizationErrorAsync( + this HttpContext context, + string toolName, + string reason, + int statusCode = StatusCodes.Status403Forbidden, + string? wwwAuthenticateValue = null, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(context); + ArgumentException.ThrowIfNullOrEmpty(toolName); + ArgumentException.ThrowIfNullOrEmpty(reason); + + var authException = new AuthorizationHttpException(toolName, reason, wwwAuthenticateValue, statusCode); + await context.WriteAuthorizationChallengeAsync(authException, cancellationToken); + } + + /// + /// Creates a Bearer token challenge response for OAuth2 authentication. + /// + /// The HTTP context to write the response to. + /// The name of the tool that was denied access. + /// The reason for the authorization failure. + /// Optional realm parameter for the WWW-Authenticate header. + /// Optional scope parameter for the WWW-Authenticate header. + /// Optional error parameter for the WWW-Authenticate header (e.g., "insufficient_scope"). + /// Optional error_description parameter for the WWW-Authenticate header. + /// Token to monitor for cancellation requests. + /// A task representing the asynchronous operation. + public static async Task WriteBearerChallengeAsync( + this HttpContext context, + string toolName, + string reason, + string? realm = null, + string? scope = null, + string? error = null, + string? errorDescription = null, + CancellationToken cancellationToken = default) + { + var authException = AuthorizationHttpException.CreateBearerChallenge( + toolName, reason, realm, scope, error, errorDescription); + await context.WriteAuthorizationChallengeAsync(authException, cancellationToken); + } + + /// + /// Creates a Basic authentication challenge response. + /// + /// The HTTP context to write the response to. + /// The name of the tool that was denied access. + /// The reason for the authorization failure. + /// The realm parameter for the WWW-Authenticate header. + /// Token to monitor for cancellation requests. + /// A task representing the asynchronous operation. + public static async Task WriteBasicChallengeAsync( + this HttpContext context, + string toolName, + string reason, + string? realm = null, + CancellationToken cancellationToken = default) + { + var authException = AuthorizationHttpException.CreateBasicChallenge(toolName, reason, realm); + await context.WriteAuthorizationChallengeAsync(authException, cancellationToken); + } + + /// + /// Determines if an exception should result in an HTTP authorization challenge. + /// + /// The exception to check. + /// True if the exception should result in an authorization challenge, false otherwise. + public static bool ShouldChallengeAuthorization(this Exception exception) + { + return exception is AuthorizationHttpException || + (exception is McpException mcpEx && mcpEx.ErrorCode == McpErrorCode.InvalidParams && + mcpEx.Message.Contains("Access denied", StringComparison.OrdinalIgnoreCase)); + } + + /// + /// Tries to extract tool name from an authorization-related exception message. + /// + /// The exception to extract the tool name from. + /// The tool name if found, otherwise null. + public static string? TryExtractToolName(this Exception exception) + { + if (exception is AuthorizationHttpException authEx) + { + return authEx.ToolName; + } + + if (exception?.Message is string message) + { + // Try to extract tool name from messages like "Access denied for tool 'toolName'" + var startIndex = message.IndexOf("tool '", StringComparison.OrdinalIgnoreCase); + if (startIndex >= 0) + { + startIndex += 6; // Length of "tool '" + var endIndex = message.IndexOf('\'', startIndex); + if (endIndex > startIndex) + { + return message.Substring(startIndex, endIndex - startIndex); + } + } + } + + return null; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 6dac1c3e..02ca7b82 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -8,6 +8,7 @@ using ModelContextProtocol.AspNetCore.Stateless; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using ModelContextProtocol.Server.Authorization; using System.Collections.Concurrent; using System.Diagnostics; using System.IO.Pipelines; @@ -69,6 +70,23 @@ await WriteJsonRpcErrorAsync(context, context.Response.StatusCode = StatusCodes.Status202Accepted; } } + catch (AuthorizationHttpException authEx) + { + // Handle authorization exceptions with proper HTTP challenge responses + await context.WriteAuthorizationChallengeAsync(authEx, context.RequestAborted); + return; + } + catch (Exception ex) when (ex.ShouldChallengeAuthorization()) + { + // Handle other authorization-related exceptions + var toolName = ex.TryExtractToolName() ?? "unknown"; + await context.WriteAuthorizationErrorAsync( + toolName, + ex.Message, + StatusCodes.Status403Forbidden, + cancellationToken: context.RequestAborted); + return; + } finally { // Stateless sessions are 1:1 with HTTP requests and are outlived by the MCP session tracked by the Mcp-Session-Id. diff --git a/src/ModelContextProtocol.Core/Server/Authorization/AllowAllToolFilter.cs b/src/ModelContextProtocol.Core/Server/Authorization/AllowAllToolFilter.cs new file mode 100644 index 00000000..6e382220 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/AllowAllToolFilter.cs @@ -0,0 +1,40 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server.Authorization; + +/// +/// A tool filter that allows access to all tools without any restrictions. +/// +/// +/// This filter is useful for development environments or scenarios where +/// no access control is required. It always returns authorization success +/// for any tool access request. +/// +public sealed class AllowAllToolFilter : IToolFilter +{ + /// + /// Initializes a new instance of the class. + /// + /// The priority for this filter. Default is (lowest priority). + public AllowAllToolFilter(int priority = int.MaxValue) + { + Priority = priority; + } + + /// + public int Priority { get; } + + /// + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Always include all tools + return Task.FromResult(true); + } + + /// + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Always allow execution + return Task.FromResult(AuthorizationResult.Allow("All tools allowed")); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/Authorization/AuthorizationChallenge.cs b/src/ModelContextProtocol.Core/Server/Authorization/AuthorizationChallenge.cs new file mode 100644 index 00000000..9a9eddbf --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/AuthorizationChallenge.cs @@ -0,0 +1,135 @@ +namespace ModelContextProtocol.Server.Authorization; + +/// +/// Represents an HTTP authorization challenge that should be sent to the client. +/// +/// +/// This class provides structured information for HTTP authorization challenges, +/// including the WWW-Authenticate header value and HTTP status code to return. +/// +public sealed class AuthorizationChallenge +{ + /// + /// Initializes a new instance of the class. + /// + /// The WWW-Authenticate header value. + /// The HTTP status code to return (default: 401 Unauthorized). + public AuthorizationChallenge(string wwwAuthenticateValue, int httpStatusCode = 401) + { + WwwAuthenticateValue = wwwAuthenticateValue ?? throw new ArgumentNullException(nameof(wwwAuthenticateValue)); + HttpStatusCode = httpStatusCode; + } + + /// + /// Gets the WWW-Authenticate header value to include in the HTTP response. + /// + public string WwwAuthenticateValue { get; } + + /// + /// Gets the HTTP status code to return in the response. + /// + /// + /// The HTTP status code, typically 401 (Unauthorized) or 403 (Forbidden). + /// + public int HttpStatusCode { get; } + + /// + /// Creates an for OAuth2 Bearer token authentication. + /// + /// Optional realm parameter for the WWW-Authenticate header. + /// Optional scope parameter for the WWW-Authenticate header. + /// Optional error parameter for the WWW-Authenticate header (e.g., "insufficient_scope"). + /// Optional error_description parameter for the WWW-Authenticate header. + /// An configured for OAuth2 Bearer token authentication. + public static AuthorizationChallenge CreateBearerChallenge( + string? realm = null, + string? scope = null, + string? error = null, + string? errorDescription = null) + { + var challengeParts = new List(); + + if (!string.IsNullOrEmpty(realm)) + challengeParts.Add($"realm=\"{realm}\""); + + if (!string.IsNullOrEmpty(scope)) + challengeParts.Add($"scope=\"{scope}\""); + + if (!string.IsNullOrEmpty(error)) + challengeParts.Add($"error=\"{error}\""); + + if (!string.IsNullOrEmpty(errorDescription)) + challengeParts.Add($"error_description=\"{errorDescription}\""); + + var wwwAuthenticate = challengeParts.Count > 0 + ? $"Bearer {string.Join(", ", challengeParts)}" + : "Bearer"; + + return new AuthorizationChallenge(wwwAuthenticate); + } + + /// + /// Creates an for Basic authentication. + /// + /// The realm parameter for the WWW-Authenticate header. + /// An configured for Basic authentication. + public static AuthorizationChallenge CreateBasicChallenge(string? realm = null) + { + var wwwAuthenticate = !string.IsNullOrEmpty(realm) + ? $"Basic realm=\"{realm}\"" + : "Basic"; + + return new AuthorizationChallenge(wwwAuthenticate); + } + + /// + /// Creates an for a custom authentication scheme. + /// + /// The authentication scheme name (e.g., "Custom", "ApiKey"). + /// Optional parameters for the WWW-Authenticate header. + /// An configured for the custom authentication scheme. + public static AuthorizationChallenge CreateCustomChallenge( + string scheme, + params (string name, string value)[] parameters) + { + var challengeParts = parameters?.Select(p => $"{p.name}=\"{p.value}\"").ToList() ?? new List(); + var wwwAuthenticate = challengeParts.Count > 0 + ? $"{scheme} {string.Join(", ", challengeParts)}" + : scheme; + + return new AuthorizationChallenge(wwwAuthenticate); + } + + /// + /// Creates an for OAuth2 insufficient scope error. + /// + /// The scope required to access the resource. + /// Optional realm parameter for the WWW-Authenticate header. + /// An configured for insufficient scope error. + public static AuthorizationChallenge CreateInsufficientScopeChallenge( + string requiredScope, + string? realm = null) + { + return CreateBearerChallenge( + realm: realm, + scope: requiredScope, + error: "insufficient_scope", + errorDescription: $"The request requires higher privileges than provided by the access token. Required scope: {requiredScope}"); + } + + /// + /// Creates an for OAuth2 invalid token error. + /// + /// Optional realm parameter for the WWW-Authenticate header. + /// Optional custom error description. + /// An configured for invalid token error. + public static AuthorizationChallenge CreateInvalidTokenChallenge( + string? realm = null, + string? errorDescription = null) + { + return CreateBearerChallenge( + realm: realm, + error: "invalid_token", + errorDescription: errorDescription ?? "The access token provided is expired, revoked, malformed, or invalid for other reasons"); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/Authorization/AuthorizationHttpException.cs b/src/ModelContextProtocol.Core/Server/Authorization/AuthorizationHttpException.cs new file mode 100644 index 00000000..4a8a015a --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/AuthorizationHttpException.cs @@ -0,0 +1,161 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server.Authorization; + +/// +/// Represents an authorization exception that requires an HTTP challenge response. +/// +/// +/// This exception is used to indicate that tool authorization has failed and +/// the client should receive a proper HTTP challenge response with WWW-Authenticate headers. +/// +public sealed class AuthorizationHttpException : McpException +{ + /// + /// Initializes a new instance of the class. + /// + /// The name of the tool that was denied access. + /// The reason for the authorization failure. + /// The WWW-Authenticate header value to include in the HTTP response. + /// The HTTP status code to return (default: 401 Unauthorized). + public AuthorizationHttpException( + string toolName, + string reason, + string? wwwAuthenticateHeaderValue = null, + int httpStatusCode = 401) + : base($"Access denied for tool '{toolName}': {reason}", McpErrorCode.InvalidParams) + { + ToolName = toolName ?? throw new ArgumentNullException(nameof(toolName)); + Reason = reason ?? throw new ArgumentNullException(nameof(reason)); + WwwAuthenticateHeaderValue = wwwAuthenticateHeaderValue; + HttpStatusCode = httpStatusCode; + } + + /// + /// Initializes a new instance of the class. + /// + /// The name of the tool that was denied access. + /// The reason for the authorization failure. + /// The inner exception that caused the authorization failure. + /// The WWW-Authenticate header value to include in the HTTP response. + /// The HTTP status code to return (default: 401 Unauthorized). + public AuthorizationHttpException( + string toolName, + string reason, + Exception innerException, + string? wwwAuthenticateHeaderValue = null, + int httpStatusCode = 401) + : base($"Access denied for tool '{toolName}': {reason}", McpErrorCode.InvalidParams, innerException) + { + ToolName = toolName ?? throw new ArgumentNullException(nameof(toolName)); + Reason = reason ?? throw new ArgumentNullException(nameof(reason)); + WwwAuthenticateHeaderValue = wwwAuthenticateHeaderValue; + HttpStatusCode = httpStatusCode; + } + + /// + /// Gets the name of the tool that was denied access. + /// + public string ToolName { get; } + + /// + /// Gets the reason for the authorization failure. + /// + public string Reason { get; } + + /// + /// Gets the WWW-Authenticate header value to include in the HTTP response. + /// + /// + /// If this is null, no WWW-Authenticate header will be added to the response. + /// + public string? WwwAuthenticateHeaderValue { get; } + + /// + /// Gets the HTTP status code to return in the response. + /// + /// + /// The HTTP status code, typically 401 (Unauthorized) or 403 (Forbidden). + /// + public int HttpStatusCode { get; } + + /// + /// Creates an for OAuth2 Bearer token authentication. + /// + /// The name of the tool that was denied access. + /// The reason for the authorization failure. + /// Optional realm parameter for the WWW-Authenticate header. + /// Optional scope parameter for the WWW-Authenticate header. + /// Optional error parameter for the WWW-Authenticate header (e.g., "insufficient_scope"). + /// Optional error_description parameter for the WWW-Authenticate header. + /// An configured for OAuth2 Bearer token authentication. + public static AuthorizationHttpException CreateBearerChallenge( + string toolName, + string reason, + string? realm = null, + string? scope = null, + string? error = null, + string? errorDescription = null) + { + var challengeParts = new List(); + + if (!string.IsNullOrEmpty(realm)) + challengeParts.Add($"realm=\"{realm}\""); + + if (!string.IsNullOrEmpty(scope)) + challengeParts.Add($"scope=\"{scope}\""); + + if (!string.IsNullOrEmpty(error)) + challengeParts.Add($"error=\"{error}\""); + + if (!string.IsNullOrEmpty(errorDescription)) + challengeParts.Add($"error_description=\"{errorDescription}\""); + + var wwwAuthenticate = challengeParts.Count > 0 + ? $"Bearer {string.Join(", ", challengeParts)}" + : "Bearer"; + + return new AuthorizationHttpException(toolName, reason, wwwAuthenticate); + } + + /// + /// Creates an for Basic authentication. + /// + /// The name of the tool that was denied access. + /// The reason for the authorization failure. + /// The realm parameter for the WWW-Authenticate header. + /// An configured for Basic authentication. + public static AuthorizationHttpException CreateBasicChallenge( + string toolName, + string reason, + string? realm = null) + { + var wwwAuthenticate = !string.IsNullOrEmpty(realm) + ? $"Basic realm=\"{realm}\"" + : "Basic"; + + return new AuthorizationHttpException(toolName, reason, wwwAuthenticate); + } + + /// + /// Creates an for a custom authentication scheme. + /// + /// The name of the tool that was denied access. + /// The reason for the authorization failure. + /// The authentication scheme name (e.g., "Custom", "ApiKey"). + /// Optional parameters for the WWW-Authenticate header. + /// An configured for the custom authentication scheme. + public static AuthorizationHttpException CreateCustomChallenge( + string toolName, + string reason, + string scheme, + params (string name, string value)[] parameters) + { + var challengeParts = parameters?.Select(p => $"{p.name}=\"{p.value}\"").ToList() ?? new List(); + var wwwAuthenticate = challengeParts.Count > 0 + ? $"{scheme} {string.Join(", ", challengeParts)}" + : scheme; + + return new AuthorizationHttpException(toolName, reason, wwwAuthenticate); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/Authorization/AuthorizationResult.cs b/src/ModelContextProtocol.Core/Server/Authorization/AuthorizationResult.cs new file mode 100644 index 00000000..81dd1947 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/AuthorizationResult.cs @@ -0,0 +1,164 @@ +namespace ModelContextProtocol.Server.Authorization; + +/// +/// Represents the result of an authorization check for tool operations. +/// +/// +/// This class encapsulates the outcome of authorization decisions, providing +/// both the boolean result and optional additional context for failed authorizations. +/// +public sealed class AuthorizationResult +{ + /// + /// Initializes a new instance of the class. + /// + /// Indicates whether the operation is authorized. + /// Optional reason for the authorization decision, particularly useful for denied operations. + /// Optional additional data that may be relevant to the authorization decision. + public AuthorizationResult(bool isAuthorized, string? reason = null, object? additionalData = null) + { + IsAuthorized = isAuthorized; + Reason = reason; + AdditionalData = additionalData; + } + + /// + /// Gets a value indicating whether the operation is authorized. + /// + /// + /// if the operation is authorized; otherwise, . + /// + public bool IsAuthorized { get; } + + /// + /// Gets the reason for the authorization decision. + /// + /// + /// A string describing the reason for the authorization result, or + /// if no specific reason was provided. + /// + /// + /// This property is particularly useful for denied operations where clients or + /// administrators need to understand why access was refused. + /// + public string? Reason { get; } + + /// + /// Gets additional data associated with the authorization decision. + /// + /// + /// An object containing additional context data, or + /// if no additional data is available. + /// + /// + /// This property can be used to pass implementation-specific data that may be + /// useful for logging, auditing, or other authorization-related processes. + /// + public object? AdditionalData { get; } + + /// + /// Creates an representing a successful authorization. + /// + /// Optional reason for the successful authorization. + /// Optional additional data related to the authorization. + /// An with set to . + public static AuthorizationResult Allow(string? reason = null, object? additionalData = null) + => new(true, reason, additionalData); + + /// + /// Creates an representing a denied authorization. + /// + /// Reason for the authorization denial. + /// Optional additional data related to the authorization failure. + /// An with set to . + public static AuthorizationResult Deny(string reason, object? additionalData = null) + => new(false, reason, additionalData); + + /// + /// Creates an representing a denied authorization with a default reason. + /// + /// Optional additional data related to the authorization failure. + /// An with set to . + public static AuthorizationResult Deny(object? additionalData = null) + => new(false, "Access denied", additionalData); + + /// + /// Creates an representing a denied authorization with an HTTP challenge. + /// + /// Reason for the authorization denial. + /// The authorization challenge to include in the HTTP response. + /// An with set to . + public static AuthorizationResult DenyWithChallenge(string reason, AuthorizationChallenge challenge) + => new(false, reason, challenge); + + /// + /// Creates an representing a denied authorization with a Bearer token challenge. + /// + /// Reason for the authorization denial. + /// Optional realm parameter for the WWW-Authenticate header. + /// Optional scope parameter for the WWW-Authenticate header. + /// Optional error parameter for the WWW-Authenticate header (e.g., "insufficient_scope"). + /// Optional error_description parameter for the WWW-Authenticate header. + /// An with set to . + public static AuthorizationResult DenyWithBearerChallenge( + string reason, + string? realm = null, + string? scope = null, + string? error = null, + string? errorDescription = null) + { + var challenge = AuthorizationChallenge.CreateBearerChallenge(realm, scope, error, errorDescription); + return new(false, reason, challenge); + } + + /// + /// Creates an representing a denied authorization with a Basic authentication challenge. + /// + /// Reason for the authorization denial. + /// The realm parameter for the WWW-Authenticate header. + /// An with set to . + public static AuthorizationResult DenyWithBasicChallenge(string reason, string? realm = null) + { + var challenge = AuthorizationChallenge.CreateBasicChallenge(realm); + return new(false, reason, challenge); + } + + /// + /// Creates an representing a denied authorization due to insufficient scope. + /// + /// The scope required to access the resource. + /// Optional realm parameter for the WWW-Authenticate header. + /// An with set to . + public static AuthorizationResult DenyInsufficientScope(string requiredScope, string? realm = null) + { + var challenge = AuthorizationChallenge.CreateInsufficientScopeChallenge(requiredScope, realm); + return new(false, $"Insufficient scope. Required scope: {requiredScope}", challenge); + } + + /// + /// Creates an representing a denied authorization due to invalid token. + /// + /// Optional realm parameter for the WWW-Authenticate header. + /// Optional custom error description. + /// An with set to . + public static AuthorizationResult DenyInvalidToken(string? realm = null, string? errorDescription = null) + { + var challenge = AuthorizationChallenge.CreateInvalidTokenChallenge(realm, errorDescription); + return new(false, "Invalid or expired token", challenge); + } + + /// + /// Returns a string representation of the authorization result. + /// + /// A string describing the authorization result. + public override string ToString() + { + return IsAuthorized switch + { + true when !string.IsNullOrEmpty(Reason) => $"Authorized: {Reason}", + true => "Authorized", + false when !string.IsNullOrEmpty(Reason) => $"Denied: {Reason}", + false => "Denied" + }; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/Authorization/DenyAllToolFilter.cs b/src/ModelContextProtocol.Core/Server/Authorization/DenyAllToolFilter.cs new file mode 100644 index 00000000..61e15ee6 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/DenyAllToolFilter.cs @@ -0,0 +1,44 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server.Authorization; + +/// +/// A tool filter that denies access to all tools. +/// +/// +/// This filter is useful for lockdown scenarios or as a safety mechanism +/// to prevent any tool execution. It always returns authorization failure +/// for any tool access request. +/// +public sealed class DenyAllToolFilter : IToolFilter +{ + private readonly string _reason; + + /// + /// Initializes a new instance of the class. + /// + /// The priority for this filter. Default is 0 (highest priority). + /// The reason for denying access. Default is "All tools denied". + public DenyAllToolFilter(int priority = 0, string reason = "All tools denied") + { + Priority = priority; + _reason = reason ?? "All tools denied"; + } + + /// + public int Priority { get; } + + /// + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Never include any tools + return Task.FromResult(false); + } + + /// + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Always deny execution + return Task.FromResult(AuthorizationResult.Deny(_reason)); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/Authorization/IToolAuthorizationService.cs b/src/ModelContextProtocol.Core/Server/Authorization/IToolAuthorizationService.cs new file mode 100644 index 00000000..d5a489e1 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/IToolAuthorizationService.cs @@ -0,0 +1,87 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server.Authorization; + +/// +/// Defines the contract for tool authorization services that manage access control for MCP tools. +/// +/// +/// The tool authorization service acts as the central orchestrator for tool filtering, +/// coordinating multiple tool filters and providing a unified interface for authorization decisions. +/// +public interface IToolAuthorizationService +{ + /// + /// Filters a collection of tools based on the current authorization context. + /// + /// The collection of tools to filter. + /// The authorization context for the filtering operation. + /// Token to monitor for cancellation requests. + /// + /// A task that represents the asynchronous operation. The task result contains + /// a filtered collection of tools that the current user is authorized to see. + /// + /// + /// Thrown when or is . + /// + /// + /// This method applies all registered tool filters to determine which tools + /// should be visible to the requesting client. Tools are included only if + /// all filters allow them. + /// + Task> FilterToolsAsync(IEnumerable tools, ToolAuthorizationContext context, CancellationToken cancellationToken = default); + + /// + /// Determines whether a specific tool can be executed by the current user. + /// + /// The name of the tool to authorize for execution. + /// The authorization context for the execution check. + /// Token to monitor for cancellation requests. + /// + /// A task that represents the asynchronous operation. The task result contains + /// an indicating whether the tool execution is authorized. + /// + /// + /// Thrown when is or empty. + /// + /// + /// Thrown when is . + /// + /// + /// This method evaluates all registered tool filters to determine whether + /// the specified tool can be executed. If any filter denies access, the + /// authorization fails. + /// + Task AuthorizeToolExecutionAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default); + + /// + /// Registers a tool filter with the authorization service. + /// + /// The tool filter to register. + /// + /// Thrown when is . + /// + /// + /// Registered filters are executed in priority order when making authorization decisions. + /// Multiple filters with the same priority may execute in any order. + /// + void RegisterFilter(IToolFilter filter); + + /// + /// Unregisters a tool filter from the authorization service. + /// + /// The tool filter to unregister. + /// + /// Thrown when is . + /// + /// + /// If the specified filter is not currently registered, this method has no effect. + /// + void UnregisterFilter(IToolFilter filter); + + /// + /// Gets all currently registered tool filters. + /// + /// A read-only collection of registered tool filters. + IReadOnlyCollection GetRegisteredFilters(); +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/Authorization/IToolFilter.cs b/src/ModelContextProtocol.Core/Server/Authorization/IToolFilter.cs new file mode 100644 index 00000000..05603245 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/IToolFilter.cs @@ -0,0 +1,71 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server.Authorization; + +/// +/// Defines the contract for implementing tool filtering logic in MCP servers. +/// +/// +/// Tool filters allow MCP servers to control which tools are visible and accessible +/// to specific clients or user contexts. This enables fine-grained access control +/// and authorization for tool operations. +/// +public interface IToolFilter +{ + /// + /// Determines whether a specific tool should be included in the list of available tools. + /// + /// The tool to evaluate for inclusion. + /// The authorization context containing user and session information. + /// Token to monitor for cancellation requests. + /// + /// A task that represents the asynchronous operation. The task result contains + /// if the tool should be included; otherwise, . + /// + /// + /// Thrown when or is . + /// + /// + /// This method is called during tool listing operations to determine which tools + /// should be visible to the requesting client. Implementations should perform + /// authorization checks based on the provided context. + /// + Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default); + + /// + /// Determines whether a specific tool can be executed by the current user context. + /// + /// The name of the tool to authorize for execution. + /// The authorization context containing user and session information. + /// Token to monitor for cancellation requests. + /// + /// A task that represents the asynchronous operation. The task result contains + /// an indicating whether the operation is authorized. + /// + /// + /// Thrown when is or empty. + /// + /// + /// Thrown when is . + /// + /// + /// This method is called before tool execution to ensure the user has permission + /// to invoke the specified tool. Implementations should perform comprehensive + /// authorization checks and return detailed failure reasons when access is denied. + /// + Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default); + + /// + /// Gets the priority order for this filter when multiple filters are registered. + /// + /// + /// An integer representing the execution priority. Lower values indicate higher priority. + /// Filters with the same priority may execute in any order. + /// + /// + /// When multiple tool filters are registered, they are executed in priority order. + /// Higher priority filters (lower numeric values) are evaluated first. If any filter + /// denies access, the operation is rejected regardless of lower priority filter results. + /// + int Priority { get; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/Authorization/README.md b/src/ModelContextProtocol.Core/Server/Authorization/README.md new file mode 100644 index 00000000..8ed96334 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/README.md @@ -0,0 +1,323 @@ +# MCP Authorization System + +This directory contains the authorization system for the Model Context Protocol (MCP) C# SDK. The authorization system provides fine-grained access control for MCP tools with proper HTTP challenge responses. + +## Key Components + +### Core Interfaces and Classes + +- **`IToolAuthorizationService`** - Main service for orchestrating tool authorization +- **`IToolFilter`** - Interface for implementing custom authorization filters +- **`AuthorizationResult`** - Represents the result of an authorization check +- **`AuthorizationChallenge`** - Represents HTTP authorization challenges (WWW-Authenticate headers) +- **`AuthorizationHttpException`** - Exception for authorization failures that require HTTP challenges +- **`ToolAuthorizationContext`** - Context information for authorization decisions + +### Built-in Filters + +- **`AllowAllToolFilter`** - Allows access to all tools (default behavior) +- **`DenyAllToolFilter`** - Denies access to all tools +- **`ToolNamePatternFilter`** - Filters tools based on name patterns +- **`RoleBasedToolFilterBuilder`** - Builder for role-based authorization + +## Usage Examples + +### Basic Setup + +```csharp +// Configure the authorization service +services.AddSingleton(sp => +{ + var authService = new ToolAuthorizationService(); + + // Add your custom filters + authService.RegisterFilter(new MyCustomToolFilter()); + + return authService; +}); + +// Register your MCP server with tools +services.AddMcpServer(options => +{ + options.WithTools(); +}); +``` + +### OAuth2 Bearer Token Authorization + +```csharp +public class OAuth2ToolFilter : IToolFilter +{ + public int Priority => 100; + + public Task AuthorizeAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Extract and validate Bearer token + var token = ExtractBearerToken(context); + + if (string.IsNullOrEmpty(token)) + { + // No token provided - request authentication + return Task.FromResult(AuthorizationResult.DenyInvalidToken("my-api")); + } + + var claims = ValidateToken(token); + if (claims == null) + { + // Invalid token - request new authentication + return Task.FromResult(AuthorizationResult.DenyInvalidToken("my-api", "Token is expired or invalid")); + } + + // Check if tool requires specific scope + var requiredScope = GetRequiredScope(tool.Name); + if (requiredScope != null && !HasScope(claims, requiredScope)) + { + // Insufficient scope - request higher privileges + return Task.FromResult(AuthorizationResult.DenyInsufficientScope(requiredScope, "my-api")); + } + + return Task.FromResult(AuthorizationResult.Allow("Valid token with sufficient scope")); + } + + private string? ExtractBearerToken(ToolAuthorizationContext context) + { + // Extract token from context (implementation depends on your setup) + // You might get this from HTTP headers, session data, etc. + return null; + } + + private ClaimsPrincipal? ValidateToken(string token) + { + // Validate JWT token and return claims + // Implementation depends on your OAuth2 provider + return null; + } + + private string? GetRequiredScope(string toolName) + { + // Return the required scope for the tool + return toolName.Contains("admin") ? "admin:tools" : "user:tools"; + } + + private bool HasScope(ClaimsPrincipal claims, string scope) + { + // Check if the claims contain the required scope + return claims.HasClaim("scope", scope); + } +} +``` + +### Role-Based Authorization + +```csharp +public class RoleBasedToolFilter : IToolFilter +{ + public int Priority => 50; + + public Task AuthorizeAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var userRole = GetUserRole(context); + var requiredRole = GetRequiredRole(tool.Name); + + if (!HasRole(userRole, requiredRole)) + { + // Create a custom challenge for role-based access + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "Role", + ("required_role", requiredRole), + ("user_role", userRole ?? "none")); + + return Task.FromResult(AuthorizationResult.DenyWithChallenge( + $"Tool '{tool.Name}' requires '{requiredRole}' role, but user has '{userRole}' role", + challenge)); + } + + return Task.FromResult(AuthorizationResult.Allow($"User has required role: {requiredRole}")); + } + + private string? GetUserRole(ToolAuthorizationContext context) + { + // Extract user role from context + return "user"; // Example + } + + private string GetRequiredRole(string toolName) + { + // Determine required role based on tool name + return toolName.StartsWith("admin_") ? "admin" : "user"; + } + + private bool HasRole(string? userRole, string requiredRole) + { + // Check if user has the required role + return userRole == requiredRole || (userRole == "admin" && requiredRole == "user"); + } +} +``` + +### API Key Authorization + +```csharp +public class ApiKeyToolFilter : IToolFilter +{ + private readonly Dictionary _apiKeyScopes; + + public ApiKeyToolFilter(Dictionary apiKeyScopes) + { + _apiKeyScopes = apiKeyScopes; + } + + public int Priority => 75; + + public Task AuthorizeAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var apiKey = ExtractApiKey(context); + + if (string.IsNullOrEmpty(apiKey)) + { + // No API key provided + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "ApiKey", + ("realm", "mcp-api"), + ("parameter", "X-API-Key")); + + return Task.FromResult(AuthorizationResult.DenyWithChallenge( + "API key required", challenge)); + } + + if (!_apiKeyScopes.TryGetValue(apiKey, out var scopes)) + { + // Invalid API key + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "ApiKey", + ("realm", "mcp-api"), + ("error", "invalid_key")); + + return Task.FromResult(AuthorizationResult.DenyWithChallenge( + "Invalid API key", challenge)); + } + + var requiredScope = GetRequiredScope(tool.Name); + if (requiredScope != null && !scopes.Contains(requiredScope)) + { + // Insufficient scope for API key + var challenge = AuthorizationChallenge.CreateCustomChallenge( + "ApiKey", + ("realm", "mcp-api"), + ("required_scope", requiredScope), + ("available_scopes", string.Join(",", scopes))); + + return Task.FromResult(AuthorizationResult.DenyWithChallenge( + $"API key does not have required scope: {requiredScope}", challenge)); + } + + return Task.FromResult(AuthorizationResult.Allow("Valid API key with sufficient scope")); + } + + private string? ExtractApiKey(ToolAuthorizationContext context) + { + // Extract API key from context (e.g., from HTTP headers) + return null; + } + + private string? GetRequiredScope(string toolName) + { + // Determine required scope for the tool + return toolName.Contains("write") ? "write" : "read"; + } +} +``` + +## HTTP Challenge Responses + +When authorization fails, the system automatically generates proper HTTP responses with WWW-Authenticate headers: + +### OAuth2 Bearer Token Challenge +``` +HTTP/1.1 401 Unauthorized +WWW-Authenticate: Bearer realm="mcp-api", scope="write:tools", error="insufficient_scope", error_description="The request requires higher privileges" +Content-Type: application/json + +{ + "error": { + "code": -32002, + "message": "Access denied for tool 'admin_tool': Insufficient scope", + "data": { + "ToolName": "admin_tool", + "Reason": "Insufficient scope", + "HttpStatusCode": 401, + "RequiresAuthentication": true + } + } +} +``` + +### Basic Authentication Challenge +``` +HTTP/1.1 401 Unauthorized +WWW-Authenticate: Basic realm="mcp-api" +Content-Type: application/json + +{ + "error": { + "code": -32002, + "message": "Access denied for tool 'secure_tool': Authentication required", + "data": { + "ToolName": "secure_tool", + "Reason": "Authentication required", + "HttpStatusCode": 401, + "RequiresAuthentication": true + } + } +} +``` + +### Custom Authentication Challenge +``` +HTTP/1.1 401 Unauthorized +WWW-Authenticate: ApiKey realm="mcp-api", parameter="X-API-Key" +Content-Type: application/json + +{ + "error": { + "code": -32002, + "message": "Access denied for tool 'api_tool': API key required", + "data": { + "ToolName": "api_tool", + "Reason": "API key required", + "HttpStatusCode": 401, + "RequiresAuthentication": true + } + } +} +``` + +## Filter Priority and Execution Order + +Filters are executed in priority order (highest to lowest). If any filter denies access, the authorization fails immediately. All filters must allow access for the tool to be authorized. + +```csharp +// Higher priority filters run first +authService.RegisterFilter(new SecurityFilter { Priority = 1000 }); +authService.RegisterFilter(new RoleFilter { Priority = 500 }); +authService.RegisterFilter(new ScopeFilter { Priority = 100 }); +``` + +## Best Practices + +1. **Use specific error messages** - Provide clear reasons for authorization failures +2. **Include proper challenges** - Always provide WWW-Authenticate headers for 401 responses +3. **Implement proper token validation** - Validate tokens securely and check expiration +4. **Use appropriate HTTP status codes** - 401 for authentication issues, 403 for authorization issues +5. **Log authorization events** - Track authorization successes and failures for security monitoring +6. **Cache authorization decisions** - Consider caching where appropriate to improve performance +7. **Handle errors gracefully** - Fail securely when authorization checks encounter errors + +## Security Considerations + +- Never expose sensitive information in error messages +- Use HTTPS in production to protect credentials +- Implement proper token storage and handling +- Consider rate limiting for authorization endpoints +- Regularly audit and rotate API keys and secrets +- Implement proper logging for security events \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/Authorization/RoleBasedToolFilterBuilder.cs b/src/ModelContextProtocol.Core/Server/Authorization/RoleBasedToolFilterBuilder.cs new file mode 100644 index 00000000..e42ca5cc --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/RoleBasedToolFilterBuilder.cs @@ -0,0 +1,206 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server.Authorization; + +/// +/// Builder class for creating role-based tool filters. +/// +/// +/// This builder provides a fluent API for configuring role-based access control +/// for MCP tools, allowing developers to easily specify which roles can access +/// which tools. +/// +public sealed class RoleBasedToolFilterBuilder +{ + private readonly Dictionary> _toolRoles = new(); + private readonly Dictionary> _roleTools = new(); + private bool _defaultDeny = true; + private int _priority = 0; + + /// + /// Allows a specific role to access a specific tool. + /// + /// The role that should have access. + /// The name of the tool to allow access to. + /// This builder instance for method chaining. + /// + /// Thrown when or is null or empty. + /// + public RoleBasedToolFilterBuilder AllowRole(string role, string toolName) + { + if (string.IsNullOrEmpty(role)) + throw new ArgumentException("Role cannot be null or empty", nameof(role)); + if (string.IsNullOrEmpty(toolName)) + throw new ArgumentException("Tool name cannot be null or empty", nameof(toolName)); + + if (!_toolRoles.ContainsKey(toolName)) + { + _toolRoles[toolName] = new HashSet(); + } + _toolRoles[toolName].Add(role); + + if (!_roleTools.ContainsKey(role)) + { + _roleTools[role] = new HashSet(); + } + _roleTools[role].Add(toolName); + + return this; + } + + /// + /// Allows a specific role to access multiple tools. + /// + /// The role that should have access. + /// The names of the tools to allow access to. + /// This builder instance for method chaining. + /// + /// Thrown when is null or empty. + /// + /// + /// Thrown when is null. + /// + public RoleBasedToolFilterBuilder AllowRole(string role, params string[] toolNames) + { + if (string.IsNullOrEmpty(role)) + throw new ArgumentException("Role cannot be null or empty", nameof(role)); + + ArgumentNullException.ThrowIfNull(toolNames); + + foreach (var toolName in toolNames) + { + if (!string.IsNullOrEmpty(toolName)) + { + AllowRole(role, toolName); + } + } + + return this; + } + + /// + /// Allows multiple roles to access a specific tool. + /// + /// The name of the tool to allow access to. + /// The roles that should have access. + /// This builder instance for method chaining. + /// + /// Thrown when is null or empty. + /// + /// + /// Thrown when is null. + /// + public RoleBasedToolFilterBuilder AllowTool(string toolName, params string[] roles) + { + if (string.IsNullOrEmpty(toolName)) + throw new ArgumentException("Tool name cannot be null or empty", nameof(toolName)); + + ArgumentNullException.ThrowIfNull(roles); + + foreach (var role in roles) + { + if (!string.IsNullOrEmpty(role)) + { + AllowRole(role, toolName); + } + } + + return this; + } + + /// + /// Sets whether to deny access by default when no explicit rules match. + /// + /// + /// If (default), access is denied when no rules match. + /// If , access is allowed when no rules match. + /// + /// This builder instance for method chaining. + public RoleBasedToolFilterBuilder WithDefaultDeny(bool defaultDeny = true) + { + _defaultDeny = defaultDeny; + return this; + } + + /// + /// Sets the priority for this filter. + /// + /// + /// The priority value. Lower values indicate higher priority. + /// + /// This builder instance for method chaining. + public RoleBasedToolFilterBuilder WithPriority(int priority) + { + _priority = priority; + return this; + } + + /// + /// Builds the role-based tool filter with the configured rules. + /// + /// A configured implementation. + public IToolFilter Build() + { + return new RoleBasedToolFilter(_toolRoles, _defaultDeny, _priority); + } + + /// + /// Internal implementation of role-based tool filtering. + /// + private sealed class RoleBasedToolFilter : IToolFilter + { + private readonly IReadOnlyDictionary> _toolRoles; + private readonly bool _defaultDeny; + + public RoleBasedToolFilter(Dictionary> toolRoles, bool defaultDeny, int priority) + { + _toolRoles = toolRoles.ToDictionary( + kvp => kvp.Key, + kvp => kvp.Value, + StringComparer.OrdinalIgnoreCase); + _defaultDeny = defaultDeny; + Priority = priority; + } + + public int Priority { get; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(tool); + ArgumentNullException.ThrowIfNull(context); + + return Task.FromResult(IsAuthorized(tool.Name, context)); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (string.IsNullOrEmpty(toolName)) + throw new ArgumentException("Tool name cannot be null or empty", nameof(toolName)); + ArgumentNullException.ThrowIfNull(context); + + bool isAuthorized = IsAuthorized(toolName, context); + + return Task.FromResult(isAuthorized + ? AuthorizationResult.Allow($"Role-based access granted for tool '{toolName}'") + : AuthorizationResult.Deny($"Role-based access denied for tool '{toolName}' - insufficient permissions")); + } + + private bool IsAuthorized(string toolName, ToolAuthorizationContext context) + { + // If no roles are defined for this user, use default behavior + if (context.UserRoles.Count == 0) + { + return !_defaultDeny; + } + + // If no rules are defined for this tool, use default behavior + if (!_toolRoles.TryGetValue(toolName, out var allowedRoles)) + { + return !_defaultDeny; + } + + // Check if any of the user's roles are allowed for this tool + return context.UserRoles.Any(role => allowedRoles.Contains(role, StringComparer.OrdinalIgnoreCase)); + } + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/Authorization/SampleOAuthToolFilter.cs b/src/ModelContextProtocol.Core/Server/Authorization/SampleOAuthToolFilter.cs new file mode 100644 index 00000000..75f3fa67 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/SampleOAuthToolFilter.cs @@ -0,0 +1,124 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server.Authorization; + +/// +/// Sample tool filter that demonstrates OAuth2 Bearer token authorization challenges. +/// This is an example implementation showing how to use the authorization challenge system. +/// +/// +/// This filter demonstrates how to implement proper OAuth2 authorization challenges +/// with WWW-Authenticate headers when tools require specific scopes or valid tokens. +/// +public sealed class SampleOAuthToolFilter : IToolFilter +{ + private readonly string _requiredScope; + private readonly string? _realm; + + /// + /// Initializes a new instance of the class. + /// + /// The OAuth2 scope required to access tools. + /// Optional realm for the WWW-Authenticate header. + public SampleOAuthToolFilter(string requiredScope, string? realm = null) + { + _requiredScope = requiredScope ?? throw new ArgumentNullException(nameof(requiredScope)); + _realm = realm; + } + + /// + public int Priority => 100; + + /// + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // For tool visibility, include all tools but execution authorization will handle access control + // In a more restrictive implementation, you might hide tools the user cannot access + return Task.FromResult(true); + } + + /// + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Simulate checking if the user has the required scope + // In a real implementation, you would extract and validate the Bearer token + // from the authorization context and check its scopes + + if (IsHighPrivilegeTool(toolName)) + { + // For high-privilege tools, require specific scope + if (!HasRequiredScope(context)) + { + return Task.FromResult(AuthorizationResult.DenyInsufficientScope(_requiredScope, _realm)); + } + } + else if (RequiresAuthentication(toolName)) + { + // For tools that require authentication, check for valid token + if (!HasValidToken(context)) + { + return Task.FromResult(AuthorizationResult.DenyInvalidToken(_realm)); + } + } + + // Tool is authorized + return Task.FromResult(AuthorizationResult.Allow("Valid credentials")); + } + + /// + /// Determines if a tool is considered high-privilege and requires specific scopes. + /// + /// The name of the tool to check. + /// True if the tool requires elevated privileges, false otherwise. + private static bool IsHighPrivilegeTool(string toolName) + { + // Example: Tools that modify data or access sensitive information + return toolName.Contains("delete", StringComparison.OrdinalIgnoreCase) || + toolName.Contains("admin", StringComparison.OrdinalIgnoreCase) || + toolName.Contains("private", StringComparison.OrdinalIgnoreCase); + } + + /// + /// Determines if a tool requires authentication. + /// + /// The name of the tool to check. + /// True if the tool requires authentication, false otherwise. + private static bool RequiresAuthentication(string toolName) + { + // Example: Most tools require authentication except public read-only ones + return !toolName.Contains("public", StringComparison.OrdinalIgnoreCase) && + !toolName.Contains("read", StringComparison.OrdinalIgnoreCase); + } + + /// + /// Simulates checking if the current context has the required OAuth2 scope. + /// + /// The authorization context. + /// True if the required scope is present, false otherwise. + private bool HasRequiredScope(ToolAuthorizationContext context) + { + // In a real implementation, you would: + // 1. Extract the Bearer token from the authorization context + // 2. Validate the token with your OAuth2 provider + // 3. Check if the token includes the required scope + + // For this sample, simulate scope checking + return false; // Always deny for demonstration + } + + /// + /// Simulates checking if the current context has a valid authentication token. + /// + /// The authorization context. + /// True if a valid token is present, false otherwise. + private bool HasValidToken(ToolAuthorizationContext context) + { + // In a real implementation, you would: + // 1. Extract the Bearer token from the authorization context + // 2. Validate the token signature and expiration + // 3. Verify the token with your OAuth2 provider + + // For this sample, simulate token validation + return false; // Always deny for demonstration + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/Authorization/ToolAuthorizationContext.cs b/src/ModelContextProtocol.Core/Server/Authorization/ToolAuthorizationContext.cs new file mode 100644 index 00000000..bf615c9c --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/ToolAuthorizationContext.cs @@ -0,0 +1,193 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server.Authorization; + +/// +/// Provides context information for tool authorization operations. +/// +/// +/// This class contains all the contextual information needed by tool filters +/// to make authorization decisions, including user identity, session information, +/// and request metadata. +/// +public sealed class ToolAuthorizationContext +{ + /// + /// Initializes a new instance of the class. + /// + /// The unique identifier for the current session. + /// Information about the client making the request. + /// The capabilities supported by the server. + public ToolAuthorizationContext( + string? sessionId, + Implementation? clientInfo, + ServerCapabilities? serverCapabilities) + { + SessionId = sessionId; + ClientInfo = clientInfo; + ServerCapabilities = serverCapabilities; + Properties = new Dictionary(); + } + + /// + /// Gets the unique identifier for the current session. + /// + /// + /// A string representing the session identifier, or + /// if no session identifier is available. + /// + /// + /// The session ID can be used to track and correlate authorization decisions + /// across multiple requests within the same session. + /// + public string? SessionId { get; } + + /// + /// Gets information about the client making the request. + /// + /// + /// An object containing client details, + /// or if client information is not available. + /// + /// + /// Client information includes details such as the client name, version, + /// and other metadata that may be relevant for authorization decisions. + /// + public Implementation? ClientInfo { get; } + + /// + /// Gets the capabilities supported by the server. + /// + /// + /// A object describing server capabilities, + /// or if capability information is not available. + /// + /// + /// Server capabilities can be used by authorization filters to make decisions + /// based on what features and operations the server supports. + /// + public ServerCapabilities? ServerCapabilities { get; } + + /// + /// Gets or sets the user identifier for the current request. + /// + /// + /// A string representing the user identifier, or + /// if no user identifier is available. + /// + /// + /// The user ID is typically set by authentication middleware or during + /// the session initialization process. It can be used by authorization + /// filters to make user-specific access control decisions. + /// + public string? UserId { get; set; } + + /// + /// Gets or sets the user roles for the current request. + /// + /// + /// A collection of strings representing user roles, or an empty collection + /// if no roles are assigned. + /// + /// + /// User roles can be used by authorization filters to implement role-based + /// access control (RBAC) for tool operations. + /// + public ICollection UserRoles { get; set; } = new List(); + + /// + /// Gets or sets the user permissions for the current request. + /// + /// + /// A collection of strings representing user permissions, or an empty collection + /// if no permissions are assigned. + /// + /// + /// User permissions provide fine-grained access control capabilities beyond + /// role-based access control, allowing for specific operation-level authorization. + /// + public ICollection UserPermissions { get; set; } = new List(); + + /// + /// Gets a dictionary of additional properties that can be used to store custom context data. + /// + /// + /// A dictionary containing key-value pairs of additional context data. + /// + /// + /// This property allows for extensibility by enabling custom authorization + /// filters to store and retrieve implementation-specific context data. + /// + public IDictionary Properties { get; } + + /// + /// Creates a new with basic session information. + /// + /// The session identifier. + /// A new instance. + public static ToolAuthorizationContext ForSession(string? sessionId) + => new(sessionId, null, null); + + /// + /// Creates a new with session and client information. + /// + /// The session identifier. + /// Information about the client. + /// A new instance. + public static ToolAuthorizationContext ForSessionAndClient(string? sessionId, Implementation? clientInfo) + => new(sessionId, clientInfo, null); + + /// + /// Creates a copy of this context with additional user information. + /// + /// The user identifier. + /// Optional user roles. + /// Optional user permissions. + /// A new instance with user information. + public ToolAuthorizationContext WithUser(string userId, IEnumerable? roles = null, IEnumerable? permissions = null) + { + var context = new ToolAuthorizationContext(SessionId, ClientInfo, ServerCapabilities) + { + UserId = userId, + UserRoles = roles?.ToList() ?? new List(), + UserPermissions = permissions?.ToList() ?? new List() + }; + + // Copy existing properties + foreach (var property in Properties) + { + context.Properties[property.Key] = property.Value; + } + + return context; + } + + /// + /// Creates a copy of this context with additional properties. + /// + /// Additional properties to include. + /// A new instance with additional properties. + public ToolAuthorizationContext WithProperties(IDictionary properties) + { + var context = new ToolAuthorizationContext(SessionId, ClientInfo, ServerCapabilities) + { + UserId = UserId, + UserRoles = UserRoles, + UserPermissions = UserPermissions + }; + + // Copy existing properties + foreach (var property in Properties) + { + context.Properties[property.Key] = property.Value; + } + + // Add new properties + foreach (var property in properties) + { + context.Properties[property.Key] = property.Value; + } + + return context; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/Authorization/ToolAuthorizationService.cs b/src/ModelContextProtocol.Core/Server/Authorization/ToolAuthorizationService.cs new file mode 100644 index 00000000..256d612c --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/ToolAuthorizationService.cs @@ -0,0 +1,196 @@ +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; +using System.Collections.Concurrent; + +namespace ModelContextProtocol.Server.Authorization; + +/// +/// Default implementation of that coordinates multiple tool filters. +/// +/// +/// This service manages a collection of tool filters and applies them in priority order +/// to make authorization decisions for tool visibility and execution. +/// +internal sealed class ToolAuthorizationService : IToolAuthorizationService +{ + private readonly ConcurrentBag _filters = new(); + private readonly ILogger? _logger; + + /// + /// Initializes a new instance of the class. + /// + /// Optional logger for diagnostic information. + public ToolAuthorizationService(ILogger? logger = null) + { + _logger = logger; + } + + /// + /// Initializes a new instance of the class with initial filters. + /// + /// Initial collection of tool filters to register. + /// Optional logger for diagnostic information. + /// + /// Thrown when is . + /// + public ToolAuthorizationService(IEnumerable filters, ILogger? logger = null) + : this(logger) + { + Throw.IfNull(filters); + + foreach (var filter in filters) + { + if (filter is not null) + { + _filters.Add(filter); + } + } + } + + /// + public async Task> FilterToolsAsync(IEnumerable tools, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + Throw.IfNull(tools); + Throw.IfNull(context); + + var filters = GetSortedFilters(); + if (filters.Count == 0) + { + _logger?.LogDebug("No tool filters registered, returning all tools"); + return tools; + } + + var filteredTools = new List(); + + foreach (var tool in tools) + { + bool shouldInclude = true; + + foreach (var filter in filters) + { + try + { + if (!await filter.ShouldIncludeToolAsync(tool, context, cancellationToken).ConfigureAwait(false)) + { + shouldInclude = false; + _logger?.LogDebug("Tool '{ToolName}' filtered out by filter '{FilterType}'", tool.Name, filter.GetType().Name); + break; + } + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger?.LogError(ex, "Error in tool filter '{FilterType}' while evaluating tool '{ToolName}', denying access", filter.GetType().Name, tool.Name); + shouldInclude = false; + break; + } + } + + if (shouldInclude) + { + filteredTools.Add(tool); + _logger?.LogDebug("Tool '{ToolName}' included after filtering", tool.Name); + } + } + + _logger?.LogInformation("Filtered {OriginalCount} tools to {FilteredCount} tools", tools.Count(), filteredTools.Count); + return filteredTools; + } + + /// + public async Task AuthorizeToolExecutionAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (string.IsNullOrEmpty(toolName)) + { + throw new ArgumentException("Tool name cannot be null or empty", nameof(toolName)); + } + Throw.IfNull(context); + + var filters = GetSortedFilters(); + if (filters.Count == 0) + { + _logger?.LogDebug("No tool filters registered, allowing execution of tool '{ToolName}'", toolName); + return AuthorizationResult.Allow("No filters configured"); + } + + foreach (var filter in filters) + { + try + { + var result = await filter.CanExecuteToolAsync(toolName, context, cancellationToken).ConfigureAwait(false); + if (!result.IsAuthorized) + { + _logger?.LogWarning("Tool execution denied for '{ToolName}' by filter '{FilterType}': {Reason}", + toolName, filter.GetType().Name, result.Reason); + return result; + } + else + { + _logger?.LogDebug("Tool execution allowed for '{ToolName}' by filter '{FilterType}'", + toolName, filter.GetType().Name); + } + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger?.LogError(ex, "Error in tool filter '{FilterType}' while authorizing tool '{ToolName}', denying access", + filter.GetType().Name, toolName); + return AuthorizationResult.Deny($"Filter error: {ex.Message}"); + } + } + + _logger?.LogInformation("Tool execution authorized for '{ToolName}' after all filters", toolName); + return AuthorizationResult.Allow("All filters passed"); + } + + /// + public void RegisterFilter(IToolFilter filter) + { + Throw.IfNull(filter); + + _filters.Add(filter); + _logger?.LogInformation("Registered tool filter '{FilterType}' with priority {Priority}", + filter.GetType().Name, filter.Priority); + } + + /// + public void UnregisterFilter(IToolFilter filter) + { + Throw.IfNull(filter); + + // ConcurrentBag doesn't support removal, so we'll need to create a new collection + // This is not the most efficient approach, but it's thread-safe and filters are typically + // registered once during startup rather than dynamically during runtime + var existingFilters = _filters.ToList(); + var newFilters = existingFilters.Where(f => !ReferenceEquals(f, filter)).ToList(); + + if (existingFilters.Count != newFilters.Count) + { + // Clear and re-add the remaining filters + while (_filters.TryTake(out _)) { } + foreach (var remainingFilter in newFilters) + { + _filters.Add(remainingFilter); + } + + _logger?.LogInformation("Unregistered tool filter '{FilterType}'", filter.GetType().Name); + } + else + { + _logger?.LogDebug("Tool filter '{FilterType}' was not found for unregistration", filter.GetType().Name); + } + } + + /// + public IReadOnlyCollection GetRegisteredFilters() + { + return _filters.ToList().AsReadOnly(); + } + + /// + /// Gets all registered filters sorted by priority. + /// + /// A list of filters sorted by priority (ascending order). + private List GetSortedFilters() + { + return _filters.OrderBy(f => f.Priority).ToList(); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/Authorization/ToolFilterAggregator.cs b/src/ModelContextProtocol.Core/Server/Authorization/ToolFilterAggregator.cs new file mode 100644 index 00000000..0c5d9a68 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/ToolFilterAggregator.cs @@ -0,0 +1,155 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server.Authorization; + +/// +/// Aggregates multiple tool filters from dependency injection and applies them collectively. +/// +/// +/// This class automatically discovers and manages tool filters registered in the dependency +/// injection container, providing a convenient way to apply multiple filters without manual registration. +/// +internal sealed class ToolFilterAggregator : IToolFilter +{ + private readonly IServiceProvider _serviceProvider; + private readonly ILogger? _logger; + private IToolFilter[]? _cachedFilters; + + /// + /// Initializes a new instance of the class. + /// + /// The service provider to resolve tool filters from. + /// Optional logger for diagnostic information. + /// + /// Thrown when is . + /// + public ToolFilterAggregator(IServiceProvider serviceProvider, ILogger? logger = null) + { + _serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider)); + _logger = logger; + } + + /// + public int Priority => int.MinValue; // Aggregator runs with highest priority to coordinate other filters + + /// + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + Throw.IfNull(tool); + Throw.IfNull(context); + + var filters = GetFilters(); + if (filters.Length == 0) + { + _logger?.LogDebug("No tool filters found in DI container, allowing tool '{ToolName}'", tool.Name); + return true; + } + + foreach (var filter in filters) + { + try + { + if (!await filter.ShouldIncludeToolAsync(tool, context, cancellationToken).ConfigureAwait(false)) + { + _logger?.LogDebug("Tool '{ToolName}' filtered out by aggregated filter '{FilterType}'", + tool.Name, filter.GetType().Name); + return false; + } + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger?.LogError(ex, "Error in aggregated tool filter '{FilterType}' while evaluating tool '{ToolName}', denying access", + filter.GetType().Name, tool.Name); + return false; + } + } + + _logger?.LogDebug("Tool '{ToolName}' passed all aggregated filters", tool.Name); + return true; + } + + /// + public async Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (string.IsNullOrEmpty(toolName)) + { + throw new ArgumentException("Tool name cannot be null or empty", nameof(toolName)); + } + Throw.IfNull(context); + + var filters = GetFilters(); + if (filters.Length == 0) + { + _logger?.LogDebug("No tool filters found in DI container, allowing execution of tool '{ToolName}'", toolName); + return AuthorizationResult.Allow("No filters configured"); + } + + foreach (var filter in filters) + { + try + { + var result = await filter.CanExecuteToolAsync(toolName, context, cancellationToken).ConfigureAwait(false); + if (!result.IsAuthorized) + { + _logger?.LogWarning("Tool execution denied for '{ToolName}' by aggregated filter '{FilterType}': {Reason}", + toolName, filter.GetType().Name, result.Reason); + return result; + } + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger?.LogError(ex, "Error in aggregated tool filter '{FilterType}' while authorizing tool '{ToolName}', denying access", + filter.GetType().Name, toolName); + return AuthorizationResult.Deny($"Filter error in {filter.GetType().Name}: {ex.Message}"); + } + } + + _logger?.LogInformation("Tool execution authorized for '{ToolName}' by all aggregated filters", toolName); + return AuthorizationResult.Allow("All aggregated filters passed"); + } + + /// + /// Gets all tool filters from the dependency injection container, sorted by priority. + /// + /// An array of tool filters sorted by priority. + private IToolFilter[] GetFilters() + { + if (_cachedFilters is not null) + { + return _cachedFilters; + } + + try + { + var filters = _serviceProvider.GetServices() + .Where(f => f != this) // Exclude self to avoid infinite recursion + .OrderBy(f => f.Priority) + .ToArray(); + + _cachedFilters = filters; + _logger?.LogDebug("Discovered {FilterCount} tool filters from DI container", filters.Length); + + return filters; + } + catch (Exception ex) + { + _logger?.LogError(ex, "Error resolving tool filters from DI container"); + return Array.Empty(); + } + } + + /// + /// Clears the cached filters, forcing them to be re-resolved on the next access. + /// + /// + /// This method can be useful in scenarios where filters are registered dynamically + /// and the aggregator needs to pick up the changes. + /// + public void ClearCache() + { + _cachedFilters = null; + _logger?.LogDebug("Cleared tool filter cache"); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/Authorization/ToolNamePatternFilter.cs b/src/ModelContextProtocol.Core/Server/Authorization/ToolNamePatternFilter.cs new file mode 100644 index 00000000..e076211b --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/Authorization/ToolNamePatternFilter.cs @@ -0,0 +1,201 @@ +using ModelContextProtocol.Protocol; +using System.Text.RegularExpressions; + +namespace ModelContextProtocol.Server.Authorization; + +/// +/// A tool filter that allows or denies access based on tool name patterns. +/// +/// +/// This filter uses regular expressions or simple string matching to determine +/// which tools should be accessible. It supports both allow-list and deny-list patterns. +/// +public sealed class ToolNamePatternFilter : IToolFilter +{ + private readonly List _allowPatterns; + private readonly List _denyPatterns; + private readonly bool _defaultAllow; + + /// + /// Initializes a new instance of the class. + /// + /// The priority for this filter. + /// + /// Whether to allow access by default when no patterns match. + /// If , tools are allowed unless explicitly denied. + /// If , tools are denied unless explicitly allowed. + /// + public ToolNamePatternFilter(int priority = 100, bool defaultAllow = false) + { + Priority = priority; + _defaultAllow = defaultAllow; + _allowPatterns = new List(); + _denyPatterns = new List(); + } + + /// + public int Priority { get; } + + /// + /// Adds a pattern that allows access to matching tool names. + /// + /// The regular expression pattern to match tool names. + /// Optional regex options. Default is case-insensitive. + /// This filter instance for method chaining. + /// + /// Thrown when is null or empty. + /// + public ToolNamePatternFilter Allow(string pattern, RegexOptions options = RegexOptions.IgnoreCase) + { + if (string.IsNullOrEmpty(pattern)) + throw new ArgumentException("Pattern cannot be null or empty", nameof(pattern)); + + _allowPatterns.Add(new Regex(pattern, options | RegexOptions.Compiled)); + return this; + } + + /// + /// Adds a pattern that denies access to matching tool names. + /// + /// The regular expression pattern to match tool names. + /// Optional regex options. Default is case-insensitive. + /// This filter instance for method chaining. + /// + /// Thrown when is null or empty. + /// + public ToolNamePatternFilter Deny(string pattern, RegexOptions options = RegexOptions.IgnoreCase) + { + if (string.IsNullOrEmpty(pattern)) + throw new ArgumentException("Pattern cannot be null or empty", nameof(pattern)); + + _denyPatterns.Add(new Regex(pattern, options | RegexOptions.Compiled)); + return this; + } + + /// + /// Adds multiple patterns that allow access to matching tool names. + /// + /// The regular expression patterns to match tool names. + /// Optional regex options. Default is case-insensitive. + /// This filter instance for method chaining. + /// + /// Thrown when is null. + /// + public ToolNamePatternFilter AllowMany(IEnumerable patterns, RegexOptions options = RegexOptions.IgnoreCase) + { + ArgumentNullException.ThrowIfNull(patterns); + + foreach (var pattern in patterns) + { + if (!string.IsNullOrEmpty(pattern)) + { + Allow(pattern, options); + } + } + + return this; + } + + /// + /// Adds multiple patterns that deny access to matching tool names. + /// + /// The regular expression patterns to match tool names. + /// Optional regex options. Default is case-insensitive. + /// This filter instance for method chaining. + /// + /// Thrown when is null. + /// + public ToolNamePatternFilter DenyMany(IEnumerable patterns, RegexOptions options = RegexOptions.IgnoreCase) + { + ArgumentNullException.ThrowIfNull(patterns); + + foreach (var pattern in patterns) + { + if (!string.IsNullOrEmpty(pattern)) + { + Deny(pattern, options); + } + } + + return this; + } + + /// + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(tool); + ArgumentNullException.ThrowIfNull(context); + + bool isAllowed = EvaluateAccess(tool.Name); + return Task.FromResult(isAllowed); + } + + /// + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (string.IsNullOrEmpty(toolName)) + throw new ArgumentException("Tool name cannot be null or empty", nameof(toolName)); + ArgumentNullException.ThrowIfNull(context); + + bool isAllowed = EvaluateAccess(toolName); + + return Task.FromResult(isAllowed + ? AuthorizationResult.Allow($"Tool '{toolName}' matches allowed patterns") + : AuthorizationResult.Deny($"Tool '{toolName}' is not allowed by pattern filter")); + } + + /// + /// Evaluates whether access should be granted for the specified tool name. + /// + /// The tool name to evaluate. + /// if access should be granted; otherwise, . + private bool EvaluateAccess(string toolName) + { + // Check deny patterns first (they take precedence) + foreach (var denyPattern in _denyPatterns) + { + if (denyPattern.IsMatch(toolName)) + { + return false; + } + } + + // Check allow patterns + foreach (var allowPattern in _allowPatterns) + { + if (allowPattern.IsMatch(toolName)) + { + return true; + } + } + + // No patterns matched, return default behavior + return _defaultAllow; + } + + /// + /// Creates a filter that allows only tools matching the specified patterns. + /// + /// Patterns that allow tool access. + /// The priority for this filter. + /// A new instance. + public static ToolNamePatternFilter CreateAllowList(IEnumerable allowPatterns, int priority = 100) + { + var filter = new ToolNamePatternFilter(priority, defaultAllow: false); + filter.AllowMany(allowPatterns); + return filter; + } + + /// + /// Creates a filter that denies only tools matching the specified patterns. + /// + /// Patterns that deny tool access. + /// The priority for this filter. + /// A new instance. + public static ToolNamePatternFilter CreateDenyList(IEnumerable denyPatterns, int priority = 100) + { + var filter = new ToolNamePatternFilter(priority, defaultAllow: true); + filter.DenyMany(denyPatterns); + return filter; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index 6c5858f9..80c839e7 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server.Authorization; using System.Runtime.CompilerServices; using System.Text.Json.Serialization.Metadata; @@ -434,7 +435,18 @@ private void ConfigureTools(McpServerOptions options) var tools = toolsCapability.ToolCollection; var listChanged = toolsCapability.ListChanged; - // Handle tools provided via DI by augmenting the handlers to incorporate them. + /* + This code implements a decorator pattern (not recursion) to layer multiple tool sources together. Here's what + the author accomplished: + + The Problem + + The MCP server needs to handle tools from two sources: + 1. User-provided handlers (via options.Capabilities.Tools) + 2. DI-registered tools (via tools collection) + + The Solution: Handler Decoration + */ if (tools is { IsEmpty: false }) { var originalListToolsHandler = listToolsHandler; @@ -446,25 +458,91 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) if (request.Params?.Cursor is null) { + // Add tools from this server's collection to the result foreach (var t in tools) { result.Tools.Add(t.ProtocolTool); } } + + // Apply tool filtering to the combined result (from all sources) regardless of pagination + var authorizationService = Services?.GetService(); + if (authorizationService is not null) + { + try + { + var authContext = CreateAuthorizationContext(request); + var allToolsInResult = result.Tools.ToList(); + var filteredTools = await authorizationService.FilterToolsAsync(allToolsInResult, authContext, cancellationToken).ConfigureAwait(false); + + // Replace the tools in the result with the filtered list + result.Tools.Clear(); + result.Tools.AddRange(filteredTools); + } + catch (Exception ex) + { + // Log error but keep the unfiltered result + Logger?.LogError(ex, "Error during tool filtering, returning unfiltered tools"); + } + } return result; }; var originalCallToolHandler = callToolHandler; - callToolHandler = (request, cancellationToken) => + callToolHandler = async (request, cancellationToken) => { - if (request.Params is not null && - tools.TryGetPrimitive(request.Params.Name, out var tool)) + if (request.Params is not null) { - return tool.InvokeAsync(request, cancellationToken); + // Check authorization before executing the tool + var authorizationService = Services?.GetService(); + if (authorizationService is not null) + { + try + { + var authContext = CreateAuthorizationContext(request); + var authResult = await authorizationService.AuthorizeToolExecutionAsync(request.Params.Name, authContext, cancellationToken).ConfigureAwait(false); + + if (!authResult.IsAuthorized) + { + // Check if the authorization result includes challenge information + if (authResult.AdditionalData is AuthorizationChallenge challenge) + { + throw new AuthorizationHttpException( + request.Params.Name, + authResult.Reason ?? "Access denied", + challenge.WwwAuthenticateValue, + challenge.HttpStatusCode); + } + else + { + // Default to a generic authorization exception + throw new AuthorizationHttpException( + request.Params.Name, + authResult.Reason ?? "Access denied"); + } + } + } + catch (McpException) + { + throw; // Re-throw MCP exceptions as-is + } + catch (Exception ex) + { + // Log error and deny access on authorization failure + Logger?.LogError(ex, "Error during tool authorization check for '{ToolName}', denying access", request.Params.Name); + throw new McpException($"Authorization error for tool '{request.Params.Name}'", McpErrorCode.InternalError); + } + } + + // Proceed with tool execution if authorized + if (tools.TryGetPrimitive(request.Params.Name, out var tool)) + { + return await tool.InvokeAsync(request, cancellationToken).ConfigureAwait(false); + } } - return originalCallToolHandler(request, cancellationToken); + return await originalCallToolHandler(request, cancellationToken).ConfigureAwait(false); }; listChanged = true; @@ -582,6 +660,19 @@ private void UpdateEndpointNameWithClientInfo() _endpointName = $"{_serverOnlyEndpointName}, Client ({ClientInfo.Name} {ClientInfo.Version})"; } + /// + /// Creates an authorization context for the current request. + /// + /// The request context containing session and client information. + /// A with current session information. + private ToolAuthorizationContext CreateAuthorizationContext(RequestContext request) + { + return new ToolAuthorizationContext( + SessionId, + ClientInfo, + ServerCapabilities); + } + /// Maps a to a . internal static LoggingLevel ToLoggingLevel(LogLevel level) => level switch diff --git a/src/ModelContextProtocol/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/McpServerBuilderExtensions.cs index d925b24f..4562fe3f 100644 --- a/src/ModelContextProtocol/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpServerBuilderExtensions.cs @@ -4,6 +4,7 @@ using ModelContextProtocol; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using ModelContextProtocol.Server.Authorization; using System.Diagnostics.CodeAnalysis; using System.Reflection; using System.Text.Json; @@ -779,6 +780,252 @@ private static void AddSingleSessionServerDependencies(IServiceCollection servic } #endregion + #region Tool Authorization + /// + /// Adds tool authorization services to the MCP server. + /// + /// The builder instance. + /// The builder provided in . + /// is . + /// + /// This method registers the default implementation + /// that can coordinate multiple tool filters for access control. + /// + public static IMcpServerBuilder WithToolAuthorization(this IMcpServerBuilder builder) + { + Throw.IfNull(builder); + + builder.Services.TryAddSingleton(); + return builder; + } + + /// + /// Adds a specific tool filter to the MCP server. + /// + /// The type of tool filter to add. + /// The builder instance. + /// The builder provided in . + /// is . + /// + /// This method registers a tool filter that will be automatically discovered + /// and used by the tool authorization system. The filter must implement . + /// + public static IMcpServerBuilder WithToolFilter(this IMcpServerBuilder builder) + where TFilter : class, IToolFilter + { + Throw.IfNull(builder); + + builder.Services.AddTransient(); + return builder; + } + + /// + /// Adds a specific tool filter instance to the MCP server. + /// + /// The builder instance. + /// The tool filter instance to add. + /// The builder provided in . + /// is . + /// is . + /// + /// This method registers a specific tool filter instance that will be used + /// by the tool authorization system. + /// + public static IMcpServerBuilder WithToolFilter(this IMcpServerBuilder builder, IToolFilter filter) + { + Throw.IfNull(builder); + Throw.IfNull(filter); + + builder.Services.AddSingleton(filter); + return builder; + } + + /// + /// Adds multiple tool filters to the MCP server. + /// + /// The builder instance. + /// The tool filter instances to add. + /// The builder provided in . + /// is . + /// is . + /// + /// This method registers multiple tool filter instances that will be used + /// by the tool authorization system. + /// + public static IMcpServerBuilder WithToolFilters(this IMcpServerBuilder builder, IEnumerable filters) + { + Throw.IfNull(builder); + Throw.IfNull(filters); + + foreach (var filter in filters) + { + if (filter is not null) + { + builder.Services.AddSingleton(filter); + } + } + + return builder; + } + + /// + /// Adds tool filter aggregation to the MCP server, which automatically discovers + /// and applies all registered tool filters. + /// + /// The builder instance. + /// The builder provided in . + /// is . + /// + /// This method enables automatic discovery and coordination of all tool filters + /// registered in the dependency injection container. This is useful when you + /// have multiple filters and want them to be automatically applied without + /// manual coordination. + /// + public static IMcpServerBuilder WithToolFilterAggregation(this IMcpServerBuilder builder) + { + Throw.IfNull(builder); + + builder.Services.AddSingleton(); + return builder; + } + + /// + /// Configures role-based tool filtering for the MCP server. + /// + /// The builder instance. + /// A delegate to configure role-based tool access. + /// The builder provided in . + /// is . + /// is . + /// + /// This method provides a convenient way to set up role-based access control + /// for tools using a simple configuration approach. + /// + public static IMcpServerBuilder WithRoleBasedToolFiltering(this IMcpServerBuilder builder, Action configureRoles) + { + Throw.IfNull(builder); + Throw.IfNull(configureRoles); + + var roleBuilder = new RoleBasedToolFilterBuilder(); + configureRoles(roleBuilder); + + var filter = roleBuilder.Build(); + builder.Services.AddSingleton(filter); + + return builder; + } + + /// + /// Adds an allow-all tool filter that grants access to all tools without restrictions. + /// + /// The builder instance. + /// The priority for this filter. Default is (lowest priority). + /// The builder provided in . + /// is . + /// + /// This filter is useful for development environments or scenarios where + /// no access control is required. + /// + public static IMcpServerBuilder WithAllowAllToolFilter(this IMcpServerBuilder builder, int priority = int.MaxValue) + { + Throw.IfNull(builder); + + builder.Services.AddSingleton(new AllowAllToolFilter(priority)); + return builder; + } + + /// + /// Adds a deny-all tool filter that blocks access to all tools. + /// + /// The builder instance. + /// The priority for this filter. Default is 0 (highest priority). + /// The reason for denying access. Default is "All tools denied". + /// The builder provided in . + /// is . + /// + /// This filter is useful for lockdown scenarios or as a safety mechanism + /// to prevent any tool execution. + /// + public static IMcpServerBuilder WithDenyAllToolFilter(this IMcpServerBuilder builder, int priority = 0, string reason = "All tools denied") + { + Throw.IfNull(builder); + + builder.Services.AddSingleton(new DenyAllToolFilter(priority, reason)); + return builder; + } + + /// + /// Adds a pattern-based tool filter that allows access only to tools matching specified patterns. + /// + /// The builder instance. + /// Regular expression patterns that allow tool access. + /// The priority for this filter. Default is 100. + /// The builder provided in . + /// is . + /// is . + /// + /// This creates an allow-list filter where only tools matching the specified + /// regular expression patterns are allowed access. + /// + public static IMcpServerBuilder WithToolAllowListFilter(this IMcpServerBuilder builder, IEnumerable allowPatterns, int priority = 100) + { + Throw.IfNull(builder); + Throw.IfNull(allowPatterns); + + var filter = ToolNamePatternFilter.CreateAllowList(allowPatterns, priority); + builder.Services.AddSingleton(filter); + return builder; + } + + /// + /// Adds a pattern-based tool filter that denies access to tools matching specified patterns. + /// + /// The builder instance. + /// Regular expression patterns that deny tool access. + /// The priority for this filter. Default is 100. + /// The builder provided in . + /// is . + /// is . + /// + /// This creates a deny-list filter where tools matching the specified + /// regular expression patterns are denied access, but all others are allowed. + /// + public static IMcpServerBuilder WithToolDenyListFilter(this IMcpServerBuilder builder, IEnumerable denyPatterns, int priority = 100) + { + Throw.IfNull(builder); + Throw.IfNull(denyPatterns); + + var filter = ToolNamePatternFilter.CreateDenyList(denyPatterns, priority); + builder.Services.AddSingleton(filter); + return builder; + } + + /// + /// Adds a pattern-based tool filter with custom configuration. + /// + /// The builder instance. + /// A delegate to configure the pattern filter. + /// The priority for this filter. Default is 100. + /// Whether to allow access by default when no patterns match. Default is false. + /// The builder provided in . + /// is . + /// is . + /// + /// This method provides maximum flexibility for configuring pattern-based + /// tool filtering with custom allow and deny patterns. + /// + public static IMcpServerBuilder WithToolPatternFilter(this IMcpServerBuilder builder, Action configureFilter, int priority = 100, bool defaultAllow = false) + { + Throw.IfNull(builder); + Throw.IfNull(configureFilter); + + var filter = new ToolNamePatternFilter(priority, defaultAllow); + configureFilter(filter); + builder.Services.AddSingleton(filter); + return builder; + } + #endregion + #region Helpers /// Creates an instance of the target object. private static object CreateTarget( diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizationChallengeTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizationChallengeTests.cs new file mode 100644 index 00000000..f6089b0e --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizationChallengeTests.cs @@ -0,0 +1,181 @@ +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc.Testing; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using ModelContextProtocol.AspNetCore; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Server.Authorization; +using System.Net; +using System.Net.Http.Headers; +using System.Text; +using System.Text.Json; +using Xunit; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Tests for authorization challenge functionality in MCP HTTP transport. +/// +public class AuthorizationChallengeTests : IClassFixture> +{ + private readonly WebApplicationFactory _factory; + + public AuthorizationChallengeTests(WebApplicationFactory factory) + { + _factory = factory; + } + + [Fact] + public async Task CallTool_WithAuthorizationDenied_ReturnsWwwAuthenticateHeader() + { + // Arrange + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + services.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(); + // Add a filter that always denies with OAuth2 challenge + authService.RegisterFilter(new SampleOAuthToolFilter("write:tools", "mcp-server")); + return authService; + }); + + services.AddScoped(sp => new TestTool()); + }); + }).CreateClient(); + + // Create a call tool request + var callToolRequest = new JsonRpcRequest + { + Id = RequestId.FromString("test-1"), + Method = RequestMethods.ToolsCall, + Params = new CallToolRequestParams + { + Name = "admin_delete_tool", + Arguments = new Dictionary() + } + }; + + var requestJson = JsonSerializer.Serialize(callToolRequest, McpJsonUtilities.JsonContext.Default.JsonRpcRequest); + var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + // Set required Accept headers for MCP + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + + // Act + var response = await client.PostAsync("/mcp", content); + + // Assert + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + Assert.True(response.Headers.Contains("WWW-Authenticate")); + + var wwwAuthHeader = response.Headers.GetValues("WWW-Authenticate").First(); + Assert.Contains("Bearer", wwwAuthHeader); + Assert.Contains("scope=\"write:tools\"", wwwAuthHeader); + Assert.Contains("error=\"insufficient_scope\"", wwwAuthHeader); + Assert.Contains("realm=\"mcp-server\"", wwwAuthHeader); + + var responseContent = await response.Content.ReadAsStringAsync(); + var errorResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcError); + + Assert.NotNull(errorResponse.Error); + Assert.Equal((int)McpErrorCode.InvalidParams, errorResponse.Error.Code); + Assert.Contains("admin_delete_tool", errorResponse.Error.Message); + Assert.Contains("Insufficient scope", errorResponse.Error.Message); + } + + [Fact] + public async Task CallTool_WithInvalidToken_ReturnsWwwAuthenticateHeader() + { + // Arrange + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + services.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(); + // Add a filter that denies with invalid token challenge + authService.RegisterFilter(new AlwaysDenyFilter()); + return authService; + }); + + services.AddScoped(sp => new TestTool()); + }); + }).CreateClient(); + + // Create a call tool request for a tool that requires authentication + var callToolRequest = new JsonRpcRequest + { + Id = RequestId.FromString("test-2"), + Method = RequestMethods.ToolsCall, + Params = new CallToolRequestParams + { + Name = "secure_tool", + Arguments = new Dictionary() + } + }; + + var requestJson = JsonSerializer.Serialize(callToolRequest, McpJsonUtilities.JsonContext.Default.JsonRpcRequest); + var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + // Set required Accept headers for MCP + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + + // Act + var response = await client.PostAsync("/mcp", content); + + // Assert + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + Assert.True(response.Headers.Contains("WWW-Authenticate")); + + var wwwAuthHeader = response.Headers.GetValues("WWW-Authenticate").First(); + Assert.Contains("Bearer", wwwAuthHeader); + Assert.Contains("error=\"invalid_token\"", wwwAuthHeader); + + var responseContent = await response.Content.ReadAsStringAsync(); + var errorResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcError); + + Assert.NotNull(errorResponse.Error); + Assert.Equal((int)McpErrorCode.InvalidParams, errorResponse.Error.Code); + Assert.Contains("secure_tool", errorResponse.Error.Message); + } + + /// + /// Test tool for authorization challenge testing. + /// + private class TestTool : McpServerTool + { + public override Tool ProtocolTool => new() + { + Name = "test_tool", + Description = "A test tool for authorization testing" + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents { Text = "Tool executed successfully" }] + }); + } + } + + /// + /// Tool filter that always denies access with an invalid token challenge. + /// + private class AlwaysDenyFilter : IToolFilter + { + public int Priority => 1; + + public Task AuthorizeAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + return Task.FromResult(AuthorizationResult.DenyInvalidToken("mcp-server")); + } + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/Authorization/AspNetCoreAuthorizationIntegrationTests.cs b/tests/ModelContextProtocol.Tests/Server/Authorization/AspNetCoreAuthorizationIntegrationTests.cs new file mode 100644 index 00000000..8395cf98 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/Authorization/AspNetCoreAuthorizationIntegrationTests.cs @@ -0,0 +1,469 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc.Testing; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.AspNetCore; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Server.Authorization; +using ModelContextProtocol.Tests.Utils; +using System.Net; +using System.Net.Http.Headers; +using System.Security.Claims; +using System.Text; +using System.Text.Json; +using Xunit; + +namespace ModelContextProtocol.Tests.Server.Authorization; + +/// +/// Integration tests for ASP.NET Core authorization scenarios. +/// +public class AspNetCoreAuthorizationIntegrationTests : LoggedTest, IClassFixture> +{ + private readonly WebApplicationFactory _factory; + + public AspNetCoreAuthorizationIntegrationTests(WebApplicationFactory factory, ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + _factory = factory; + } + + [Fact] + public async Task Authorization_WithDependencyInjectedFilters_WorksCorrectly() + { + // Arrange + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + // Register filters via DI + services.AddSingleton(new AllowAllToolFilter(priority: 100)); + services.AddSingleton(new ToolNamePatternFilter(new[] { "admin_*" }, allowMatching: false, priority: 1)); + services.AddSingleton(); + + services.AddScoped(sp => new TestTool("admin_delete")); + services.AddScoped(sp => new TestTool("user_profile")); + }); + }).CreateClient(); + + // Test ListTools + var listRequest = CreateJsonRpcRequest(RequestMethods.ToolsList, new ListToolsRequestParams()); + var listResponse = await SendRequest(client, listRequest); + + Assert.Equal(HttpStatusCode.OK, listResponse.StatusCode); + var listContent = await listResponse.Content.ReadAsStringAsync(); + var listResult = JsonSerializer.Deserialize(listContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse); + + Assert.NotNull(listResult.Result); + var toolsResult = JsonSerializer.Deserialize(listResult.Result.ToString()!, McpJsonUtilities.JsonContext.Default.ListToolsResult); + + // Should only have user_profile tool (admin_delete filtered out) + Assert.Single(toolsResult.Tools); + Assert.Equal("user_profile", toolsResult.Tools[0].Name); + + // Test CallTool for allowed tool + var callRequest = CreateJsonRpcRequest(RequestMethods.ToolsCall, new CallToolRequestParams + { + Name = "user_profile", + Arguments = new Dictionary() + }); + var callResponse = await SendRequest(client, callRequest); + + Assert.Equal(HttpStatusCode.OK, callResponse.StatusCode); + + // Test CallTool for denied tool + var deniedRequest = CreateJsonRpcRequest(RequestMethods.ToolsCall, new CallToolRequestParams + { + Name = "admin_delete", + Arguments = new Dictionary() + }); + var deniedResponse = await SendRequest(client, deniedRequest); + + Assert.Equal(HttpStatusCode.BadRequest, deniedResponse.StatusCode); // Tool not found/accessible + } + + [Fact] + public async Task Authorization_WithRoleBasedFiltering_RespectsUserClaims() + { + // Arrange + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + services.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + var filter = RoleBasedToolFilterBuilder.Create() + .RequireRole("admin") + .ForToolsMatching("admin_*") + .Build(); + authService.RegisterFilter(filter); + return authService; + }); + + services.AddScoped(sp => new TestTool("admin_panel")); + services.AddScoped(sp => new TestTool("user_dashboard")); + + // Add authentication for setting up user context + services.AddAuthentication("Test") + .AddScheme("Test", _ => { }); + }); + + builder.Configure(app => + { + app.UseAuthentication(); + app.UseAuthorization(); + + app.Use(async (context, next) => + { + // Set up user context based on request headers + if (context.Request.Headers.TryGetValue("X-User-Role", out var roleHeader)) + { + var identity = new ClaimsIdentity("Test"); + identity.AddClaim(new Claim(ClaimTypes.Role, roleHeader.ToString())); + context.User = new ClaimsPrincipal(identity); + } + await next(); + }); + + app.MapMcp("/mcp"); + }); + }).CreateClient(); + + // Test without admin role + var listRequest = CreateJsonRpcRequest(RequestMethods.ToolsList, new ListToolsRequestParams()); + var listResponse = await SendRequest(client, listRequest); + + Assert.Equal(HttpStatusCode.OK, listResponse.StatusCode); + var listContent = await listResponse.Content.ReadAsStringAsync(); + var listResult = JsonSerializer.Deserialize(listContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse); + var toolsResult = JsonSerializer.Deserialize(listResult.Result.ToString()!, McpJsonUtilities.JsonContext.Default.ListToolsResult); + + // Should only have user_dashboard tool (admin_panel filtered out) + Assert.Single(toolsResult.Tools); + Assert.Equal("user_dashboard", toolsResult.Tools[0].Name); + + // Test with admin role + client.DefaultRequestHeaders.Add("X-User-Role", "admin"); + var adminListResponse = await SendRequest(client, listRequest); + + Assert.Equal(HttpStatusCode.OK, adminListResponse.StatusCode); + var adminListContent = await adminListResponse.Content.ReadAsStringAsync(); + var adminListResult = JsonSerializer.Deserialize(adminListContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse); + var adminToolsResult = JsonSerializer.Deserialize(adminListResult.Result.ToString()!, McpJsonUtilities.JsonContext.Default.ListToolsResult); + + // Should have both tools with admin role + Assert.Equal(2, adminToolsResult.Tools.Count); + Assert.Contains(adminToolsResult.Tools, t => t.Name == "admin_panel"); + Assert.Contains(adminToolsResult.Tools, t => t.Name == "user_dashboard"); + } + + [Fact] + public async Task Authorization_WithSessionContext_PassesCorrectContext() + { + // Arrange + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + services.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + authService.RegisterFilter(new TestContextCapturingFilter()); + return authService; + }); + + services.AddScoped(sp => new TestTool("context_tool")); + }); + }).CreateClient(); + + // Add session/user context headers + client.DefaultRequestHeaders.Add("X-Session-Id", "test-session-123"); + client.DefaultRequestHeaders.Add("X-User-Id", "user-456"); + + var listRequest = CreateJsonRpcRequest(RequestMethods.ToolsList, new ListToolsRequestParams()); + var listResponse = await SendRequest(client, listRequest); + + Assert.Equal(HttpStatusCode.OK, listResponse.StatusCode); + + // The filter should have captured the context + // In a real implementation, you'd verify the context was properly set + } + + [Fact] + public async Task Authorization_WithAuthenticationMiddleware_IntegratesCorrectly() + { + // Arrange + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + services.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + authService.RegisterFilter(new TestBearerTokenFilter()); + return authService; + }); + + services.AddScoped(sp => new TestTool("secure_api")); + + services.AddAuthentication("Bearer") + .AddScheme("Bearer", _ => { }); + }); + + builder.Configure(app => + { + app.UseAuthentication(); + app.UseAuthorization(); + app.MapMcp("/mcp"); + }); + }).CreateClient(); + + // Test without token - should fail + var request = CreateJsonRpcRequest(RequestMethods.ToolsCall, new CallToolRequestParams + { + Name = "secure_api", + Arguments = new Dictionary() + }); + var response = await SendRequest(client, request); + + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + Assert.True(response.Headers.Contains("WWW-Authenticate")); + + // Test with valid token - should succeed + client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", "valid-token"); + var authedResponse = await SendRequest(client, request); + + Assert.Equal(HttpStatusCode.OK, authedResponse.StatusCode); + } + + [Fact] + public async Task Authorization_WithMultipleEndpoints_IsolatesCorrectly() + { + // Arrange + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + services.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + authService.RegisterFilter(new ToolNamePatternFilter(new[] { "public_*" }, allowMatching: true)); + return authService; + }); + + services.AddScoped(sp => new TestTool("public_info")); + services.AddScoped(sp => new TestTool("private_data")); + }); + + builder.Configure(app => + { + // Map multiple MCP endpoints + app.MapMcp("/mcp/public"); + app.MapMcp("/mcp/admin"); + }); + }).CreateClient(); + + // Test public endpoint + var publicRequest = CreateJsonRpcRequest(RequestMethods.ToolsList, new ListToolsRequestParams()); + var publicResponse = await SendRequest(client, publicRequest, "/mcp/public"); + + Assert.Equal(HttpStatusCode.OK, publicResponse.StatusCode); + var publicContent = await publicResponse.Content.ReadAsStringAsync(); + var publicResult = JsonSerializer.Deserialize(publicContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse); + var publicToolsResult = JsonSerializer.Deserialize(publicResult.Result.ToString()!, McpJsonUtilities.JsonContext.Default.ListToolsResult); + + // Should only have public tools + Assert.Single(publicToolsResult.Tools); + Assert.Equal("public_info", publicToolsResult.Tools[0].Name); + + // Test admin endpoint (should have same filtering) + var adminResponse = await SendRequest(client, publicRequest, "/mcp/admin"); + + Assert.Equal(HttpStatusCode.OK, adminResponse.StatusCode); + var adminContent = await adminResponse.Content.ReadAsStringAsync(); + var adminResult = JsonSerializer.Deserialize(adminContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse); + var adminToolsResult = JsonSerializer.Deserialize(adminResult.Result.ToString()!, McpJsonUtilities.JsonContext.Default.ListToolsResult); + + // Should have same filtering applied + Assert.Single(adminToolsResult.Tools); + Assert.Equal("public_info", adminToolsResult.Tools[0].Name); + } + + private static JsonRpcRequest CreateJsonRpcRequest(string method, object? parameters) + { + return new JsonRpcRequest + { + Id = RequestId.FromString(Guid.NewGuid().ToString()), + Method = method, + Params = parameters + }; + } + + private static async Task SendRequest(HttpClient client, JsonRpcRequest request, string endpoint = "/mcp") + { + var requestJson = JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.JsonRpcRequest); + var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + client.DefaultRequestHeaders.Accept.Clear(); + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + + return await client.PostAsync(endpoint, content); + } + + /// + /// Test tool for authorization testing. + /// + private class TestTool : McpServerTool + { + private readonly string _toolName; + + public TestTool(string toolName) + { + _toolName = toolName; + } + + public override Tool ProtocolTool => new() + { + Name = _toolName, + Description = $"Test tool: {_toolName}" + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents { Text = $"Tool {_toolName} executed" }] + }); + } + } + + /// + /// Test filter that captures authorization context for verification. + /// + private class TestContextCapturingFilter : IToolFilter + { + public int Priority => 100; + + public static ToolAuthorizationContext? LastContext { get; private set; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + LastContext = context; + return Task.FromResult(true); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + LastContext = context; + return Task.FromResult(AuthorizationResult.Allow("Context captured")); + } + } + + /// + /// Test filter that demonstrates Bearer token authentication. + /// + private class TestBearerTokenFilter : IToolFilter + { + public int Priority => 100; + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var result = await CanExecuteToolAsync(tool.Name, context, cancellationToken); + return result.IsAuthorized; + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Check if user is authenticated (would be set by authentication middleware) + if (context.Properties.TryGetValue("IsAuthenticated", out var authValue) && authValue is true) + { + return Task.FromResult(AuthorizationResult.Allow("Token valid")); + } + + return Task.FromResult(AuthorizationResult.DenyInvalidToken("mcp-server")); + } + } +} + +/// +/// Test startup class for ASP.NET Core integration tests. +/// +public class TestStartup +{ + public void ConfigureServices(IServiceCollection services) + { + services.AddMcp(); + services.AddLogging(); + } + + public void Configure(IApplicationBuilder app, IWebHostEnvironment env) + { + app.MapMcp("/mcp"); + } +} + +/// +/// Test authentication scheme options. +/// +public class TestAuthenticationSchemeOptions : Microsoft.AspNetCore.Authentication.AuthenticationSchemeOptions +{ +} + +/// +/// Test authentication handler. +/// +public class TestAuthenticationHandler : Microsoft.AspNetCore.Authentication.AuthenticationHandler +{ + public TestAuthenticationHandler(Microsoft.AspNetCore.Authentication.IOptionsMonitor options, + ILoggerFactory logger, System.Text.Encodings.Web.UrlEncoder encoder) + : base(options, logger, encoder) + { + } + + protected override Task HandleAuthenticateAsync() + { + var identity = new ClaimsIdentity(Scheme.Name); + var principal = new ClaimsPrincipal(identity); + var ticket = new Microsoft.AspNetCore.Authentication.AuthenticationTicket(principal, Scheme.Name); + return Task.FromResult(Microsoft.AspNetCore.Authentication.AuthenticateResult.Success(ticket)); + } +} + +/// +/// Test Bearer token authentication handler. +/// +public class TestBearerAuthenticationHandler : Microsoft.AspNetCore.Authentication.AuthenticationHandler +{ + public TestBearerAuthenticationHandler(Microsoft.AspNetCore.Authentication.IOptionsMonitor options, + ILoggerFactory logger, System.Text.Encodings.Web.UrlEncoder encoder) + : base(options, logger, encoder) + { + } + + protected override Task HandleAuthenticateAsync() + { + if (Request.Headers.TryGetValue("Authorization", out var authHeader)) + { + var headerValue = authHeader.ToString(); + if (headerValue.StartsWith("Bearer ") && headerValue.Length > 7) + { + var token = headerValue.Substring(7); + if (token == "valid-token") + { + var identity = new ClaimsIdentity("Bearer"); + var principal = new ClaimsPrincipal(identity); + var ticket = new Microsoft.AspNetCore.Authentication.AuthenticationTicket(principal, "Bearer"); + return Task.FromResult(Microsoft.AspNetCore.Authentication.AuthenticateResult.Success(ticket)); + } + } + } + + return Task.FromResult(Microsoft.AspNetCore.Authentication.AuthenticateResult.Fail("Invalid token")); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/Authorization/AuthorizationChallengeTests.cs b/tests/ModelContextProtocol.Tests/Server/Authorization/AuthorizationChallengeTests.cs new file mode 100644 index 00000000..2579a9da --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/Authorization/AuthorizationChallengeTests.cs @@ -0,0 +1,389 @@ +using ModelContextProtocol.Server.Authorization; +using Xunit; + +namespace ModelContextProtocol.Tests.Server.Authorization; + +/// +/// Unit tests for AuthorizationChallenge class. +/// +public class AuthorizationChallengeTests +{ + [Fact] + public void Constructor_WithValidParameters_SetsProperties() + { + // Arrange + const string wwwAuthenticateValue = "Bearer realm=\"test\""; + const int httpStatusCode = 401; + + // Act + var challenge = new AuthorizationChallenge(wwwAuthenticateValue, httpStatusCode); + + // Assert + Assert.Equal(wwwAuthenticateValue, challenge.WwwAuthenticateValue); + Assert.Equal(httpStatusCode, challenge.HttpStatusCode); + } + + [Fact] + public void Constructor_WithDefaultStatusCode_SetsDefault401() + { + // Arrange + const string wwwAuthenticateValue = "Bearer realm=\"test\""; + + // Act + var challenge = new AuthorizationChallenge(wwwAuthenticateValue); + + // Assert + Assert.Equal(wwwAuthenticateValue, challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void Constructor_WithNullWwwAuthenticateValue_ThrowsArgumentNullException() + { + // Act & Assert + Assert.Throws(() => new AuthorizationChallenge(null!)); + } + + [Fact] + public void CreateBearerChallenge_WithNoParameters_ReturnsBasicBearerChallenge() + { + // Act + var challenge = AuthorizationChallenge.CreateBearerChallenge(); + + // Assert + Assert.Equal("Bearer", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateBearerChallenge_WithRealm_ReturnsBearerChallengeWithRealm() + { + // Arrange + const string realm = "test-realm"; + + // Act + var challenge = AuthorizationChallenge.CreateBearerChallenge(realm: realm); + + // Assert + Assert.Equal($"Bearer realm=\"{realm}\"", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateBearerChallenge_WithScope_ReturnsBearerChallengeWithScope() + { + // Arrange + const string scope = "read:data"; + + // Act + var challenge = AuthorizationChallenge.CreateBearerChallenge(scope: scope); + + // Assert + Assert.Equal($"Bearer scope=\"{scope}\"", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateBearerChallenge_WithError_ReturnsBearerChallengeWithError() + { + // Arrange + const string error = "invalid_token"; + + // Act + var challenge = AuthorizationChallenge.CreateBearerChallenge(error: error); + + // Assert + Assert.Equal($"Bearer error=\"{error}\"", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateBearerChallenge_WithErrorDescription_ReturnsBearerChallengeWithErrorDescription() + { + // Arrange + const string errorDescription = "The token has expired"; + + // Act + var challenge = AuthorizationChallenge.CreateBearerChallenge(errorDescription: errorDescription); + + // Assert + Assert.Equal($"Bearer error_description=\"{errorDescription}\"", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateBearerChallenge_WithAllParameters_ReturnsBearerChallengeWithAllParameters() + { + // Arrange + const string realm = "test-realm"; + const string scope = "read:data write:data"; + const string error = "insufficient_scope"; + const string errorDescription = "The request requires higher privileges"; + + // Act + var challenge = AuthorizationChallenge.CreateBearerChallenge(realm, scope, error, errorDescription); + + // Assert + var expected = $"Bearer realm=\"{realm}\", scope=\"{scope}\", error=\"{error}\", error_description=\"{errorDescription}\""; + Assert.Equal(expected, challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateBearerChallenge_WithEmptyStrings_IgnoresEmptyParameters() + { + // Act + var challenge = AuthorizationChallenge.CreateBearerChallenge("", "", "", ""); + + // Assert + Assert.Equal("Bearer", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateBearerChallenge_WithWhitespaceStrings_IgnoresWhitespaceParameters() + { + // Act + var challenge = AuthorizationChallenge.CreateBearerChallenge(" ", " ", " ", " "); + + // Assert + Assert.Equal("Bearer", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateBasicChallenge_WithRealm_ReturnsBasicChallengeWithRealm() + { + // Arrange + const string realm = "test-realm"; + + // Act + var challenge = AuthorizationChallenge.CreateBasicChallenge(realm); + + // Assert + Assert.Equal($"Basic realm=\"{realm}\"", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateBasicChallenge_WithoutRealm_ReturnsBasicChallenge() + { + // Act + var challenge = AuthorizationChallenge.CreateBasicChallenge(); + + // Assert + Assert.Equal("Basic", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateBasicChallenge_WithNullRealm_ReturnsBasicChallenge() + { + // Act + var challenge = AuthorizationChallenge.CreateBasicChallenge(null); + + // Assert + Assert.Equal("Basic", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateBasicChallenge_WithEmptyRealm_ReturnsBasicChallenge() + { + // Act + var challenge = AuthorizationChallenge.CreateBasicChallenge(""); + + // Assert + Assert.Equal("Basic", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateCustomChallenge_WithSchemeOnly_ReturnsCustomChallenge() + { + // Arrange + const string scheme = "CustomAuth"; + + // Act + var challenge = AuthorizationChallenge.CreateCustomChallenge(scheme); + + // Assert + Assert.Equal(scheme, challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateCustomChallenge_WithSchemeAndParameters_ReturnsCustomChallengeWithParameters() + { + // Arrange + const string scheme = "CustomAuth"; + var parameters = new[] { ("realm", "test"), ("token", "abc123") }; + + // Act + var challenge = AuthorizationChallenge.CreateCustomChallenge(scheme, parameters); + + // Assert + Assert.Equal($"{scheme} realm=\"test\", token=\"abc123\"", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateCustomChallenge_WithEmptyParameters_ReturnsSchemeOnly() + { + // Arrange + const string scheme = "CustomAuth"; + var parameters = Array.Empty<(string, string)>(); + + // Act + var challenge = AuthorizationChallenge.CreateCustomChallenge(scheme, parameters); + + // Assert + Assert.Equal(scheme, challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateCustomChallenge_WithNullParameters_ReturnsSchemeOnly() + { + // Arrange + const string scheme = "CustomAuth"; + + // Act + var challenge = AuthorizationChallenge.CreateCustomChallenge(scheme, null); + + // Assert + Assert.Equal(scheme, challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateInsufficientScopeChallenge_WithRequiredScope_ReturnsInsufficientScopeChallenge() + { + // Arrange + const string requiredScope = "write:admin"; + + // Act + var challenge = AuthorizationChallenge.CreateInsufficientScopeChallenge(requiredScope); + + // Assert + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains($"scope=\"{requiredScope}\"", challenge.WwwAuthenticateValue); + Assert.Contains("error=\"insufficient_scope\"", challenge.WwwAuthenticateValue); + Assert.Contains($"Required scope: {requiredScope}", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateInsufficientScopeChallenge_WithRequiredScopeAndRealm_ReturnsInsufficientScopeChallengeWithRealm() + { + // Arrange + const string requiredScope = "write:admin"; + const string realm = "test-realm"; + + // Act + var challenge = AuthorizationChallenge.CreateInsufficientScopeChallenge(requiredScope, realm); + + // Assert + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains($"realm=\"{realm}\"", challenge.WwwAuthenticateValue); + Assert.Contains($"scope=\"{requiredScope}\"", challenge.WwwAuthenticateValue); + Assert.Contains("error=\"insufficient_scope\"", challenge.WwwAuthenticateValue); + Assert.Contains($"Required scope: {requiredScope}", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateInvalidTokenChallenge_WithDefaultParameters_ReturnsInvalidTokenChallenge() + { + // Act + var challenge = AuthorizationChallenge.CreateInvalidTokenChallenge(); + + // Assert + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains("error=\"invalid_token\"", challenge.WwwAuthenticateValue); + Assert.Contains("expired, revoked, malformed, or invalid", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateInvalidTokenChallenge_WithRealm_ReturnsInvalidTokenChallengeWithRealm() + { + // Arrange + const string realm = "test-realm"; + + // Act + var challenge = AuthorizationChallenge.CreateInvalidTokenChallenge(realm); + + // Assert + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains($"realm=\"{realm}\"", challenge.WwwAuthenticateValue); + Assert.Contains("error=\"invalid_token\"", challenge.WwwAuthenticateValue); + Assert.Contains("expired, revoked, malformed, or invalid", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Fact] + public void CreateInvalidTokenChallenge_WithCustomErrorDescription_ReturnsInvalidTokenChallengeWithCustomDescription() + { + // Arrange + const string realm = "test-realm"; + const string errorDescription = "Token signature validation failed"; + + // Act + var challenge = AuthorizationChallenge.CreateInvalidTokenChallenge(realm, errorDescription); + + // Assert + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains($"realm=\"{realm}\"", challenge.WwwAuthenticateValue); + Assert.Contains("error=\"invalid_token\"", challenge.WwwAuthenticateValue); + Assert.Contains($"error_description=\"{errorDescription}\"", challenge.WwwAuthenticateValue); + Assert.DoesNotContain("expired, revoked, malformed, or invalid", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Theory] + [InlineData("realm with spaces")] + [InlineData("realm\"with\"quotes")] + [InlineData("realm,with,commas")] + public void CreateBearerChallenge_WithSpecialCharactersInRealm_HandlesCorrectly(string realm) + { + // Act + var challenge = AuthorizationChallenge.CreateBearerChallenge(realm: realm); + + // Assert + Assert.Contains($"realm=\"{realm}\"", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Theory] + [InlineData("read:data write:data")] + [InlineData("admin:full")] + [InlineData("user:profile user:email")] + public void CreateBearerChallenge_WithVariousScopes_HandlesCorrectly(string scope) + { + // Act + var challenge = AuthorizationChallenge.CreateBearerChallenge(scope: scope); + + // Assert + Assert.Contains($"scope=\"{scope}\"", challenge.WwwAuthenticateValue); + Assert.Equal(401, challenge.HttpStatusCode); + } + + [Theory] + [InlineData(400)] + [InlineData(401)] + [InlineData(403)] + [InlineData(500)] + public void Constructor_WithVariousStatusCodes_SetsCorrectStatusCode(int statusCode) + { + // Arrange + const string wwwAuthenticateValue = "Bearer"; + + // Act + var challenge = new AuthorizationChallenge(wwwAuthenticateValue, statusCode); + + // Assert + Assert.Equal(statusCode, challenge.HttpStatusCode); + Assert.Equal(wwwAuthenticateValue, challenge.WwwAuthenticateValue); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/Authorization/AuthorizationResultTests.cs b/tests/ModelContextProtocol.Tests/Server/Authorization/AuthorizationResultTests.cs new file mode 100644 index 00000000..e1d083c5 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/Authorization/AuthorizationResultTests.cs @@ -0,0 +1,349 @@ +using ModelContextProtocol.Server.Authorization; +using Xunit; + +namespace ModelContextProtocol.Tests.Server.Authorization; + +/// +/// Unit tests for AuthorizationResult and AuthorizationChallenge classes. +/// +public class AuthorizationResultTests +{ + [Fact] + public void Constructor_WithValidParameters_SetsProperties() + { + // Arrange + const bool isAuthorized = true; + const string reason = "Test reason"; + const string additionalData = "Test data"; + + // Act + var result = new AuthorizationResult(isAuthorized, reason, additionalData); + + // Assert + Assert.Equal(isAuthorized, result.IsAuthorized); + Assert.Equal(reason, result.Reason); + Assert.Equal(additionalData, result.AdditionalData); + } + + [Fact] + public void Constructor_WithMinimalParameters_SetsDefaults() + { + // Arrange & Act + var result = new AuthorizationResult(true); + + // Assert + Assert.True(result.IsAuthorized); + Assert.Null(result.Reason); + Assert.Null(result.AdditionalData); + } + + [Fact] + public void Allow_WithoutParameters_ReturnsAuthorizedResult() + { + // Act + var result = AuthorizationResult.Allow(); + + // Assert + Assert.True(result.IsAuthorized); + Assert.Null(result.Reason); + Assert.Null(result.AdditionalData); + } + + [Fact] + public void Allow_WithReason_ReturnsAuthorizedResultWithReason() + { + // Arrange + const string reason = "User has permission"; + + // Act + var result = AuthorizationResult.Allow(reason); + + // Assert + Assert.True(result.IsAuthorized); + Assert.Equal(reason, result.Reason); + Assert.Null(result.AdditionalData); + } + + [Fact] + public void Allow_WithReasonAndData_ReturnsAuthorizedResultWithReasonAndData() + { + // Arrange + const string reason = "User has permission"; + const string data = "Additional context"; + + // Act + var result = AuthorizationResult.Allow(reason, data); + + // Assert + Assert.True(result.IsAuthorized); + Assert.Equal(reason, result.Reason); + Assert.Equal(data, result.AdditionalData); + } + + [Fact] + public void Deny_WithReason_ReturnsDeniedResultWithReason() + { + // Arrange + const string reason = "Insufficient permissions"; + + // Act + var result = AuthorizationResult.Deny(reason); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal(reason, result.Reason); + Assert.Null(result.AdditionalData); + } + + [Fact] + public void Deny_WithReasonAndData_ReturnsDeniedResultWithReasonAndData() + { + // Arrange + const string reason = "Insufficient permissions"; + const string data = "Additional context"; + + // Act + var result = AuthorizationResult.Deny(reason, data); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal(reason, result.Reason); + Assert.Equal(data, result.AdditionalData); + } + + [Fact] + public void Deny_WithOnlyData_ReturnsDeniedResultWithDefaultReason() + { + // Arrange + const string data = "Additional context"; + + // Act + var result = AuthorizationResult.Deny(data); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal("Access denied", result.Reason); + Assert.Equal(data, result.AdditionalData); + } + + [Fact] + public void DenyWithChallenge_WithReasonAndChallenge_ReturnsDeniedResultWithChallenge() + { + // Arrange + const string reason = "Authentication required"; + var challenge = AuthorizationChallenge.CreateBearerChallenge("test-realm"); + + // Act + var result = AuthorizationResult.DenyWithChallenge(reason, challenge); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal(reason, result.Reason); + Assert.Equal(challenge, result.AdditionalData); + } + + [Fact] + public void DenyWithBearerChallenge_WithBasicParameters_ReturnsDeniedResultWithBearerChallenge() + { + // Arrange + const string reason = "Invalid token"; + const string realm = "test-realm"; + + // Act + var result = AuthorizationResult.DenyWithBearerChallenge(reason, realm); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal(reason, result.Reason); + Assert.IsType(result.AdditionalData); + + var challenge = (AuthorizationChallenge)result.AdditionalData; + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains($"realm=\"{realm}\"", challenge.WwwAuthenticateValue); + } + + [Fact] + public void DenyWithBearerChallenge_WithAllParameters_ReturnsDeniedResultWithFullBearerChallenge() + { + // Arrange + const string reason = "Insufficient scope"; + const string realm = "test-realm"; + const string scope = "read:data"; + const string error = "insufficient_scope"; + const string errorDescription = "Token lacks required scope"; + + // Act + var result = AuthorizationResult.DenyWithBearerChallenge(reason, realm, scope, error, errorDescription); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal(reason, result.Reason); + Assert.IsType(result.AdditionalData); + + var challenge = (AuthorizationChallenge)result.AdditionalData; + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains($"realm=\"{realm}\"", challenge.WwwAuthenticateValue); + Assert.Contains($"scope=\"{scope}\"", challenge.WwwAuthenticateValue); + Assert.Contains($"error=\"{error}\"", challenge.WwwAuthenticateValue); + Assert.Contains($"error_description=\"{errorDescription}\"", challenge.WwwAuthenticateValue); + } + + [Fact] + public void DenyWithBasicChallenge_WithRealm_ReturnsDeniedResultWithBasicChallenge() + { + // Arrange + const string reason = "Authentication required"; + const string realm = "test-realm"; + + // Act + var result = AuthorizationResult.DenyWithBasicChallenge(reason, realm); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal(reason, result.Reason); + Assert.IsType(result.AdditionalData); + + var challenge = (AuthorizationChallenge)result.AdditionalData; + Assert.Contains("Basic", challenge.WwwAuthenticateValue); + Assert.Contains($"realm=\"{realm}\"", challenge.WwwAuthenticateValue); + } + + [Fact] + public void DenyWithBasicChallenge_WithoutRealm_ReturnsDeniedResultWithBasicChallenge() + { + // Arrange + const string reason = "Authentication required"; + + // Act + var result = AuthorizationResult.DenyWithBasicChallenge(reason); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal(reason, result.Reason); + Assert.IsType(result.AdditionalData); + + var challenge = (AuthorizationChallenge)result.AdditionalData; + Assert.Equal("Basic", challenge.WwwAuthenticateValue); + } + + [Fact] + public void DenyInsufficientScope_WithRequiredScope_ReturnsDeniedResultWithInsufficientScopeChallenge() + { + // Arrange + const string requiredScope = "write:admin"; + const string realm = "test-realm"; + + // Act + var result = AuthorizationResult.DenyInsufficientScope(requiredScope, realm); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Contains("Insufficient scope", result.Reason); + Assert.Contains(requiredScope, result.Reason); + Assert.IsType(result.AdditionalData); + + var challenge = (AuthorizationChallenge)result.AdditionalData; + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains($"scope=\"{requiredScope}\"", challenge.WwwAuthenticateValue); + Assert.Contains("error=\"insufficient_scope\"", challenge.WwwAuthenticateValue); + Assert.Contains($"realm=\"{realm}\"", challenge.WwwAuthenticateValue); + } + + [Fact] + public void DenyInsufficientScope_WithoutRealm_ReturnsDeniedResultWithInsufficientScopeChallenge() + { + // Arrange + const string requiredScope = "write:admin"; + + // Act + var result = AuthorizationResult.DenyInsufficientScope(requiredScope); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Contains("Insufficient scope", result.Reason); + Assert.Contains(requiredScope, result.Reason); + Assert.IsType(result.AdditionalData); + + var challenge = (AuthorizationChallenge)result.AdditionalData; + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains($"scope=\"{requiredScope}\"", challenge.WwwAuthenticateValue); + Assert.Contains("error=\"insufficient_scope\"", challenge.WwwAuthenticateValue); + Assert.DoesNotContain("realm=", challenge.WwwAuthenticateValue); + } + + [Fact] + public void DenyInvalidToken_WithRealm_ReturnsDeniedResultWithInvalidTokenChallenge() + { + // Arrange + const string realm = "test-realm"; + + // Act + var result = AuthorizationResult.DenyInvalidToken(realm); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal("Invalid or expired token", result.Reason); + Assert.IsType(result.AdditionalData); + + var challenge = (AuthorizationChallenge)result.AdditionalData; + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains("error=\"invalid_token\"", challenge.WwwAuthenticateValue); + Assert.Contains($"realm=\"{realm}\"", challenge.WwwAuthenticateValue); + } + + [Fact] + public void DenyInvalidToken_WithCustomErrorDescription_ReturnsDeniedResultWithCustomDescription() + { + // Arrange + const string realm = "test-realm"; + const string errorDescription = "Token has expired"; + + // Act + var result = AuthorizationResult.DenyInvalidToken(realm, errorDescription); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal("Invalid or expired token", result.Reason); + Assert.IsType(result.AdditionalData); + + var challenge = (AuthorizationChallenge)result.AdditionalData; + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains("error=\"invalid_token\"", challenge.WwwAuthenticateValue); + Assert.Contains($"error_description=\"{errorDescription}\"", challenge.WwwAuthenticateValue); + Assert.Contains($"realm=\"{realm}\"", challenge.WwwAuthenticateValue); + } + + [Fact] + public void DenyInvalidToken_WithoutParameters_ReturnsDeniedResultWithDefaultInvalidTokenChallenge() + { + // Act + var result = AuthorizationResult.DenyInvalidToken(); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal("Invalid or expired token", result.Reason); + Assert.IsType(result.AdditionalData); + + var challenge = (AuthorizationChallenge)result.AdditionalData; + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains("error=\"invalid_token\"", challenge.WwwAuthenticateValue); + Assert.DoesNotContain("realm=", challenge.WwwAuthenticateValue); + } + + [Theory] + [InlineData(true, "Success", "Authorized: Success")] + [InlineData(true, null, "Authorized")] + [InlineData(false, "Failed", "Denied: Failed")] + [InlineData(false, null, "Denied")] + public void ToString_WithVariousStates_ReturnsExpectedString(bool isAuthorized, string? reason, string expected) + { + // Arrange + var result = new AuthorizationResult(isAuthorized, reason); + + // Act + var toString = result.ToString(); + + // Assert + Assert.Equal(expected, toString); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/Authorization/HttpAuthorizationChallengeIntegrationTests.cs b/tests/ModelContextProtocol.Tests/Server/Authorization/HttpAuthorizationChallengeIntegrationTests.cs new file mode 100644 index 00000000..294570d3 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/Authorization/HttpAuthorizationChallengeIntegrationTests.cs @@ -0,0 +1,503 @@ +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Mvc.Testing; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.AspNetCore; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Server.Authorization; +using System.Net; +using System.Net.Http.Headers; +using System.Text; +using System.Text.Json; +using Xunit; + +namespace ModelContextProtocol.Tests.Server.Authorization; + +/// +/// Integration tests for CallTool authorization with HTTP challenges. +/// +public class HttpAuthorizationChallengeIntegrationTests : IClassFixture> +{ + private readonly WebApplicationFactory _factory; + + public HttpAuthorizationChallengeIntegrationTests(WebApplicationFactory factory) + { + _factory = factory; + } + + [Fact] + public async Task CallTool_WithInsufficientScopeFilter_ReturnsWwwAuthenticateHeader() + { + // Arrange + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + services.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + authService.RegisterFilter(new TestOAuthToolFilter("write:admin", "mcp-server")); + return authService; + }); + + services.AddScoped(sp => new TestTool("admin_delete_tool")); + }); + }).CreateClient(); + + var callToolRequest = new JsonRpcRequest + { + Id = RequestId.FromString("test-1"), + Method = RequestMethods.ToolsCall, + Params = new CallToolRequestParams + { + Name = "admin_delete_tool", + Arguments = new Dictionary() + } + }; + + var requestJson = JsonSerializer.Serialize(callToolRequest, McpJsonUtilities.JsonContext.Default.JsonRpcRequest); + var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + + // Act + var response = await client.PostAsync("/mcp", content); + + // Assert + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + Assert.True(response.Headers.Contains("WWW-Authenticate")); + + var wwwAuthHeader = response.Headers.GetValues("WWW-Authenticate").First(); + Assert.Contains("Bearer", wwwAuthHeader); + Assert.Contains("scope=\"write:admin\"", wwwAuthHeader); + Assert.Contains("error=\"insufficient_scope\"", wwwAuthHeader); + Assert.Contains("realm=\"mcp-server\"", wwwAuthHeader); + + var responseContent = await response.Content.ReadAsStringAsync(); + var errorResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcError); + + Assert.NotNull(errorResponse.Error); + Assert.Equal((int)McpErrorCode.InvalidParams, errorResponse.Error.Code); + Assert.Contains("admin_delete_tool", errorResponse.Error.Message); + Assert.Contains("Insufficient scope", errorResponse.Error.Message); + } + + [Fact] + public async Task CallTool_WithInvalidTokenFilter_ReturnsWwwAuthenticateHeader() + { + // Arrange + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + services.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + authService.RegisterFilter(new TestInvalidTokenFilter("mcp-server")); + return authService; + }); + + services.AddScoped(sp => new TestTool("secure_tool")); + }); + }).CreateClient(); + + var callToolRequest = new JsonRpcRequest + { + Id = RequestId.FromString("test-2"), + Method = RequestMethods.ToolsCall, + Params = new CallToolRequestParams + { + Name = "secure_tool", + Arguments = new Dictionary() + } + }; + + var requestJson = JsonSerializer.Serialize(callToolRequest, McpJsonUtilities.JsonContext.Default.JsonRpcRequest); + var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + + // Act + var response = await client.PostAsync("/mcp", content); + + // Assert + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + Assert.True(response.Headers.Contains("WWW-Authenticate")); + + var wwwAuthHeader = response.Headers.GetValues("WWW-Authenticate").First(); + Assert.Contains("Bearer", wwwAuthHeader); + Assert.Contains("error=\"invalid_token\"", wwwAuthHeader); + Assert.Contains("realm=\"mcp-server\"", wwwAuthHeader); + + var responseContent = await response.Content.ReadAsStringAsync(); + var errorResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcError); + + Assert.NotNull(errorResponse.Error); + Assert.Equal((int)McpErrorCode.InvalidParams, errorResponse.Error.Code); + Assert.Contains("secure_tool", errorResponse.Error.Message); + Assert.Contains("Invalid or expired token", errorResponse.Error.Message); + } + + [Fact] + public async Task CallTool_WithBasicAuthChallenge_ReturnsBasicWwwAuthenticateHeader() + { + // Arrange + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + services.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + authService.RegisterFilter(new TestBasicAuthFilter("secure-area")); + return authService; + }); + + services.AddScoped(sp => new TestTool("protected_tool")); + }); + }).CreateClient(); + + var callToolRequest = new JsonRpcRequest + { + Id = RequestId.FromString("test-3"), + Method = RequestMethods.ToolsCall, + Params = new CallToolRequestParams + { + Name = "protected_tool", + Arguments = new Dictionary() + } + }; + + var requestJson = JsonSerializer.Serialize(callToolRequest, McpJsonUtilities.JsonContext.Default.JsonRpcRequest); + var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + + // Act + var response = await client.PostAsync("/mcp", content); + + // Assert + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + Assert.True(response.Headers.Contains("WWW-Authenticate")); + + var wwwAuthHeader = response.Headers.GetValues("WWW-Authenticate").First(); + Assert.Contains("Basic", wwwAuthHeader); + Assert.Contains("realm=\"secure-area\"", wwwAuthHeader); + + var responseContent = await response.Content.ReadAsStringAsync(); + var errorResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcError); + + Assert.NotNull(errorResponse.Error); + Assert.Equal((int)McpErrorCode.InvalidParams, errorResponse.Error.Code); + Assert.Contains("protected_tool", errorResponse.Error.Message); + } + + [Fact] + public async Task CallTool_WithCustomAuthChallenge_ReturnsCustomWwwAuthenticateHeader() + { + // Arrange + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + services.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + authService.RegisterFilter(new TestCustomAuthFilter()); + return authService; + }); + + services.AddScoped(sp => new TestTool("api_tool")); + }); + }).CreateClient(); + + var callToolRequest = new JsonRpcRequest + { + Id = RequestId.FromString("test-4"), + Method = RequestMethods.ToolsCall, + Params = new CallToolRequestParams + { + Name = "api_tool", + Arguments = new Dictionary() + } + }; + + var requestJson = JsonSerializer.Serialize(callToolRequest, McpJsonUtilities.JsonContext.Default.JsonRpcRequest); + var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + + // Act + var response = await client.PostAsync("/mcp", content); + + // Assert + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + Assert.True(response.Headers.Contains("WWW-Authenticate")); + + var wwwAuthHeader = response.Headers.GetValues("WWW-Authenticate").First(); + Assert.Contains("ApiKey", wwwAuthHeader); + Assert.Contains("realm=\"api\"", wwwAuthHeader); + Assert.Contains("scope=\"full\"", wwwAuthHeader); + } + + [Fact] + public async Task CallTool_WithAllowedTool_ReturnsSuccess() + { + // Arrange + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + services.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + authService.RegisterFilter(new AllowAllToolFilter()); + return authService; + }); + + services.AddScoped(sp => new TestTool("allowed_tool")); + }); + }).CreateClient(); + + var callToolRequest = new JsonRpcRequest + { + Id = RequestId.FromString("test-5"), + Method = RequestMethods.ToolsCall, + Params = new CallToolRequestParams + { + Name = "allowed_tool", + Arguments = new Dictionary() + } + }; + + var requestJson = JsonSerializer.Serialize(callToolRequest, McpJsonUtilities.JsonContext.Default.JsonRpcRequest); + var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + + // Act + var response = await client.PostAsync("/mcp", content); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.False(response.Headers.Contains("WWW-Authenticate")); + + var responseContent = await response.Content.ReadAsStringAsync(); + var successResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse); + + Assert.NotNull(successResponse.Result); + Assert.Null(successResponse.Error); + } + + [Fact] + public async Task CallTool_WithMultipleFilters_ReturnsFirstDenyChallenge() + { + // Arrange + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + services.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + // First filter (higher priority) denies with OAuth challenge + authService.RegisterFilter(new TestOAuthToolFilter("write:admin", "mcp-server", priority: 1)); + // Second filter (lower priority) would deny with Basic auth, but shouldn't be reached + authService.RegisterFilter(new TestBasicAuthFilter("secure-area", priority: 2)); + return authService; + }); + + services.AddScoped(sp => new TestTool("admin_tool")); + }); + }).CreateClient(); + + var callToolRequest = new JsonRpcRequest + { + Id = RequestId.FromString("test-6"), + Method = RequestMethods.ToolsCall, + Params = new CallToolRequestParams + { + Name = "admin_tool", + Arguments = new Dictionary() + } + }; + + var requestJson = JsonSerializer.Serialize(callToolRequest, McpJsonUtilities.JsonContext.Default.JsonRpcRequest); + var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + + // Act + var response = await client.PostAsync("/mcp", content); + + // Assert + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + Assert.True(response.Headers.Contains("WWW-Authenticate")); + + var wwwAuthHeader = response.Headers.GetValues("WWW-Authenticate").First(); + // Should get OAuth challenge from first filter, not Basic from second filter + Assert.Contains("Bearer", wwwAuthHeader); + Assert.Contains("scope=\"write:admin\"", wwwAuthHeader); + Assert.DoesNotContain("Basic", wwwAuthHeader); + } + + /// + /// Test tool for authorization challenge testing. + /// + private class TestTool : McpServerTool + { + private readonly string _toolName; + + public TestTool(string toolName = "test_tool") + { + _toolName = toolName; + } + + public override Tool ProtocolTool => new() + { + Name = _toolName, + Description = $"A test tool for authorization testing: {_toolName}" + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents { Text = $"Tool {_toolName} executed successfully" }] + }); + } + } + + /// + /// Test OAuth tool filter that demonstrates insufficient scope challenges. + /// + private class TestOAuthToolFilter : IToolFilter + { + private readonly string _requiredScope; + private readonly string? _realm; + + public TestOAuthToolFilter(string requiredScope, string? realm = null, int priority = 100) + { + _requiredScope = requiredScope ?? throw new ArgumentNullException(nameof(requiredScope)); + _realm = realm; + Priority = priority; + } + + public int Priority { get; } + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var result = await CanExecuteToolAsync(tool.Name, context, cancellationToken); + return result.IsAuthorized; + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // For testing, always deny with insufficient scope challenge + return Task.FromResult(AuthorizationResult.DenyInsufficientScope(_requiredScope, _realm)); + } + } + + /// + /// Test filter that demonstrates invalid token challenges. + /// + private class TestInvalidTokenFilter : IToolFilter + { + private readonly string? _realm; + + public TestInvalidTokenFilter(string? realm = null) + { + _realm = realm; + } + + public int Priority => 100; + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var result = await CanExecuteToolAsync(tool.Name, context, cancellationToken); + return result.IsAuthorized; + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // For testing, always deny with invalid token challenge + return Task.FromResult(AuthorizationResult.DenyInvalidToken(_realm)); + } + } + + /// + /// Test filter that demonstrates Basic auth challenges. + /// + private class TestBasicAuthFilter : IToolFilter + { + private readonly string? _realm; + + public TestBasicAuthFilter(string? realm = null, int priority = 100) + { + _realm = realm; + Priority = priority; + } + + public int Priority { get; } + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var result = await CanExecuteToolAsync(tool.Name, context, cancellationToken); + return result.IsAuthorized; + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // For testing, always deny with Basic auth challenge + return Task.FromResult(AuthorizationResult.DenyWithBasicChallenge("Authentication required", _realm)); + } + } + + /// + /// Test filter that demonstrates custom auth challenges. + /// + private class TestCustomAuthFilter : IToolFilter + { + public int Priority => 100; + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var result = await CanExecuteToolAsync(tool.Name, context, cancellationToken); + return result.IsAuthorized; + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // For testing, always deny with custom auth challenge + var challenge = AuthorizationChallenge.CreateCustomChallenge("ApiKey", + ("realm", "api"), + ("scope", "full")); + return Task.FromResult(AuthorizationResult.DenyWithChallenge("API key required", challenge)); + } + } +} + +/// +/// Test program for the web application factory. +/// +public class TestProgram +{ + public static void Main(string[] args) + { + var builder = WebApplication.CreateBuilder(args); + builder.Services.AddMcp(); + + var app = builder.Build(); + app.MapMcp("/mcp"); + + app.Run(); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/Authorization/MockToolImplementations.cs b/tests/ModelContextProtocol.Tests/Server/Authorization/MockToolImplementations.cs new file mode 100644 index 00000000..3a280247 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/Authorization/MockToolImplementations.cs @@ -0,0 +1,920 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Server.Authorization; +using System.Text.Json; +using System.Text.Json.Nodes; +using Xunit; + +namespace ModelContextProtocol.Tests.Server.Authorization; + +/// +/// Mock tool implementations for comprehensive authorization testing. +/// +public static class MockToolImplementations +{ + /// + /// Creates a collection of mock tools for testing various authorization scenarios. + /// + public static IEnumerable CreateTestToolSuite() + { + yield return new ReadOnlyTool(); + yield return new WriteDataTool(); + yield return new AdminDeleteTool(); + yield return new AdminCreateTool(); + yield return new PublicInfoTool(); + yield return new PrivateDataTool(); + yield return new BetaFeatureTool(); + yield return new HighRateOperationTool(); + yield return new SecureApiTool(); + yield return new UserProfileTool(); + yield return new SystemStatusTool(); + yield return new ComplexAnalysisTool(); + yield return new TimeSensitiveTool(); + yield return new QuotaConsumingTool(); + yield return new AuditableTool(); + } + + /// + /// Creates a collection of test tools with specific categories for testing. + /// + public static Dictionary> CreateCategorizedTools() + { + return new Dictionary> + { + ["read_only"] = new McpServerTool[] { new ReadOnlyTool(), new PublicInfoTool(), new SystemStatusTool() }, + ["write_operations"] = new McpServerTool[] { new WriteDataTool(), new UserProfileTool() }, + ["admin_tools"] = new McpServerTool[] { new AdminDeleteTool(), new AdminCreateTool() }, + ["beta_features"] = new McpServerTool[] { new BetaFeatureTool() }, + ["high_privilege"] = new McpServerTool[] { new PrivateDataTool(), new SecureApiTool() }, + ["resource_intensive"] = new McpServerTool[] { new ComplexAnalysisTool(), new HighRateOperationTool() }, + ["time_sensitive"] = new McpServerTool[] { new TimeSensitiveTool() }, + ["quota_limited"] = new McpServerTool[] { new QuotaConsumingTool() }, + ["auditable"] = new McpServerTool[] { new AuditableTool() } + }; + } +} + +/// +/// Mock tool for read-only operations - typically allowed by most filters. +/// +public class ReadOnlyTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "read_only_data", + Description = "Reads data without making any changes", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["resource_id"] = new JsonObject + { + ["type"] = "string", + ["description"] = "The ID of the resource to read" + } + }, + ["required"] = new JsonArray { "resource_id" } + } + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var resourceId = request.Params.Arguments?.GetValueOrDefault("resource_id")?.ToString() ?? "unknown"; + + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Successfully read resource: {resourceId}. Data: {{\"id\":\"{resourceId}\",\"status\":\"active\",\"created\":\"2024-01-01T00:00:00Z\"}}" + }] + }); + } +} + +/// +/// Mock tool for write operations - may require elevated permissions. +/// +public class WriteDataTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "write_data", + Description = "Writes data to the system", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["resource_id"] = new JsonObject + { + ["type"] = "string", + ["description"] = "The ID of the resource to write" + }, + ["data"] = new JsonObject + { + ["type"] = "object", + ["description"] = "The data to write" + } + }, + ["required"] = new JsonArray { "resource_id", "data" } + } + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var resourceId = request.Params.Arguments?.GetValueOrDefault("resource_id")?.ToString() ?? "unknown"; + var data = request.Params.Arguments?.GetValueOrDefault("data"); + + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Successfully wrote data to resource: {resourceId}. Data written: {JsonSerializer.Serialize(data)}" + }] + }); + } +} + +/// +/// Mock tool for admin delete operations - requires admin privileges. +/// +public class AdminDeleteTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "admin_delete_resource", + Description = "Permanently deletes a resource (admin only)", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["resource_id"] = new JsonObject + { + ["type"] = "string", + ["description"] = "The ID of the resource to delete" + }, + ["confirm"] = new JsonObject + { + ["type"] = "boolean", + ["description"] = "Confirmation flag to prevent accidental deletions" + } + }, + ["required"] = new JsonArray { "resource_id", "confirm" } + } + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var resourceId = request.Params.Arguments?.GetValueOrDefault("resource_id")?.ToString() ?? "unknown"; + var confirm = request.Params.Arguments?.GetValueOrDefault("confirm"); + + if (confirm is JsonElement element && element.GetBoolean()) + { + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Resource {resourceId} has been permanently deleted. This action cannot be undone." + }] + }); + } + + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents + { + Text = "Delete operation cancelled - confirmation required." + }] + }); + } +} + +/// +/// Mock tool for admin create operations - requires admin privileges. +/// +public class AdminCreateTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "admin_create_resource", + Description = "Creates a new system resource (admin only)", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["resource_type"] = new JsonObject + { + ["type"] = "string", + ["enum"] = new JsonArray { "user", "organization", "project", "system" }, + ["description"] = "The type of resource to create" + }, + ["config"] = new JsonObject + { + ["type"] = "object", + ["description"] = "Configuration for the new resource" + } + }, + ["required"] = new JsonArray { "resource_type", "config" } + } + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var resourceType = request.Params.Arguments?.GetValueOrDefault("resource_type")?.ToString() ?? "unknown"; + var config = request.Params.Arguments?.GetValueOrDefault("config"); + var newId = Guid.NewGuid().ToString(); + + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Created new {resourceType} resource with ID: {newId}. Configuration: {JsonSerializer.Serialize(config)}" + }] + }); + } +} + +/// +/// Mock tool for public information - typically unrestricted. +/// +public class PublicInfoTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "public_info", + Description = "Retrieves publicly available information", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["info_type"] = new JsonObject + { + ["type"] = "string", + ["enum"] = new JsonArray { "status", "version", "features", "documentation" }, + ["description"] = "The type of public information to retrieve" + } + }, + ["required"] = new JsonArray { "info_type" } + } + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var infoType = request.Params.Arguments?.GetValueOrDefault("info_type")?.ToString() ?? "status"; + + var info = infoType switch + { + "status" => "System is operational", + "version" => "v1.2.3", + "features" => "Authentication, Authorization, Audit Logging", + "documentation" => "https://docs.example.com", + _ => "Unknown information type" + }; + + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents { Text = $"Public {infoType}: {info}" }] + }); + } +} + +/// +/// Mock tool for private data access - requires authentication and authorization. +/// +public class PrivateDataTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "private_data_access", + Description = "Accesses private data (requires proper authorization)", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["user_id"] = new JsonObject + { + ["type"] = "string", + ["description"] = "The user ID to access private data for" + }, + ["data_type"] = new JsonObject + { + ["type"] = "string", + ["enum"] = new JsonArray { "profile", "settings", "history", "analytics" }, + ["description"] = "The type of private data to access" + } + }, + ["required"] = new JsonArray { "user_id", "data_type" } + } + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var userId = request.Params.Arguments?.GetValueOrDefault("user_id")?.ToString() ?? "unknown"; + var dataType = request.Params.Arguments?.GetValueOrDefault("data_type")?.ToString() ?? "profile"; + + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Private {dataType} data for user {userId}: {{\"sensitive\":\"data\",\"access_time\":\"{DateTime.UtcNow:O}\"}}" + }] + }); + } +} + +/// +/// Mock tool for beta features - may be restricted based on feature flags. +/// +public class BetaFeatureTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "beta_advanced_analytics", + Description = "Advanced analytics feature (beta)", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["dataset_id"] = new JsonObject + { + ["type"] = "string", + ["description"] = "The dataset to analyze" + }, + ["analysis_type"] = new JsonObject + { + ["type"] = "string", + ["enum"] = new JsonArray { "trend", "correlation", "prediction", "anomaly" }, + ["description"] = "Type of analysis to perform" + } + }, + ["required"] = new JsonArray { "dataset_id", "analysis_type" } + } + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var datasetId = request.Params.Arguments?.GetValueOrDefault("dataset_id")?.ToString() ?? "unknown"; + var analysisType = request.Params.Arguments?.GetValueOrDefault("analysis_type")?.ToString() ?? "trend"; + + // Simulate beta feature processing + await Task.Delay(100, cancellationToken); + + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Beta {analysisType} analysis completed for dataset {datasetId}. Results: {{\"confidence\":0.85,\"insights\":[\"pattern_detected\",\"seasonal_trend\"]}}" + }] + }); + } +} + +/// +/// Mock tool for high-rate operations - may be limited by rate limiting filters. +/// +public class HighRateOperationTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "high_rate_batch_process", + Description = "Processes data in high-frequency batches", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["batch_size"] = new JsonObject + { + ["type"] = "integer", + ["minimum"] = 1, + ["maximum"] = 10000, + ["description"] = "Number of items to process in the batch" + }, + ["processing_mode"] = new JsonObject + { + ["type"] = "string", + ["enum"] = new JsonArray { "fast", "thorough", "balanced" }, + ["description"] = "Processing mode" + } + }, + ["required"] = new JsonArray { "batch_size", "processing_mode" } + } + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var batchSize = request.Params.Arguments?.GetValueOrDefault("batch_size"); + var processingMode = request.Params.Arguments?.GetValueOrDefault("processing_mode")?.ToString() ?? "balanced"; + + var size = batchSize is JsonElement element ? element.GetInt32() : 100; + + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Batch processing completed: {size} items processed in {processingMode} mode. Rate: {size * 10} items/minute" + }] + }); + } +} + +/// +/// Mock tool for secure API operations - requires strong authentication. +/// +public class SecureApiTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "secure_api_operation", + Description = "Secure API operation requiring strong authentication", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["operation"] = new JsonObject + { + ["type"] = "string", + ["enum"] = new JsonArray { "encrypt", "decrypt", "sign", "verify" }, + ["description"] = "Cryptographic operation to perform" + }, + ["payload"] = new JsonObject + { + ["type"] = "string", + ["description"] = "Data to process" + } + }, + ["required"] = new JsonArray { "operation", "payload" } + } + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var operation = request.Params.Arguments?.GetValueOrDefault("operation")?.ToString() ?? "encrypt"; + var payload = request.Params.Arguments?.GetValueOrDefault("payload")?.ToString() ?? ""; + + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Secure {operation} operation completed. Processed {payload.Length} bytes. Result hash: {payload.GetHashCode():X8}" + }] + }); + } +} + +/// +/// Mock tool for user profile operations - context-dependent access. +/// +public class UserProfileTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "user_profile_update", + Description = "Updates user profile information", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["user_id"] = new JsonObject + { + ["type"] = "string", + ["description"] = "User ID to update" + }, + ["updates"] = new JsonObject + { + ["type"] = "object", + ["description"] = "Profile updates to apply" + } + }, + ["required"] = new JsonArray { "user_id", "updates" } + } + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var userId = request.Params.Arguments?.GetValueOrDefault("user_id")?.ToString() ?? "unknown"; + var updates = request.Params.Arguments?.GetValueOrDefault("updates"); + + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Profile updated for user {userId}. Changes: {JsonSerializer.Serialize(updates)}" + }] + }); + } +} + +/// +/// Mock tool for system status - usually publicly accessible. +/// +public class SystemStatusTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "system_status_check", + Description = "Checks system health and status", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["component"] = new JsonObject + { + ["type"] = "string", + ["enum"] = new JsonArray { "all", "database", "cache", "queue", "storage" }, + ["description"] = "System component to check" + } + } + } + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var component = request.Params.Arguments?.GetValueOrDefault("component")?.ToString() ?? "all"; + + var status = component switch + { + "database" => "Connected, 50ms latency", + "cache" => "Active, 95% hit rate", + "queue" => "Processing, 123 items pending", + "storage" => "Available, 78% capacity", + "all" => "All systems operational", + _ => "Component status unknown" + }; + + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Status for {component}: {status}" + }] + }); + } +} + +/// +/// Mock tool for complex analysis - may require premium access. +/// +public class ComplexAnalysisTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "complex_data_analysis", + Description = "Performs complex data analysis (resource intensive)", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["dataset"] = new JsonObject + { + ["type"] = "array", + ["items"] = new JsonObject { ["type"] = "object" }, + ["description"] = "Dataset to analyze" + }, + ["algorithms"] = new JsonObject + { + ["type"] = "array", + ["items"] = new JsonObject + { + ["type"] = "string", + ["enum"] = new JsonArray { "regression", "clustering", "classification", "neural_network" } + }, + ["description"] = "Analysis algorithms to apply" + } + }, + ["required"] = new JsonArray { "dataset", "algorithms" } + } + }; + + public override async ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var dataset = request.Params.Arguments?.GetValueOrDefault("dataset"); + var algorithms = request.Params.Arguments?.GetValueOrDefault("algorithms"); + + // Simulate complex processing + await Task.Delay(500, cancellationToken); + + return new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Complex analysis completed using algorithms: {JsonSerializer.Serialize(algorithms)}. Dataset size: {(dataset as JsonArray)?.Count ?? 0} records. Processing time: 500ms" + }] + }; + } +} + +/// +/// Mock tool for time-sensitive operations - may be restricted by time-based filters. +/// +public class TimeSensitiveTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "time_sensitive_trading", + Description = "Time-sensitive trading operation (business hours only)", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["symbol"] = new JsonObject + { + ["type"] = "string", + ["description"] = "Trading symbol" + }, + ["action"] = new JsonObject + { + ["type"] = "string", + ["enum"] = new JsonArray { "buy", "sell", "quote" }, + ["description"] = "Trading action" + }, + ["quantity"] = new JsonObject + { + ["type"] = "number", + ["minimum"] = 0, + ["description"] = "Quantity to trade" + } + }, + ["required"] = new JsonArray { "symbol", "action" } + } + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var symbol = request.Params.Arguments?.GetValueOrDefault("symbol")?.ToString() ?? "UNKNOWN"; + var action = request.Params.Arguments?.GetValueOrDefault("action")?.ToString() ?? "quote"; + var quantity = request.Params.Arguments?.GetValueOrDefault("quantity"); + + var currentTime = DateTime.Now; + + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Trading {action} for {symbol} executed at {currentTime:HH:mm:ss}. Quantity: {quantity ?? "N/A"}. Market status: {(currentTime.Hour >= 9 && currentTime.Hour <= 17 ? "Open" : "Closed")}" + }] + }); + } +} + +/// +/// Mock tool for quota-consuming operations - limited by usage quotas. +/// +public class QuotaConsumingTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "quota_consuming_operation", + Description = "Operation that consumes user quota", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["operation_size"] = new JsonObject + { + ["type"] = "string", + ["enum"] = new JsonArray { "small", "medium", "large", "xlarge" }, + ["description"] = "Size of operation (affects quota consumption)" + }, + ["priority"] = new JsonObject + { + ["type"] = "string", + ["enum"] = new JsonArray { "low", "normal", "high", "urgent" }, + ["description"] = "Operation priority" + } + }, + ["required"] = new JsonArray { "operation_size" } + } + }; + + public override async ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var operationSize = request.Params.Arguments?.GetValueOrDefault("operation_size")?.ToString() ?? "small"; + var priority = request.Params.Arguments?.GetValueOrDefault("priority")?.ToString() ?? "normal"; + + var quotaCost = operationSize switch + { + "small" => 1, + "medium" => 5, + "large" => 20, + "xlarge" => 100, + _ => 1 + }; + + // Simulate processing time based on size + var processingTime = quotaCost * 10; + await Task.Delay(processingTime, cancellationToken); + + return new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Operation completed. Size: {operationSize}, Priority: {priority}, Quota consumed: {quotaCost} units, Processing time: {processingTime}ms" + }] + }; + } +} + +/// +/// Mock tool for auditable operations - logs all access attempts. +/// +public class AuditableTool : McpServerTool +{ + public override Tool ProtocolTool => new() + { + Name = "auditable_financial_operation", + Description = "Financial operation that requires full audit trail", + InputSchema = new JsonObject + { + ["type"] = "object", + ["properties"] = new JsonObject + { + ["transaction_type"] = new JsonObject + { + ["type"] = "string", + ["enum"] = new JsonArray { "transfer", "deposit", "withdrawal", "balance_check" }, + ["description"] = "Type of financial transaction" + }, + ["amount"] = new JsonObject + { + ["type"] = "number", + ["minimum"] = 0, + ["description"] = "Transaction amount" + }, + ["currency"] = new JsonObject + { + ["type"] = "string", + ["pattern"] = "^[A-Z]{3}$", + ["description"] = "Currency code (ISO 4217)" + }, + ["reference"] = new JsonObject + { + ["type"] = "string", + ["description"] = "Transaction reference" + } + }, + ["required"] = new JsonArray { "transaction_type" } + } + }; + + public override ValueTask InvokeAsync(RequestContext request, CancellationToken cancellationToken = default) + { + var transactionType = request.Params.Arguments?.GetValueOrDefault("transaction_type")?.ToString() ?? "balance_check"; + var amount = request.Params.Arguments?.GetValueOrDefault("amount"); + var currency = request.Params.Arguments?.GetValueOrDefault("currency")?.ToString() ?? "USD"; + var reference = request.Params.Arguments?.GetValueOrDefault("reference")?.ToString() ?? Guid.NewGuid().ToString(); + + var auditTrail = new + { + TransactionId = Guid.NewGuid().ToString(), + Type = transactionType, + Amount = amount, + Currency = currency, + Reference = reference, + Timestamp = DateTime.UtcNow, + Status = "Completed", + AuditLevel = "Full" + }; + + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextResourceContents + { + Text = $"Financial operation completed with full audit trail: {JsonSerializer.Serialize(auditTrail)}" + }] + }); + } +} + +/// +/// Test utilities for working with mock tools. +/// +public static class MockToolTestUtilities +{ + /// + /// Creates a test authorization context with specified properties. + /// + public static ToolAuthorizationContext CreateTestContext( + string? sessionId = null, + string? userId = null, + IEnumerable? roles = null, + IEnumerable? permissions = null, + Dictionary? properties = null) + { + var context = ToolAuthorizationContext.ForSession(sessionId ?? "test-session-" + Guid.NewGuid().ToString("N")[..8]); + + if (userId != null) + { + context.UserId = userId; + } + + if (roles != null) + { + foreach (var role in roles) + { + context.UserRoles.Add(role); + } + } + + if (permissions != null) + { + foreach (var permission in permissions) + { + context.UserPermissions.Add(permission); + } + } + + if (properties != null) + { + foreach (var kvp in properties) + { + context.Properties[kvp.Key] = kvp.Value; + } + } + + return context; + } + + /// + /// Validates that a tool result contains expected content. + /// + public static void AssertToolResultContains(CallToolResult result, string expectedContent) + { + Assert.NotNull(result); + Assert.NotNull(result.Content); + Assert.NotEmpty(result.Content); + + var textContent = result.Content.OfType().FirstOrDefault(); + Assert.NotNull(textContent); + Assert.Contains(expectedContent, textContent.Text); + } + + /// + /// Creates a basic tool filter for testing that allows/denies based on tool name patterns. + /// + public static IToolFilter CreatePatternFilter(string[] patterns, bool allowMatching, int priority = 100) + { + return new ToolNamePatternFilter(patterns, allowMatching, priority); + } + + /// + /// Creates a role-based filter for testing. + /// + public static IToolFilter CreateRoleFilter(string requiredRole, string[] toolPatterns, int priority = 100) + { + return new TestRoleFilter(requiredRole, toolPatterns, priority); + } + + private class TestRoleFilter : IToolFilter + { + private readonly string _requiredRole; + private readonly string[] _toolPatterns; + + public TestRoleFilter(string requiredRole, string[] toolPatterns, int priority) + { + _requiredRole = requiredRole; + _toolPatterns = toolPatterns; + Priority = priority; + } + + public int Priority { get; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (MatchesPattern(tool.Name)) + { + return Task.FromResult(context.UserRoles.Contains(_requiredRole)); + } + return Task.FromResult(true); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (MatchesPattern(toolName) && !context.UserRoles.Contains(_requiredRole)) + { + return Task.FromResult(AuthorizationResult.Deny($"Required role: {_requiredRole}")); + } + return Task.FromResult(AuthorizationResult.Allow()); + } + + private bool MatchesPattern(string toolName) + { + return _toolPatterns.Any(pattern => + pattern.EndsWith("*") ? toolName.StartsWith(pattern[..^1]) : + pattern.StartsWith("*") ? toolName.EndsWith(pattern[1..]) : + toolName.Equals(pattern)); + } + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/Authorization/MultipleFilterCoordinationTests.cs b/tests/ModelContextProtocol.Tests/Server/Authorization/MultipleFilterCoordinationTests.cs new file mode 100644 index 00000000..7cc350e3 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/Authorization/MultipleFilterCoordinationTests.cs @@ -0,0 +1,729 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server.Authorization; +using ModelContextProtocol.Tests.Utils; +using System.Text.Json.Nodes; +using Xunit; + +namespace ModelContextProtocol.Tests.Server.Authorization; + +/// +/// Integration tests for multiple filter coordination scenarios. +/// +public class MultipleFilterCoordinationTests : LoggedTest +{ + public MultipleFilterCoordinationTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) + { + } + + [Fact] + public async Task ComplexFilterChain_WithVariousPriorities_AppliesCorrectOrder() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + + // Add filters with different priorities and behaviors + service.RegisterFilter(new SecurityFilter(priority: 1)); // Highest priority - security checks + service.RegisterFilter(new RateLimitFilter(priority: 5)); // Rate limiting + service.RegisterFilter(new RoleBasedFilter(priority: 10)); // Role-based access + service.RegisterFilter(new FeatureFlagFilter(priority: 15)); // Feature flags + service.RegisterFilter(new AuditFilter(priority: 100)); // Audit logging (lowest priority) + + var tools = new[] + { + CreateTestTool("user_profile"), + CreateTestTool("admin_delete"), + CreateTestTool("beta_feature"), + CreateTestTool("high_rate_tool"), + CreateTestTool("secure_operation") + }; + + var context = CreateTestContext(); + + // Act + var filteredTools = await service.FilterToolsAsync(tools, context); + + // Assert + // Verify that filters are applied in priority order and expected tools are filtered + var toolNames = filteredTools.Select(t => t.Name).ToList(); + + // SecurityFilter should block secure_operation + Assert.DoesNotContain("secure_operation", toolNames); + + // RateLimitFilter should block high_rate_tool + Assert.DoesNotContain("high_rate_tool", toolNames); + + // RoleBasedFilter should block admin_delete (no admin role) + Assert.DoesNotContain("admin_delete", toolNames); + + // FeatureFlagFilter should block beta_feature + Assert.DoesNotContain("beta_feature", toolNames); + + // user_profile should be allowed through all filters + Assert.Contains("user_profile", toolNames); + + // Verify all filters were called in order + Assert.True(SecurityFilter.CallOrder < RateLimitFilter.CallOrder); + Assert.True(RateLimitFilter.CallOrder < RoleBasedFilter.CallOrder); + Assert.True(RoleBasedFilter.CallOrder < FeatureFlagFilter.CallOrder); + Assert.True(FeatureFlagFilter.CallOrder < AuditFilter.CallOrder); + } + + [Fact] + public async Task FilterChain_WithEarlyTermination_StopsProcessingCorrectly() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + + var mockFilter1 = new ExecutionTrackingFilter("Filter1", priority: 1, shouldInclude: true); + var mockFilter2 = new ExecutionTrackingFilter("Filter2", priority: 2, shouldInclude: false); // This will deny + var mockFilter3 = new ExecutionTrackingFilter("Filter3", priority: 3, shouldInclude: true); + + service.RegisterFilter(mockFilter1); + service.RegisterFilter(mockFilter2); + service.RegisterFilter(mockFilter3); + + var tools = new[] { CreateTestTool("test_tool") }; + var context = CreateTestContext(); + + // Act + var filteredTools = await service.FilterToolsAsync(tools, context); + + // Assert + Assert.Empty(filteredTools); // Tool should be filtered out + + // Verify execution order + Assert.True(mockFilter1.WasCalled); + Assert.True(mockFilter2.WasCalled); // This denies, so processing stops here + Assert.False(mockFilter3.WasCalled); // This should not be called due to early termination + + Assert.True(mockFilter1.CallTime < mockFilter2.CallTime); + } + + [Fact] + public async Task FilterChain_WithExceptionHandling_IsolatesFailures() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + + service.RegisterFilter(new AllowAllToolFilter(priority: 1)); + service.RegisterFilter(new ExceptionThrowingFilter(priority: 2)); + service.RegisterFilter(new AllowAllToolFilter(priority: 3)); + + var tools = new[] { CreateTestTool("test_tool") }; + var context = CreateTestContext(); + + // Act + var filteredTools = await service.FilterToolsAsync(tools, context); + + // Assert + // Exception in middle filter should cause tool to be filtered out + Assert.Empty(filteredTools); + } + + [Fact] + public async Task FilterChain_WithDifferentFilterTypes_CombinesLogicCorrectly() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + + // Blacklist filter: deny admin tools + service.RegisterFilter(new ToolNamePatternFilter(new[] { "admin_*" }, allowMatching: false, priority: 1)); + + // Whitelist filter: only allow user and admin tools + service.RegisterFilter(new ToolNamePatternFilter(new[] { "user_*", "admin_*" }, allowMatching: true, priority: 2)); + + // Role filter: require admin role for admin tools (but admin tools already blocked above) + var roleFilter = RoleBasedToolFilterBuilder.Create() + .RequireRole("admin") + .ForToolsMatching("admin_*") + .Build(); + service.RegisterFilter(roleFilter); + + var tools = new[] + { + CreateTestTool("user_profile"), // Should pass (allowed by whitelist, not blocked by blacklist) + CreateTestTool("admin_delete"), // Should be blocked (blocked by blacklist) + CreateTestTool("system_status"), // Should be blocked (not in whitelist) + CreateTestTool("user_settings") // Should pass (allowed by whitelist, not blocked by blacklist) + }; + + var context = CreateTestContext(); + + // Act + var filteredTools = await service.FilterToolsAsync(tools, context); + + // Assert + var toolNames = filteredTools.Select(t => t.Name).ToList(); + + Assert.Equal(2, toolNames.Count); + Assert.Contains("user_profile", toolNames); + Assert.Contains("user_settings", toolNames); + Assert.DoesNotContain("admin_delete", toolNames); + Assert.DoesNotContain("system_status", toolNames); + } + + [Fact] + public async Task FilterChain_WithConditionalLogic_HandlesComplexScenarios() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + + // Add filters that have conditional logic based on context + service.RegisterFilter(new TimeBasedFilter(priority: 1)); + service.RegisterFilter(new UserLevelFilter(priority: 2)); + service.RegisterFilter(new ToolComplexityFilter(priority: 3)); + + var tools = new[] + { + CreateTestTool("simple_read"), + CreateTestTool("complex_analysis"), + CreateTestTool("admin_operation"), + CreateTestTool("time_sensitive_task") + }; + + var context = CreateTestContext(); + context.Properties["UserLevel"] = "premium"; + context.Properties["CurrentHour"] = DateTime.Now.Hour; + + // Act + var filteredTools = await service.FilterToolsAsync(tools, context); + + // Assert + var toolNames = filteredTools.Select(t => t.Name).ToList(); + + // Specific assertions based on the filter logic + Assert.Contains("simple_read", toolNames); // Should always be allowed + + // Other tools depend on time, user level, and complexity rules + // The exact results depend on current time and context + } + + [Fact] + public async Task ToolExecution_WithMultipleFilters_RespectsAuthorizationChain() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + + service.RegisterFilter(new AuthenticationFilter(priority: 1)); + service.RegisterFilter(new AuthorizationFilter(priority: 2)); + service.RegisterFilter(new QuotaFilter(priority: 3)); + + var context = CreateTestContext(); + + // Test 1: Unauthenticated user + var result1 = await service.AuthorizeToolExecutionAsync("secure_api", context); + Assert.False(result1.IsAuthorized); + Assert.IsType(result1.AdditionalData); + + // Test 2: Authenticated but unauthorized user + context.Properties["IsAuthenticated"] = true; + var result2 = await service.AuthorizeToolExecutionAsync("admin_tool", context); + Assert.False(result2.IsAuthorized); + Assert.Contains("insufficient permissions", result2.Reason.ToLowerInvariant()); + + // Test 3: Authorized user but quota exceeded + context.Properties["HasAdminPermission"] = true; + context.Properties["QuotaExceeded"] = true; + var result3 = await service.AuthorizeToolExecutionAsync("admin_tool", context); + Assert.False(result3.IsAuthorized); + Assert.Contains("quota", result3.Reason.ToLowerInvariant()); + + // Test 4: Fully authorized user with quota available + context.Properties["QuotaExceeded"] = false; + var result4 = await service.AuthorizeToolExecutionAsync("admin_tool", context); + Assert.True(result4.IsAuthorized); + } + + [Fact] + public async Task FilterChain_WithDynamicFilters_HandlesRuntimeChanges() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + + var dynamicFilter = new DynamicRulesFilter(); + service.RegisterFilter(dynamicFilter); + + var tools = new[] { CreateTestTool("dynamic_tool") }; + var context = CreateTestContext(); + + // Act & Assert 1: Initially restrictive + dynamicFilter.SetMode(DynamicRulesFilter.FilterMode.Restrictive); + var result1 = await service.FilterToolsAsync(tools, context); + Assert.Empty(result1); + + // Act & Assert 2: Change to permissive + dynamicFilter.SetMode(DynamicRulesFilter.FilterMode.Permissive); + var result2 = await service.FilterToolsAsync(tools, context); + Assert.Single(result2); + + // Act & Assert 3: Change to selective + dynamicFilter.SetMode(DynamicRulesFilter.FilterMode.Selective); + dynamicFilter.AddAllowedTool("dynamic_tool"); + var result3 = await service.FilterToolsAsync(tools, context); + Assert.Single(result3); + + // Act & Assert 4: Remove from allowed list + dynamicFilter.RemoveAllowedTool("dynamic_tool"); + var result4 = await service.FilterToolsAsync(tools, context); + Assert.Empty(result4); + } + + private static Tool CreateTestTool(string name, string? description = null) + { + return new Tool + { + Name = name, + Description = description ?? $"Test tool: {name}", + InputSchema = new JsonObject() + }; + } + + private static ToolAuthorizationContext CreateTestContext(string? sessionId = null, string? userId = null) + { + var context = ToolAuthorizationContext.ForSession(sessionId ?? "test-session"); + if (userId != null) + { + context.UserId = userId; + } + return context; + } + + // Test filter implementations for complex scenarios + + private class SecurityFilter : IToolFilter + { + public static int CallOrder { get; private set; } + private static int _globalCallOrder = 0; + + public SecurityFilter(int priority) + { + Priority = priority; + } + + public int Priority { get; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + CallOrder = Interlocked.Increment(ref _globalCallOrder); + + // Block tools with "secure" in the name + return Task.FromResult(!tool.Name.Contains("secure", StringComparison.OrdinalIgnoreCase)); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (toolName.Contains("secure", StringComparison.OrdinalIgnoreCase)) + { + return Task.FromResult(AuthorizationResult.Deny("Security policy violation")); + } + return Task.FromResult(AuthorizationResult.Allow()); + } + } + + private class RateLimitFilter : IToolFilter + { + public static int CallOrder { get; private set; } + private static int _globalCallOrder = 0; + + public RateLimitFilter(int priority) + { + Priority = priority; + } + + public int Priority { get; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + CallOrder = Interlocked.Increment(ref _globalCallOrder); + + // Block high rate tools + return Task.FromResult(!tool.Name.Contains("high_rate", StringComparison.OrdinalIgnoreCase)); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (toolName.Contains("high_rate", StringComparison.OrdinalIgnoreCase)) + { + return Task.FromResult(AuthorizationResult.Deny("Rate limit exceeded")); + } + return Task.FromResult(AuthorizationResult.Allow()); + } + } + + private class RoleBasedFilter : IToolFilter + { + public static int CallOrder { get; private set; } + private static int _globalCallOrder = 0; + + public RoleBasedFilter(int priority) + { + Priority = priority; + } + + public int Priority { get; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + CallOrder = Interlocked.Increment(ref _globalCallOrder); + + // Block admin tools if not admin + if (tool.Name.Contains("admin", StringComparison.OrdinalIgnoreCase)) + { + return Task.FromResult(context.UserRoles.Contains("admin")); + } + return Task.FromResult(true); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (toolName.Contains("admin", StringComparison.OrdinalIgnoreCase) && !context.UserRoles.Contains("admin")) + { + return Task.FromResult(AuthorizationResult.Deny("Insufficient role")); + } + return Task.FromResult(AuthorizationResult.Allow()); + } + } + + private class FeatureFlagFilter : IToolFilter + { + public static int CallOrder { get; private set; } + private static int _globalCallOrder = 0; + + public FeatureFlagFilter(int priority) + { + Priority = priority; + } + + public int Priority { get; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + CallOrder = Interlocked.Increment(ref _globalCallOrder); + + // Block beta features + return Task.FromResult(!tool.Name.Contains("beta", StringComparison.OrdinalIgnoreCase)); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (toolName.Contains("beta", StringComparison.OrdinalIgnoreCase)) + { + return Task.FromResult(AuthorizationResult.Deny("Feature not enabled")); + } + return Task.FromResult(AuthorizationResult.Allow()); + } + } + + private class AuditFilter : IToolFilter + { + public static int CallOrder { get; private set; } + private static int _globalCallOrder = 0; + + public AuditFilter(int priority) + { + Priority = priority; + } + + public int Priority { get; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + CallOrder = Interlocked.Increment(ref _globalCallOrder); + + // Audit filter always allows (just logs) + return Task.FromResult(true); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Log the access attempt (in real implementation) + return Task.FromResult(AuthorizationResult.Allow("Audited")); + } + } + + // Additional test filter implementations... + + private class ExecutionTrackingFilter : IToolFilter + { + private readonly string _name; + private readonly bool _shouldInclude; + + public ExecutionTrackingFilter(string name, int priority, bool shouldInclude) + { + _name = name; + Priority = priority; + _shouldInclude = shouldInclude; + } + + public int Priority { get; } + public bool WasCalled { get; private set; } + public DateTime CallTime { get; private set; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + WasCalled = true; + CallTime = DateTime.UtcNow; + return Task.FromResult(_shouldInclude); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + WasCalled = true; + CallTime = DateTime.UtcNow; + return Task.FromResult(_shouldInclude + ? AuthorizationResult.Allow($"Allowed by {_name}") + : AuthorizationResult.Deny($"Denied by {_name}")); + } + } + + private class ExceptionThrowingFilter : IToolFilter + { + public ExceptionThrowingFilter(int priority) + { + Priority = priority; + } + + public int Priority { get; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException("Test exception in filter"); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException("Test exception in filter"); + } + } + + // Additional complex filter implementations for advanced scenarios... + + private class TimeBasedFilter : IToolFilter + { + public TimeBasedFilter(int priority) + { + Priority = priority; + } + + public int Priority { get; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Block time-sensitive tasks outside business hours + if (tool.Name.Contains("time_sensitive")) + { + var hour = DateTime.Now.Hour; + return Task.FromResult(hour >= 9 && hour <= 17); + } + return Task.FromResult(true); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (toolName.Contains("time_sensitive")) + { + var hour = DateTime.Now.Hour; + if (hour < 9 || hour > 17) + { + return Task.FromResult(AuthorizationResult.Deny("Outside business hours")); + } + } + return Task.FromResult(AuthorizationResult.Allow()); + } + } + + private class UserLevelFilter : IToolFilter + { + public UserLevelFilter(int priority) + { + Priority = priority; + } + + public int Priority { get; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Complex tools require premium users + if (tool.Name.Contains("complex")) + { + return Task.FromResult(context.Properties.TryGetValue("UserLevel", out var level) && + level?.ToString() == "premium"); + } + return Task.FromResult(true); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (toolName.Contains("complex")) + { + if (!context.Properties.TryGetValue("UserLevel", out var level) || level?.ToString() != "premium") + { + return Task.FromResult(AuthorizationResult.Deny("Premium subscription required")); + } + } + return Task.FromResult(AuthorizationResult.Allow()); + } + } + + private class ToolComplexityFilter : IToolFilter + { + public ToolComplexityFilter(int priority) + { + Priority = priority; + } + + public int Priority { get; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Admin operations require special handling + return Task.FromResult(!tool.Name.Contains("admin") || + context.Properties.ContainsKey("HasAdminPermission")); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (toolName.Contains("admin") && !context.Properties.ContainsKey("HasAdminPermission")) + { + return Task.FromResult(AuthorizationResult.Deny("Admin permission required")); + } + return Task.FromResult(AuthorizationResult.Allow()); + } + } + + // Filters for authorization chain testing + + private class AuthenticationFilter : IToolFilter + { + public AuthenticationFilter(int priority) + { + Priority = priority; + } + + public int Priority { get; } + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var result = await CanExecuteToolAsync(tool.Name, context, cancellationToken); + return result.IsAuthorized; + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (!context.Properties.TryGetValue("IsAuthenticated", out var auth) || !(bool)auth!) + { + return Task.FromResult(AuthorizationResult.DenyInvalidToken("mcp-server")); + } + return Task.FromResult(AuthorizationResult.Allow()); + } + } + + private class AuthorizationFilter : IToolFilter + { + public AuthorizationFilter(int priority) + { + Priority = priority; + } + + public int Priority { get; } + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var result = await CanExecuteToolAsync(tool.Name, context, cancellationToken); + return result.IsAuthorized; + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (toolName.Contains("admin") && + (!context.Properties.TryGetValue("HasAdminPermission", out var perm) || !(bool)perm!)) + { + return Task.FromResult(AuthorizationResult.Deny("Insufficient permissions")); + } + return Task.FromResult(AuthorizationResult.Allow()); + } + } + + private class QuotaFilter : IToolFilter + { + public QuotaFilter(int priority) + { + Priority = priority; + } + + public int Priority { get; } + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var result = await CanExecuteToolAsync(tool.Name, context, cancellationToken); + return result.IsAuthorized; + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + if (context.Properties.TryGetValue("QuotaExceeded", out var exceeded) && (bool)exceeded!) + { + return Task.FromResult(AuthorizationResult.Deny("Quota exceeded")); + } + return Task.FromResult(AuthorizationResult.Allow()); + } + } + + private class DynamicRulesFilter : IToolFilter + { + public enum FilterMode + { + Restrictive, + Permissive, + Selective + } + + private FilterMode _mode = FilterMode.Restrictive; + private readonly HashSet _allowedTools = new(); + + public int Priority => 100; + + public void SetMode(FilterMode mode) + { + _mode = mode; + } + + public void AddAllowedTool(string toolName) + { + _allowedTools.Add(toolName); + } + + public void RemoveAllowedTool(string toolName) + { + _allowedTools.Remove(toolName); + } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + return _mode switch + { + FilterMode.Restrictive => Task.FromResult(false), + FilterMode.Permissive => Task.FromResult(true), + FilterMode.Selective => Task.FromResult(_allowedTools.Contains(tool.Name)), + _ => Task.FromResult(false) + }; + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var allowed = _mode switch + { + FilterMode.Restrictive => false, + FilterMode.Permissive => true, + FilterMode.Selective => _allowedTools.Contains(toolName), + _ => false + }; + + return Task.FromResult(allowed + ? AuthorizationResult.Allow($"Dynamic rule: {_mode}") + : AuthorizationResult.Deny($"Dynamic rule: {_mode}")); + } + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/Authorization/PerformanceAndThreadSafetyTests.cs b/tests/ModelContextProtocol.Tests/Server/Authorization/PerformanceAndThreadSafetyTests.cs new file mode 100644 index 00000000..1fa887d9 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/Authorization/PerformanceAndThreadSafetyTests.cs @@ -0,0 +1,745 @@ +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server.Authorization; +using ModelContextProtocol.Tests.Utils; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Text.Json.Nodes; +using Xunit; + +namespace ModelContextProtocol.Tests.Server.Authorization; + +/// +/// Performance and thread safety tests for the tool filtering system. +/// +public class PerformanceAndThreadSafetyTests : LoggedTest +{ + public PerformanceAndThreadSafetyTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) + { + } + + [Fact] + public async Task ToolAuthorizationService_HighVolumeFiltering_PerformsWithinExpectedTime() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + service.RegisterFilter(new AllowAllToolFilter()); + service.RegisterFilter(new ToolNamePatternFilter(new[] { "admin_*" }, allowMatching: false)); + service.RegisterFilter(new PerformanceTestFilter()); + + // Create a large number of tools + const int toolCount = 10000; + var tools = Enumerable.Range(0, toolCount) + .Select(i => CreateTestTool($"tool_{i}", $"Test tool number {i}")) + .ToArray(); + + var context = CreateTestContext(); + var stopwatch = Stopwatch.StartNew(); + + // Act + var filteredTools = await service.FilterToolsAsync(tools, context); + + // Assert + stopwatch.Stop(); + var elapsedMs = stopwatch.ElapsedMilliseconds; + + TestOutputHelper.WriteLine($"Filtered {toolCount} tools in {elapsedMs}ms ({toolCount / (double)elapsedMs * 1000:F0} tools/second)"); + + // Performance assertion: should process at least 1000 tools per second + Assert.True(elapsedMs < toolCount / 10, $"Performance too slow: {elapsedMs}ms for {toolCount} tools"); + + // Verify correctness + Assert.True(filteredTools.Count() < toolCount, "Some tools should be filtered out"); + Assert.All(filteredTools, tool => Assert.DoesNotContain("admin_", tool.Name)); + } + + [Fact] + public async Task ToolAuthorizationService_HighVolumeExecution_PerformsWithinExpectedTime() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + service.RegisterFilter(new AllowAllToolFilter()); + service.RegisterFilter(new PerformanceTestFilter()); + + const int operationCount = 10000; + var context = CreateTestContext(); + var stopwatch = Stopwatch.StartNew(); + + // Act + var tasks = Enumerable.Range(0, operationCount) + .Select(i => service.AuthorizeToolExecutionAsync($"tool_{i}", context)) + .ToArray(); + + var results = await Task.WhenAll(tasks); + + // Assert + stopwatch.Stop(); + var elapsedMs = stopwatch.ElapsedMilliseconds; + + TestOutputHelper.WriteLine($"Authorized {operationCount} tool executions in {elapsedMs}ms ({operationCount / (double)elapsedMs * 1000:F0} operations/second)"); + + // Performance assertion: should process at least 2000 operations per second + Assert.True(elapsedMs < operationCount / 20, $"Performance too slow: {elapsedMs}ms for {operationCount} operations"); + + // Verify correctness + Assert.All(results, result => Assert.True(result.IsAuthorized)); + } + + [Fact] + public async Task ToolAuthorizationService_ConcurrentRegistration_IsThreadSafe() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + const int concurrentOperations = 100; + const int filtersPerOperation = 10; + + var allFilters = new List(); + var registrationTasks = new List(); + + // Act - Register filters concurrently + for (int i = 0; i < concurrentOperations; i++) + { + var operationIndex = i; + var task = Task.Run(() => + { + for (int j = 0; j < filtersPerOperation; j++) + { + var filter = new ConcurrentTestFilter($"Filter_{operationIndex}_{j}", priority: operationIndex * 100 + j); + service.RegisterFilter(filter); + lock (allFilters) + { + allFilters.Add(filter); + } + } + }); + registrationTasks.Add(task); + } + + await Task.WhenAll(registrationTasks); + + // Assert + var registeredFilters = service.GetRegisteredFilters(); + Assert.Equal(concurrentOperations * filtersPerOperation, registeredFilters.Count); + + // Verify all filters were registered + foreach (var filter in allFilters) + { + Assert.Contains(filter, registeredFilters); + } + } + + [Fact] + public async Task ToolAuthorizationService_ConcurrentFiltering_IsThreadSafe() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + + // Add filters with different behaviors + service.RegisterFilter(new ConcurrentTestFilter("AllowFilter", allowAll: true, priority: 1)); + service.RegisterFilter(new ConcurrentTestFilter("PatternFilter", allowAll: false, priority: 2)); + service.RegisterFilter(new ThreadSafeCountingFilter(priority: 3)); + + var tools = Enumerable.Range(0, 1000) + .Select(i => CreateTestTool($"tool_{i}")) + .ToArray(); + + const int concurrentOperations = 50; + var contexts = Enumerable.Range(0, concurrentOperations) + .Select(i => CreateTestContext($"session_{i}", $"user_{i}")) + .ToArray(); + + // Act - Filter tools concurrently + var filteringTasks = contexts.Select(context => + service.FilterToolsAsync(tools, context)).ToArray(); + + var results = await Task.WhenAll(filteringTasks); + + // Assert + Assert.All(results, result => + { + Assert.NotNull(result); + Assert.True(result.Any(), "Should have some tools after filtering"); + }); + + // Verify thread-safe counter + var countingFilter = service.GetRegisteredFilters() + .OfType() + .First(); + + // Should have been called for each tool in each context + var expectedCallCount = tools.Length * concurrentOperations; + Assert.Equal(expectedCallCount, countingFilter.CallCount); + } + + [Fact] + public async Task ToolAuthorizationService_ConcurrentExecutionAuthorization_IsThreadSafe() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + service.RegisterFilter(new ThreadSafeCountingFilter()); + service.RegisterFilter(new ConcurrentAuthorizationFilter()); + + const int concurrentOperations = 100; + const int operationsPerThread = 100; + + var contexts = Enumerable.Range(0, concurrentOperations) + .Select(i => CreateTestContext($"session_{i}", $"user_{i}")) + .ToArray(); + + // Act - Authorize tool executions concurrently + var authorizationTasks = new List>(); + + for (int i = 0; i < concurrentOperations; i++) + { + var context = contexts[i]; + var task = Task.Run(async () => + { + var results = new AuthorizationResult[operationsPerThread]; + for (int j = 0; j < operationsPerThread; j++) + { + results[j] = await service.AuthorizeToolExecutionAsync($"tool_{i}_{j}", context); + } + return results; + }); + authorizationTasks.Add(task); + } + + var allResults = await Task.WhenAll(authorizationTasks); + + // Assert + var flatResults = allResults.SelectMany(r => r).ToArray(); + Assert.Equal(concurrentOperations * operationsPerThread, flatResults.Length); + + // Verify all operations completed successfully (no exceptions) + Assert.All(flatResults, result => Assert.NotNull(result)); + + // Verify thread-safe counter + var countingFilter = service.GetRegisteredFilters() + .OfType() + .First(); + + Assert.Equal(concurrentOperations * operationsPerThread, countingFilter.ExecutionCallCount); + } + + [Fact] + public async Task ToolAuthorizationService_UnderLoad_MaintainsCorrectness() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + + // Add filters with complex logic + service.RegisterFilter(new LoadTestFilter("admin", priority: 1)); + service.RegisterFilter(new LoadTestFilter("user", priority: 2)); + service.RegisterFilter(new LoadTestFilter("guest", priority: 3)); + + var tools = new[] + { + CreateTestTool("admin_tool"), + CreateTestTool("user_tool"), + CreateTestTool("guest_tool"), + CreateTestTool("public_tool") + }; + + const int concurrentUsers = 50; + const int operationsPerUser = 50; + + // Create contexts for different user types + var adminContexts = Enumerable.Range(0, concurrentUsers / 3) + .Select(i => CreateTestContext($"admin_session_{i}", $"admin_{i}", new[] { "admin" })) + .ToArray(); + + var userContexts = Enumerable.Range(0, concurrentUsers / 3) + .Select(i => CreateTestContext($"user_session_{i}", $"user_{i}", new[] { "user" })) + .ToArray(); + + var guestContexts = Enumerable.Range(0, concurrentUsers / 3) + .Select(i => CreateTestContext($"guest_session_{i}", $"guest_{i}", new[] { "guest" })) + .ToArray(); + + var allContexts = adminContexts.Concat(userContexts).Concat(guestContexts).ToArray(); + + // Act - Simulate load + var loadTasks = allContexts.Select(async context => + { + var results = new List(); + + for (int i = 0; i < operationsPerUser; i++) + { + foreach (var tool in tools) + { + var result = await service.AuthorizeToolExecutionAsync(tool.Name, context); + results.Add(result); + } + } + + return new { Context = context, Results = results }; + }).ToArray(); + + var loadResults = await Task.WhenAll(loadTasks); + + // Assert correctness under load + foreach (var userResult in loadResults) + { + var userRoles = userResult.Context.UserRoles; + + foreach (var result in userResult.Results) + { + // Verify authorization logic is correctly applied + if (userRoles.Contains("admin")) + { + Assert.True(result.IsAuthorized, "Admin should have access to all tools"); + } + else if (userRoles.Contains("user")) + { + // Users should not have access to admin tools + if (result.Reason?.Contains("admin_tool") == true) + { + Assert.False(result.IsAuthorized, "Users should not access admin tools"); + } + } + else if (userRoles.Contains("guest")) + { + // Guests should only have access to guest and public tools + if (result.Reason?.Contains("admin_tool") == true || result.Reason?.Contains("user_tool") == true) + { + Assert.False(result.IsAuthorized, "Guests should have limited access"); + } + } + } + } + } + + [Fact] + public void ToolAuthorizationService_MemoryUsage_StaysWithinReasonableBounds() + { + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + + // Force initial garbage collection + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + var initialMemory = GC.GetTotalMemory(false); + + // Act - Register many filters and perform operations + const int filterCount = 1000; + for (int i = 0; i < filterCount; i++) + { + service.RegisterFilter(new MemoryTestFilter(i)); + } + + // Perform operations that might leak memory + var tools = Enumerable.Range(0, 1000) + .Select(i => CreateTestTool($"tool_{i}")) + .ToArray(); + + var context = CreateTestContext(); + + // Run filtering operations multiple times + for (int iteration = 0; iteration < 10; iteration++) + { + _ = service.FilterToolsAsync(tools, context).GetAwaiter().GetResult(); + } + + // Force garbage collection + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + var finalMemory = GC.GetTotalMemory(false); + var memoryIncrease = finalMemory - initialMemory; + + // Assert - Memory increase should be reasonable + const long maxReasonableIncrease = 50 * 1024 * 1024; // 50MB + TestOutputHelper.WriteLine($"Memory increase: {memoryIncrease / 1024.0 / 1024.0:F2} MB"); + + Assert.True(memoryIncrease < maxReasonableIncrease, + $"Memory usage increased by {memoryIncrease / 1024.0 / 1024.0:F2} MB, which exceeds the limit of {maxReasonableIncrease / 1024.0 / 1024.0:F2} MB"); + } + + [Fact] + public async Task ToolFilterAggregator_ConcurrentServiceResolution_IsThreadSafe() + { + // This test would require dependency injection setup + // For now, we'll test the core thread safety of filter aggregation + + // Arrange + var service = new ToolAuthorizationService(LoggerFactory.CreateLogger()); + + // Add filters that will be accessed concurrently + var sharedFilter = new ConcurrentAccessTestFilter(); + service.RegisterFilter(sharedFilter); + service.RegisterFilter(new AllowAllToolFilter()); + + var tools = Enumerable.Range(0, 100) + .Select(i => CreateTestTool($"tool_{i}")) + .ToArray(); + + const int concurrentThreads = 20; + var contexts = Enumerable.Range(0, concurrentThreads) + .Select(i => CreateTestContext($"session_{i}")) + .ToArray(); + + // Act - Access the same filter from multiple threads + var concurrentTasks = contexts.Select(context => + Task.Run(async () => + { + for (int i = 0; i < 50; i++) + { + await service.FilterToolsAsync(tools, context); + await service.AuthorizeToolExecutionAsync($"test_tool_{i}", context); + } + })).ToArray(); + + await Task.WhenAll(concurrentTasks); + + // Assert - No exceptions should occur and state should be consistent + Assert.True(sharedFilter.TotalCalls > 0); + Assert.Equal(0, sharedFilter.ErrorCount); + } + + private static Tool CreateTestTool(string name, string? description = null) + { + return new Tool + { + Name = name, + Description = description ?? $"Test tool: {name}", + InputSchema = new JsonObject() + }; + } + + private static ToolAuthorizationContext CreateTestContext( + string? sessionId = null, + string? userId = null, + IEnumerable? roles = null) + { + var context = ToolAuthorizationContext.ForSession(sessionId ?? "test-session"); + if (userId != null) + { + context.UserId = userId; + } + if (roles != null) + { + foreach (var role in roles) + { + context.UserRoles.Add(role); + } + } + return context; + } + + // Test filter implementations for performance and concurrency testing + + private class PerformanceTestFilter : IToolFilter + { + public int Priority => 50; + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Simulate some work without being too slow + var hash = tool.Name.GetHashCode(); + return Task.FromResult(hash % 10 != 0); // Filter out ~10% of tools + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var hash = toolName.GetHashCode(); + return Task.FromResult(hash % 10 != 0 + ? AuthorizationResult.Allow("Performance test passed") + : AuthorizationResult.Deny("Performance test filtered")); + } + } + + private class ConcurrentTestFilter : IToolFilter + { + private readonly string _name; + private readonly bool _allowAll; + + public ConcurrentTestFilter(string name, bool allowAll = true, int priority = 100) + { + _name = name; + _allowAll = allowAll; + Priority = priority; + } + + public int Priority { get; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + return Task.FromResult(_allowAll || !tool.Name.Contains("filtered")); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + return Task.FromResult(_allowAll || !toolName.Contains("filtered") + ? AuthorizationResult.Allow($"Allowed by {_name}") + : AuthorizationResult.Deny($"Denied by {_name}")); + } + } + + private class ThreadSafeCountingFilter : IToolFilter + { + private long _callCount; + private long _executionCallCount; + + public ThreadSafeCountingFilter(int priority = 100) + { + Priority = priority; + } + + public int Priority { get; } + public long CallCount => _callCount; + public long ExecutionCallCount => _executionCallCount; + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + Interlocked.Increment(ref _callCount); + return Task.FromResult(true); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + Interlocked.Increment(ref _executionCallCount); + return Task.FromResult(AuthorizationResult.Allow("Thread-safe counting")); + } + } + + private class ConcurrentAuthorizationFilter : IToolFilter + { + private readonly ConcurrentDictionary _callCounts = new(); + + public int Priority => 100; + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var result = await CanExecuteToolAsync(tool.Name, context, cancellationToken); + return result.IsAuthorized; + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + _callCounts.AddOrUpdate(toolName, 1, (key, count) => count + 1); + + // Simulate some processing time + Thread.SpinWait(100); + + return Task.FromResult(AuthorizationResult.Allow("Concurrent authorization completed")); + } + } + + private class LoadTestFilter : IToolFilter + { + private readonly string _requiredRole; + + public LoadTestFilter(string requiredRole, int priority = 100) + { + _requiredRole = requiredRole; + Priority = priority; + } + + public int Priority { get; } + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var result = await CanExecuteToolAsync(tool.Name, context, cancellationToken); + return result.IsAuthorized; + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Complex authorization logic + if (toolName.StartsWith($"{_requiredRole}_") && !context.UserRoles.Contains(_requiredRole)) + { + return Task.FromResult(AuthorizationResult.Deny($"Tool {toolName} requires role: {_requiredRole}")); + } + + if (context.UserRoles.Contains("admin")) + { + return Task.FromResult(AuthorizationResult.Allow("Admin access")); + } + + if (toolName == "public_tool") + { + return Task.FromResult(AuthorizationResult.Allow("Public access")); + } + + return Task.FromResult(context.UserRoles.Contains(_requiredRole) + ? AuthorizationResult.Allow($"Role-based access: {_requiredRole}") + : AuthorizationResult.Deny($"Insufficient role for {toolName}")); + } + } + + private class MemoryTestFilter : IToolFilter + { + private readonly int _id; + private readonly byte[] _data; // Simulate some memory usage + + public MemoryTestFilter(int id) + { + _id = id; + Priority = id; + _data = new byte[1024]; // 1KB per filter + } + + public int Priority { get; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Simulate some computation + var hash = tool.Name.GetHashCode() ^ _id; + return Task.FromResult(hash % 100 != 0); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var hash = toolName.GetHashCode() ^ _id; + return Task.FromResult(hash % 100 != 0 + ? AuthorizationResult.Allow($"Memory test {_id}") + : AuthorizationResult.Deny($"Memory test {_id} filtered")); + } + } + + private class ConcurrentAccessTestFilter : IToolFilter + { + private long _totalCalls; + private long _errorCount; + + public int Priority => 100; + public long TotalCalls => _totalCalls; + public long ErrorCount => _errorCount; + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + Interlocked.Increment(ref _totalCalls); + + try + { + // Simulate shared resource access + var result = ProcessSharedData(tool.Name); + return Task.FromResult(result); + } + catch + { + Interlocked.Increment(ref _errorCount); + return Task.FromResult(false); + } + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + Interlocked.Increment(ref _totalCalls); + + try + { + var result = ProcessSharedData(toolName); + return Task.FromResult(result + ? AuthorizationResult.Allow("Concurrent access successful") + : AuthorizationResult.Deny("Concurrent access filtered")); + } + catch + { + Interlocked.Increment(ref _errorCount); + return Task.FromResult(AuthorizationResult.Deny("Concurrent access error")); + } + } + + private bool ProcessSharedData(string input) + { + // Simulate some shared data processing + var hash = input.GetHashCode(); + Thread.SpinWait(10); // Small delay to increase chance of race conditions + return hash % 4 != 0; // Allow 75% of tools + } + } +} + +/// +/// Additional performance test utilities. +/// +public static class PerformanceTestUtilities +{ + /// + /// Measures the execution time of an async operation. + /// + public static async Task<(T Result, TimeSpan Duration)> MeasureAsync(Func> operation) + { + var stopwatch = Stopwatch.StartNew(); + var result = await operation(); + stopwatch.Stop(); + return (result, stopwatch.Elapsed); + } + + /// + /// Measures the execution time of a sync operation. + /// + public static (T Result, TimeSpan Duration) Measure(Func operation) + { + var stopwatch = Stopwatch.StartNew(); + var result = operation(); + stopwatch.Stop(); + return (result, stopwatch.Elapsed); + } + + /// + /// Runs a performance benchmark with multiple iterations. + /// + public static async Task BenchmarkAsync( + Func> operation, + int iterations = 100, + int warmupIterations = 10) + { + // Warmup + for (int i = 0; i < warmupIterations; i++) + { + await operation(); + } + + var durations = new List(); + + // Actual benchmark + for (int i = 0; i < iterations; i++) + { + var (_, duration) = await MeasureAsync(operation); + durations.Add(duration); + } + + return new BenchmarkResult(durations); + } + + public class BenchmarkResult + { + public BenchmarkResult(IEnumerable durations) + { + var sortedDurations = durations.OrderBy(d => d).ToArray(); + + Min = sortedDurations.First(); + Max = sortedDurations.Last(); + Average = TimeSpan.FromTicks((long)sortedDurations.Average(d => d.Ticks)); + Median = sortedDurations[sortedDurations.Length / 2]; + + // 95th percentile + var index95 = (int)(sortedDurations.Length * 0.95); + Percentile95 = sortedDurations[index95]; + + Iterations = sortedDurations.Length; + } + + public TimeSpan Min { get; } + public TimeSpan Max { get; } + public TimeSpan Average { get; } + public TimeSpan Median { get; } + public TimeSpan Percentile95 { get; } + public int Iterations { get; } + + public override string ToString() + { + return $"Benchmark Results ({Iterations} iterations):\n" + + $" Min: {Min.TotalMilliseconds:F2}ms\n" + + $" Max: {Max.TotalMilliseconds:F2}ms\n" + + $" Average: {Average.TotalMilliseconds:F2}ms\n" + + $" Median: {Median.TotalMilliseconds:F2}ms\n" + + $" 95th Percentile: {Percentile95.TotalMilliseconds:F2}ms"; + } + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/Authorization/ToolAuthorizationServiceTests.cs b/tests/ModelContextProtocol.Tests/Server/Authorization/ToolAuthorizationServiceTests.cs new file mode 100644 index 00000000..03ec294b --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/Authorization/ToolAuthorizationServiceTests.cs @@ -0,0 +1,610 @@ +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server.Authorization; +using ModelContextProtocol.Tests.Utils; +using Moq; +using System.Text.Json.Nodes; +using Xunit; + +namespace ModelContextProtocol.Tests.Server.Authorization; + +/// +/// Unit tests for ToolAuthorizationService functionality. +/// +public class ToolAuthorizationServiceTests +{ + [Fact] + public void Constructor_WithNullFilters_ThrowsArgumentNullException() + { + // Act & Assert + Assert.Throws(() => new ToolAuthorizationService((IEnumerable)null!)); + } + + [Fact] + public void Constructor_WithLogger_DoesNotThrow() + { + // Arrange + var logger = Mock.Of>(); + + // Act & Assert + var service = new ToolAuthorizationService(logger); + Assert.NotNull(service); + } + + [Fact] + public void Constructor_WithFiltersAndLogger_RegistersFilters() + { + // Arrange + var filter1 = new AllowAllToolFilter(1); + var filter2 = new DenyAllToolFilter(2); + var filters = new IToolFilter[] { filter1, filter2 }; + var logger = Mock.Of>(); + + // Act + var service = new ToolAuthorizationService(filters, logger); + + // Assert + var registeredFilters = service.GetRegisteredFilters(); + Assert.Equal(2, registeredFilters.Count); + Assert.Contains(filter1, registeredFilters); + Assert.Contains(filter2, registeredFilters); + } + + [Fact] + public void Constructor_WithFiltersContainingNull_IgnoresNullFilters() + { + // Arrange + var filter1 = new AllowAllToolFilter(); + var filters = new IToolFilter?[] { filter1, null, filter1 }; + + // Act + var service = new ToolAuthorizationService(filters.Where(f => f != null)!); + + // Assert + var registeredFilters = service.GetRegisteredFilters(); + Assert.Equal(2, registeredFilters.Count); + Assert.All(registeredFilters, f => Assert.Equal(filter1, f)); + } + + [Fact] + public async Task FilterToolsAsync_WithNullTools_ThrowsArgumentNullException() + { + // Arrange + var service = new ToolAuthorizationService(); + var context = CreateTestContext(); + + // Act & Assert + await Assert.ThrowsAsync(() => + service.FilterToolsAsync(null!, context)); + } + + [Fact] + public async Task FilterToolsAsync_WithNullContext_ThrowsArgumentNullException() + { + // Arrange + var service = new ToolAuthorizationService(); + var tools = new[] { CreateTestTool("test") }; + + // Act & Assert + await Assert.ThrowsAsync(() => + service.FilterToolsAsync(tools, null!)); + } + + [Fact] + public async Task FilterToolsAsync_WithNoFilters_ReturnsAllTools() + { + // Arrange + var service = new ToolAuthorizationService(); + var tools = new[] { CreateTestTool("tool1"), CreateTestTool("tool2") }; + var context = CreateTestContext(); + + // Act + var result = await service.FilterToolsAsync(tools, context); + + // Assert + Assert.Equal(2, result.Count()); + Assert.Equal(tools, result); + } + + [Fact] + public async Task FilterToolsAsync_WithAllowAllFilter_ReturnsAllTools() + { + // Arrange + var service = new ToolAuthorizationService(); + service.RegisterFilter(new AllowAllToolFilter()); + var tools = new[] { CreateTestTool("tool1"), CreateTestTool("tool2") }; + var context = CreateTestContext(); + + // Act + var result = await service.FilterToolsAsync(tools, context); + + // Assert + Assert.Equal(2, result.Count()); + Assert.Equal(tools, result); + } + + [Fact] + public async Task FilterToolsAsync_WithDenyAllFilter_ReturnsNoTools() + { + // Arrange + var service = new ToolAuthorizationService(); + service.RegisterFilter(new DenyAllToolFilter()); + var tools = new[] { CreateTestTool("tool1"), CreateTestTool("tool2") }; + var context = CreateTestContext(); + + // Act + var result = await service.FilterToolsAsync(tools, context); + + // Assert + Assert.Empty(result); + } + + [Fact] + public async Task FilterToolsAsync_WithSelectiveFilter_ReturnsFilteredTools() + { + // Arrange + var service = new ToolAuthorizationService(); + service.RegisterFilter(new ToolNamePatternFilter(new[] { "read_*" }, allowMatching: true)); + var tools = new[] + { + CreateTestTool("read_data"), + CreateTestTool("write_data"), + CreateTestTool("read_file") + }; + var context = CreateTestContext(); + + // Act + var result = await service.FilterToolsAsync(tools, context); + + // Assert + Assert.Equal(2, result.Count()); + Assert.Contains(result, t => t.Name == "read_data"); + Assert.Contains(result, t => t.Name == "read_file"); + Assert.DoesNotContain(result, t => t.Name == "write_data"); + } + + [Fact] + public async Task FilterToolsAsync_WithMultipleFilters_AppliesPriorityOrder() + { + // Arrange + var service = new ToolAuthorizationService(); + // Lower priority number = higher priority + service.RegisterFilter(new AllowAllToolFilter(priority: 100)); // Lower priority + service.RegisterFilter(new DenyAllToolFilter(priority: 1)); // Higher priority + var tools = new[] { CreateTestTool("tool1") }; + var context = CreateTestContext(); + + // Act + var result = await service.FilterToolsAsync(tools, context); + + // Assert + // DenyAllToolFilter should execute first and block all tools + Assert.Empty(result); + } + + [Fact] + public async Task FilterToolsAsync_WithMultipleFilters_StopsOnFirstDeny() + { + // Arrange + var mockFilter1 = new Mock(); + var mockFilter2 = new Mock(); + + mockFilter1.Setup(f => f.Priority).Returns(1); + mockFilter1.Setup(f => f.ShouldIncludeToolAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync(false); + + mockFilter2.Setup(f => f.Priority).Returns(2); + + var service = new ToolAuthorizationService(); + service.RegisterFilter(mockFilter1.Object); + service.RegisterFilter(mockFilter2.Object); + + var tools = new[] { CreateTestTool("tool1") }; + var context = CreateTestContext(); + + // Act + var result = await service.FilterToolsAsync(tools, context); + + // Assert + Assert.Empty(result); + mockFilter1.Verify(f => f.ShouldIncludeToolAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Once); + // Second filter should not be called since first filter denied + mockFilter2.Verify(f => f.ShouldIncludeToolAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Never); + } + + [Fact] + public async Task FilterToolsAsync_WithFilterException_DeniesAccess() + { + // Arrange + var service = new ToolAuthorizationService(); + service.RegisterFilter(new ExceptionThrowingFilter()); + var tools = new[] { CreateTestTool("tool1") }; + var context = CreateTestContext(); + + // Act + var result = await service.FilterToolsAsync(tools, context); + + // Assert + Assert.Empty(result); + } + + [Fact] + public async Task AuthorizeToolExecutionAsync_WithNullOrEmptyToolName_ThrowsArgumentException() + { + // Arrange + var service = new ToolAuthorizationService(); + var context = CreateTestContext(); + + // Act & Assert + await Assert.ThrowsAsync(() => + service.AuthorizeToolExecutionAsync(null!, context)); + await Assert.ThrowsAsync(() => + service.AuthorizeToolExecutionAsync("", context)); + await Assert.ThrowsAsync(() => + service.AuthorizeToolExecutionAsync(" ", context)); + } + + [Fact] + public async Task AuthorizeToolExecutionAsync_WithNullContext_ThrowsArgumentNullException() + { + // Arrange + var service = new ToolAuthorizationService(); + + // Act & Assert + await Assert.ThrowsAsync(() => + service.AuthorizeToolExecutionAsync("test_tool", null!)); + } + + [Fact] + public async Task AuthorizeToolExecutionAsync_WithNoFilters_ReturnsAllow() + { + // Arrange + var service = new ToolAuthorizationService(); + var context = CreateTestContext(); + + // Act + var result = await service.AuthorizeToolExecutionAsync("test_tool", context); + + // Assert + Assert.True(result.IsAuthorized); + Assert.Equal("No filters configured", result.Reason); + } + + [Fact] + public async Task AuthorizeToolExecutionAsync_WithAllowFilter_ReturnsAllow() + { + // Arrange + var service = new ToolAuthorizationService(); + service.RegisterFilter(new AllowAllToolFilter()); + var context = CreateTestContext(); + + // Act + var result = await service.AuthorizeToolExecutionAsync("test_tool", context); + + // Assert + Assert.True(result.IsAuthorized); + Assert.Equal("All filters passed", result.Reason); + } + + [Fact] + public async Task AuthorizeToolExecutionAsync_WithDenyFilter_ReturnsDeny() + { + // Arrange + var service = new ToolAuthorizationService(); + service.RegisterFilter(new DenyAllToolFilter()); + var context = CreateTestContext(); + + // Act + var result = await service.AuthorizeToolExecutionAsync("test_tool", context); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal("All tools denied", result.Reason); + } + + [Fact] + public async Task AuthorizeToolExecutionAsync_WithMultipleFilters_StopsOnFirstDeny() + { + // Arrange + var mockFilter1 = new Mock(); + var mockFilter2 = new Mock(); + + mockFilter1.Setup(f => f.Priority).Returns(1); + mockFilter1.Setup(f => f.CanExecuteToolAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync(AuthorizationResult.Deny("Access denied")); + + mockFilter2.Setup(f => f.Priority).Returns(2); + + var service = new ToolAuthorizationService(); + service.RegisterFilter(mockFilter1.Object); + service.RegisterFilter(mockFilter2.Object); + + var context = CreateTestContext(); + + // Act + var result = await service.AuthorizeToolExecutionAsync("test_tool", context); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal("Access denied", result.Reason); + mockFilter1.Verify(f => f.CanExecuteToolAsync("test_tool", context, It.IsAny()), Times.Once); + // Second filter should not be called since first filter denied + mockFilter2.Verify(f => f.CanExecuteToolAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Never); + } + + [Fact] + public async Task AuthorizeToolExecutionAsync_WithFilterException_ReturnsDeny() + { + // Arrange + var service = new ToolAuthorizationService(); + service.RegisterFilter(new ExceptionThrowingFilter()); + var context = CreateTestContext(); + + // Act + var result = await service.AuthorizeToolExecutionAsync("test_tool", context); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Contains("Filter error", result.Reason); + } + + [Fact] + public async Task AuthorizeToolExecutionAsync_WithCancellation_ThrowsOperationCanceledException() + { + // Arrange + var service = new ToolAuthorizationService(); + service.RegisterFilter(new SlowFilter()); + var context = CreateTestContext(); + var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert + await Assert.ThrowsAsync(() => + service.AuthorizeToolExecutionAsync("test_tool", context, cts.Token)); + } + + [Fact] + public void RegisterFilter_WithNullFilter_ThrowsArgumentNullException() + { + // Arrange + var service = new ToolAuthorizationService(); + + // Act & Assert + Assert.Throws(() => service.RegisterFilter(null!)); + } + + [Fact] + public void RegisterFilter_WithValidFilter_AddsToCollection() + { + // Arrange + var service = new ToolAuthorizationService(); + var filter = new AllowAllToolFilter(); + + // Act + service.RegisterFilter(filter); + + // Assert + var registeredFilters = service.GetRegisteredFilters(); + Assert.Single(registeredFilters); + Assert.Contains(filter, registeredFilters); + } + + [Fact] + public void RegisterFilter_WithMultipleFilters_AddsAll() + { + // Arrange + var service = new ToolAuthorizationService(); + var filter1 = new AllowAllToolFilter(); + var filter2 = new DenyAllToolFilter(); + + // Act + service.RegisterFilter(filter1); + service.RegisterFilter(filter2); + + // Assert + var registeredFilters = service.GetRegisteredFilters(); + Assert.Equal(2, registeredFilters.Count); + Assert.Contains(filter1, registeredFilters); + Assert.Contains(filter2, registeredFilters); + } + + [Fact] + public void UnregisterFilter_WithNullFilter_ThrowsArgumentNullException() + { + // Arrange + var service = new ToolAuthorizationService(); + + // Act & Assert + Assert.Throws(() => service.UnregisterFilter(null!)); + } + + [Fact] + public void UnregisterFilter_WithRegisteredFilter_RemovesFromCollection() + { + // Arrange + var service = new ToolAuthorizationService(); + var filter = new AllowAllToolFilter(); + service.RegisterFilter(filter); + + // Act + service.UnregisterFilter(filter); + + // Assert + var registeredFilters = service.GetRegisteredFilters(); + Assert.Empty(registeredFilters); + } + + [Fact] + public void UnregisterFilter_WithUnregisteredFilter_DoesNotThrow() + { + // Arrange + var service = new ToolAuthorizationService(); + var filter = new AllowAllToolFilter(); + + // Act & Assert + service.UnregisterFilter(filter); // Should not throw + Assert.Empty(service.GetRegisteredFilters()); + } + + [Fact] + public void UnregisterFilter_WithMultipleFilters_RemovesOnlySpecified() + { + // Arrange + var service = new ToolAuthorizationService(); + var filter1 = new AllowAllToolFilter(); + var filter2 = new DenyAllToolFilter(); + service.RegisterFilter(filter1); + service.RegisterFilter(filter2); + + // Act + service.UnregisterFilter(filter1); + + // Assert + var registeredFilters = service.GetRegisteredFilters(); + Assert.Single(registeredFilters); + Assert.Contains(filter2, registeredFilters); + Assert.DoesNotContain(filter1, registeredFilters); + } + + [Fact] + public void GetRegisteredFilters_ReturnsReadOnlyCollection() + { + // Arrange + var service = new ToolAuthorizationService(); + var filter = new AllowAllToolFilter(); + service.RegisterFilter(filter); + + // Act + var registeredFilters = service.GetRegisteredFilters(); + + // Assert + Assert.IsAssignableFrom>(registeredFilters); + Assert.Single(registeredFilters); + Assert.Contains(filter, registeredFilters); + } + + [Fact] + public async Task ConcurrentOperations_AreThreadSafe() + { + // Arrange + var service = new ToolAuthorizationService(); + var tools = Enumerable.Range(0, 100).Select(i => CreateTestTool($"tool{i}")).ToArray(); + var context = CreateTestContext(); + + // Act - Run multiple concurrent operations + var tasks = new List(); + + // Register filters concurrently + for (int i = 0; i < 10; i++) + { + var filter = new AllowAllToolFilter(i); + tasks.Add(Task.Run(() => service.RegisterFilter(filter))); + } + + // Filter tools concurrently + for (int i = 0; i < 10; i++) + { + tasks.Add(Task.Run(async () => await service.FilterToolsAsync(tools, context))); + } + + // Authorize tools concurrently + for (int i = 0; i < 10; i++) + { + var toolName = $"tool{i}"; + tasks.Add(Task.Run(async () => await service.AuthorizeToolExecutionAsync(toolName, context))); + } + + await Task.WhenAll(tasks); + + // Assert - No exceptions should be thrown and operations should complete + Assert.Equal(10, service.GetRegisteredFilters().Count); + } + + [Fact] + public async Task FilterToolsAsync_WithMixedFilters_AppliesCorrectLogic() + { + // Arrange + var service = new ToolAuthorizationService(); + + // Add filters with different priorities and behaviors + service.RegisterFilter(new ToolNamePatternFilter(new[] { "admin_*" }, allowMatching: false, priority: 1)); // Block admin tools first + service.RegisterFilter(new AllowAllToolFilter(priority: 100)); // Allow everything else + + var tools = new[] + { + CreateTestTool("admin_delete"), + CreateTestTool("user_profile"), + CreateTestTool("admin_create"), + CreateTestTool("public_info") + }; + var context = CreateTestContext(); + + // Act + var result = await service.FilterToolsAsync(tools, context); + + // Assert + Assert.Equal(2, result.Count()); + Assert.Contains(result, t => t.Name == "user_profile"); + Assert.Contains(result, t => t.Name == "public_info"); + Assert.DoesNotContain(result, t => t.Name == "admin_delete"); + Assert.DoesNotContain(result, t => t.Name == "admin_create"); + } + + private static Tool CreateTestTool(string name, string? description = null) + { + return new Tool + { + Name = name, + Description = description ?? $"Test tool: {name}", + InputSchema = new JsonObject() + }; + } + + private static ToolAuthorizationContext CreateTestContext(string? sessionId = null, string? userId = null) + { + var context = ToolAuthorizationContext.ForSession(sessionId ?? "test-session"); + if (userId != null) + { + context.UserId = userId; + } + return context; + } + + /// + /// Test filter that throws exceptions to test error handling. + /// + private class ExceptionThrowingFilter : IToolFilter + { + public int Priority => 1; + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException("Test exception in ShouldIncludeToolAsync"); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException("Test exception in CanExecuteToolAsync"); + } + } + + /// + /// Test filter that simulates slow operations for testing cancellation. + /// + private class SlowFilter : IToolFilter + { + public int Priority => 1; + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + await Task.Delay(5000, cancellationToken); + return true; + } + + public async Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + await Task.Delay(5000, cancellationToken); + return AuthorizationResult.Allow(); + } + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/Authorization/ToolFilterAggregatorTests.cs b/tests/ModelContextProtocol.Tests/Server/Authorization/ToolFilterAggregatorTests.cs new file mode 100644 index 00000000..2a382c63 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/Authorization/ToolFilterAggregatorTests.cs @@ -0,0 +1,532 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server.Authorization; +using ModelContextProtocol.Tests.Utils; +using Moq; +using System.Text.Json.Nodes; +using Xunit; + +namespace ModelContextProtocol.Tests.Server.Authorization; + +/// +/// Unit tests for ToolFilterAggregator and filter priority handling. +/// +public class ToolFilterAggregatorTests +{ + [Fact] + public void Constructor_WithNullServiceProvider_ThrowsArgumentNullException() + { + // Act & Assert + Assert.Throws(() => new ToolFilterAggregator(null!)); + } + + [Fact] + public void Constructor_WithValidServiceProvider_DoesNotThrow() + { + // Arrange + var serviceProvider = Mock.Of(); + var logger = Mock.Of>(); + + // Act & Assert + var aggregator = new ToolFilterAggregator(serviceProvider, logger); + Assert.NotNull(aggregator); + } + + [Fact] + public void Priority_ReturnsMinValue() + { + // Arrange + var serviceProvider = Mock.Of(); + var aggregator = new ToolFilterAggregator(serviceProvider); + + // Act & Assert + Assert.Equal(int.MinValue, aggregator.Priority); + } + + [Fact] + public async Task ShouldIncludeToolAsync_WithNullTool_ThrowsArgumentNullException() + { + // Arrange + var serviceProvider = Mock.Of(); + var aggregator = new ToolFilterAggregator(serviceProvider); + var context = CreateTestContext(); + + // Act & Assert + await Assert.ThrowsAsync(() => + aggregator.ShouldIncludeToolAsync(null!, context)); + } + + [Fact] + public async Task ShouldIncludeToolAsync_WithNullContext_ThrowsArgumentNullException() + { + // Arrange + var serviceProvider = Mock.Of(); + var aggregator = new ToolFilterAggregator(serviceProvider); + var tool = CreateTestTool("test"); + + // Act & Assert + await Assert.ThrowsAsync(() => + aggregator.ShouldIncludeToolAsync(tool, null!)); + } + + [Fact] + public async Task ShouldIncludeToolAsync_WithNoFilters_ReturnsTrue() + { + // Arrange + var mockServiceProvider = new Mock(); + mockServiceProvider.Setup(sp => sp.GetServices()) + .Returns(Array.Empty()); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + var tool = CreateTestTool("test_tool"); + var context = CreateTestContext(); + + // Act + var result = await aggregator.ShouldIncludeToolAsync(tool, context); + + // Assert + Assert.True(result); + } + + [Fact] + public async Task ShouldIncludeToolAsync_WithAllowAllFilter_ReturnsTrue() + { + // Arrange + var filters = new IToolFilter[] { new AllowAllToolFilter() }; + var mockServiceProvider = new Mock(); + mockServiceProvider.Setup(sp => sp.GetServices()) + .Returns(filters); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + var tool = CreateTestTool("test_tool"); + var context = CreateTestContext(); + + // Act + var result = await aggregator.ShouldIncludeToolAsync(tool, context); + + // Assert + Assert.True(result); + } + + [Fact] + public async Task ShouldIncludeToolAsync_WithDenyAllFilter_ReturnsFalse() + { + // Arrange + var filters = new IToolFilter[] { new DenyAllToolFilter() }; + var mockServiceProvider = new Mock(); + mockServiceProvider.Setup(sp => sp.GetServices()) + .Returns(filters); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + var tool = CreateTestTool("test_tool"); + var context = CreateTestContext(); + + // Act + var result = await aggregator.ShouldIncludeToolAsync(tool, context); + + // Assert + Assert.False(result); + } + + [Fact] + public async Task ShouldIncludeToolAsync_WithMultipleFilters_AppliesPriorityOrder() + { + // Arrange + var filter1 = new AllowAllToolFilter(priority: 100); // Lower priority + var filter2 = new DenyAllToolFilter(priority: 1); // Higher priority (should execute first) + var filters = new IToolFilter[] { filter1, filter2 }; + + var mockServiceProvider = new Mock(); + mockServiceProvider.Setup(sp => sp.GetServices()) + .Returns(filters); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + var tool = CreateTestTool("test_tool"); + var context = CreateTestContext(); + + // Act + var result = await aggregator.ShouldIncludeToolAsync(tool, context); + + // Assert + // DenyAllToolFilter should execute first due to higher priority and block the tool + Assert.False(result); + } + + [Fact] + public async Task ShouldIncludeToolAsync_WithFilterException_ReturnsFalse() + { + // Arrange + var filters = new IToolFilter[] { new ExceptionThrowingFilter() }; + var mockServiceProvider = new Mock(); + mockServiceProvider.Setup(sp => sp.GetServices()) + .Returns(filters); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + var tool = CreateTestTool("test_tool"); + var context = CreateTestContext(); + + // Act + var result = await aggregator.ShouldIncludeToolAsync(tool, context); + + // Assert + Assert.False(result); + } + + [Fact] + public async Task ShouldIncludeToolAsync_ExcludesSelfFromFilters() + { + // Arrange + var otherFilter = new AllowAllToolFilter(); + var mockServiceProvider = new Mock(); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + + // Setup service provider to return the aggregator itself along with other filters + var filters = new IToolFilter[] { aggregator, otherFilter }; + mockServiceProvider.Setup(sp => sp.GetServices()) + .Returns(filters); + + var tool = CreateTestTool("test_tool"); + var context = CreateTestContext(); + + // Act + var result = await aggregator.ShouldIncludeToolAsync(tool, context); + + // Assert + // Should not cause infinite recursion and should work with the other filter + Assert.True(result); + } + + [Fact] + public async Task CanExecuteToolAsync_WithNullOrEmptyToolName_ThrowsArgumentException() + { + // Arrange + var serviceProvider = Mock.Of(); + var aggregator = new ToolFilterAggregator(serviceProvider); + var context = CreateTestContext(); + + // Act & Assert + await Assert.ThrowsAsync(() => + aggregator.CanExecuteToolAsync(null!, context)); + await Assert.ThrowsAsync(() => + aggregator.CanExecuteToolAsync("", context)); + await Assert.ThrowsAsync(() => + aggregator.CanExecuteToolAsync(" ", context)); + } + + [Fact] + public async Task CanExecuteToolAsync_WithNullContext_ThrowsArgumentNullException() + { + // Arrange + var serviceProvider = Mock.Of(); + var aggregator = new ToolFilterAggregator(serviceProvider); + + // Act & Assert + await Assert.ThrowsAsync(() => + aggregator.CanExecuteToolAsync("test_tool", null!)); + } + + [Fact] + public async Task CanExecuteToolAsync_WithNoFilters_ReturnsAllow() + { + // Arrange + var mockServiceProvider = new Mock(); + mockServiceProvider.Setup(sp => sp.GetServices()) + .Returns(Array.Empty()); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + var context = CreateTestContext(); + + // Act + var result = await aggregator.CanExecuteToolAsync("test_tool", context); + + // Assert + Assert.True(result.IsAuthorized); + Assert.Equal("No filters configured", result.Reason); + } + + [Fact] + public async Task CanExecuteToolAsync_WithAllowAllFilter_ReturnsAllow() + { + // Arrange + var filters = new IToolFilter[] { new AllowAllToolFilter() }; + var mockServiceProvider = new Mock(); + mockServiceProvider.Setup(sp => sp.GetServices()) + .Returns(filters); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + var context = CreateTestContext(); + + // Act + var result = await aggregator.CanExecuteToolAsync("test_tool", context); + + // Assert + Assert.True(result.IsAuthorized); + Assert.Equal("All aggregated filters passed", result.Reason); + } + + [Fact] + public async Task CanExecuteToolAsync_WithDenyAllFilter_ReturnsDeny() + { + // Arrange + var filters = new IToolFilter[] { new DenyAllToolFilter() }; + var mockServiceProvider = new Mock(); + mockServiceProvider.Setup(sp => sp.GetServices()) + .Returns(filters); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + var context = CreateTestContext(); + + // Act + var result = await aggregator.CanExecuteToolAsync("test_tool", context); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal("All tools denied", result.Reason); + } + + [Fact] + public async Task CanExecuteToolAsync_WithMultipleFilters_StopsOnFirstDeny() + { + // Arrange + var mockFilter1 = new Mock(); + var mockFilter2 = new Mock(); + + mockFilter1.Setup(f => f.Priority).Returns(1); + mockFilter1.Setup(f => f.CanExecuteToolAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync(AuthorizationResult.Deny("Access denied")); + + mockFilter2.Setup(f => f.Priority).Returns(2); + + var filters = new IToolFilter[] { mockFilter1.Object, mockFilter2.Object }; + var mockServiceProvider = new Mock(); + mockServiceProvider.Setup(sp => sp.GetServices()) + .Returns(filters); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + var context = CreateTestContext(); + + // Act + var result = await aggregator.CanExecuteToolAsync("test_tool", context); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal("Access denied", result.Reason); + + mockFilter1.Verify(f => f.CanExecuteToolAsync("test_tool", context, It.IsAny()), Times.Once); + // Second filter should not be called since first filter denied + mockFilter2.Verify(f => f.CanExecuteToolAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Never); + } + + [Fact] + public async Task CanExecuteToolAsync_WithFilterException_ReturnsDeny() + { + // Arrange + var filters = new IToolFilter[] { new ExceptionThrowingFilter() }; + var mockServiceProvider = new Mock(); + mockServiceProvider.Setup(sp => sp.GetServices()) + .Returns(filters); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + var context = CreateTestContext(); + + // Act + var result = await aggregator.CanExecuteToolAsync("test_tool", context); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Contains("Filter error", result.Reason); + Assert.Contains("ExceptionThrowingFilter", result.Reason); + } + + [Fact] + public async Task FilterOperations_WithServiceProviderException_HandleGracefully() + { + // Arrange + var mockServiceProvider = new Mock(); + mockServiceProvider.Setup(sp => sp.GetServices()) + .Throws(new InvalidOperationException("Service resolution failed")); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + var tool = CreateTestTool("test_tool"); + var context = CreateTestContext(); + + // Act + var includeResult = await aggregator.ShouldIncludeToolAsync(tool, context); + var executeResult = await aggregator.CanExecuteToolAsync("test_tool", context); + + // Assert + // Should return safe defaults when service resolution fails + Assert.True(includeResult); + Assert.True(executeResult.IsAuthorized); + Assert.Equal("No filters configured", executeResult.Reason); + } + + [Fact] + public void ClearCache_ClearsFilterCache() + { + // Arrange + var filters = new IToolFilter[] { new AllowAllToolFilter() }; + var mockServiceProvider = new Mock(); + + // Setup to return different results on subsequent calls + var setupSequence = mockServiceProvider.SetupSequence(sp => sp.GetServices()); + setupSequence.Returns(filters); + setupSequence.Returns(Array.Empty()); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + + // Act & Assert + // First call should cache the filters + _ = aggregator.ShouldIncludeToolAsync(CreateTestTool("test"), CreateTestContext()); + + // Clear cache + aggregator.ClearCache(); + + // Second call should get new filters from service provider + _ = aggregator.ShouldIncludeToolAsync(CreateTestTool("test"), CreateTestContext()); + + // Verify service provider was called twice (once for initial, once after cache clear) + mockServiceProvider.Verify(sp => sp.GetServices(), Times.Exactly(2)); + } + + [Fact] + public async Task FilterAggregator_WithComplexPriorityScenario_AppliesCorrectOrder() + { + // Arrange + var filters = new IToolFilter[] + { + new TestOrderedFilter("Filter1", priority: 10, allow: true), + new TestOrderedFilter("Filter2", priority: 5, allow: true), // Should execute first + new TestOrderedFilter("Filter3", priority: 15, allow: false), // Should execute last but won't be reached + new TestOrderedFilter("Filter4", priority: 7, allow: false) // Should execute second and deny + }; + + var mockServiceProvider = new Mock(); + mockServiceProvider.Setup(sp => sp.GetServices()) + .Returns(filters); + + var aggregator = new ToolFilterAggregator(mockServiceProvider.Object); + var tool = CreateTestTool("test_tool"); + var context = CreateTestContext(); + + // Act + var result = await aggregator.ShouldIncludeToolAsync(tool, context); + + // Assert + Assert.False(result); // Should be denied by Filter4 + + // Verify execution order by checking which filters were called + var filter2 = (TestOrderedFilter)filters[1]; // Priority 5 + var filter4 = (TestOrderedFilter)filters[3]; // Priority 7 + var filter1 = (TestOrderedFilter)filters[0]; // Priority 10 + var filter3 = (TestOrderedFilter)filters[2]; // Priority 15 + + Assert.True(filter2.WasCalled); // Should be called first + Assert.True(filter4.WasCalled); // Should be called second and deny + Assert.False(filter1.WasCalled); // Should not be called due to earlier denial + Assert.False(filter3.WasCalled); // Should not be called due to earlier denial + } + + [Fact] + public async Task FilterAggregator_WithRealWorldDIContainer_WorksCorrectly() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(new AllowAllToolFilter(priority: 100)); + services.AddSingleton(new ToolNamePatternFilter(new[] { "admin_*" }, allowMatching: false, priority: 1)); + services.AddSingleton(new DenyAllToolFilter(priority: 50)); + + var serviceProvider = services.BuildServiceProvider(); + var aggregator = new ToolFilterAggregator(serviceProvider); + + var adminTool = CreateTestTool("admin_delete"); + var userTool = CreateTestTool("user_profile"); + var context = CreateTestContext(); + + // Act + var adminResult = await aggregator.ShouldIncludeToolAsync(adminTool, context); + var userResult = await aggregator.ShouldIncludeToolAsync(userTool, context); + + // Assert + // Admin tool should be blocked by pattern filter (priority 1) + Assert.False(adminResult); + + // User tool should be blocked by deny all filter (priority 50, executes after pattern filter allows it) + Assert.False(userResult); + } + + private static Tool CreateTestTool(string name, string? description = null) + { + return new Tool + { + Name = name, + Description = description ?? $"Test tool: {name}", + InputSchema = new JsonObject() + }; + } + + private static ToolAuthorizationContext CreateTestContext(string? sessionId = null, string? userId = null) + { + var context = ToolAuthorizationContext.ForSession(sessionId ?? "test-session"); + if (userId != null) + { + context.UserId = userId; + } + return context; + } + + /// + /// Test filter that throws exceptions to test error handling. + /// + private class ExceptionThrowingFilter : IToolFilter + { + public int Priority => 1; + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException("Test exception in ShouldIncludeToolAsync"); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException("Test exception in CanExecuteToolAsync"); + } + } + + /// + /// Test filter that tracks execution order and allows testing priority handling. + /// + private class TestOrderedFilter : IToolFilter + { + private readonly string _name; + private readonly bool _allowResult; + + public TestOrderedFilter(string name, int priority, bool allow) + { + _name = name; + Priority = priority; + _allowResult = allow; + } + + public int Priority { get; } + public bool WasCalled { get; private set; } + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + WasCalled = true; + return Task.FromResult(_allowResult); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + WasCalled = true; + return Task.FromResult(_allowResult + ? AuthorizationResult.Allow($"Allowed by {_name}") + : AuthorizationResult.Deny($"Denied by {_name}")); + } + + public override string ToString() => $"{_name} (Priority: {Priority})"; + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/Authorization/ToolFilterTests.cs b/tests/ModelContextProtocol.Tests/Server/Authorization/ToolFilterTests.cs new file mode 100644 index 00000000..cf7bcaf2 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/Authorization/ToolFilterTests.cs @@ -0,0 +1,561 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server.Authorization; +using Xunit; + +namespace ModelContextProtocol.Tests.Server.Authorization; + +/// +/// Unit tests for IToolFilter implementations. +/// +public class ToolFilterTests +{ + [Fact] + public async Task AllowAllToolFilter_ShouldIncludeToolAsync_ReturnsTrue() + { + // Arrange + var filter = new AllowAllToolFilter(); + var tool = CreateTestTool("test_tool"); + var context = CreateTestContext(); + + // Act + var result = await filter.ShouldIncludeToolAsync(tool, context); + + // Assert + Assert.True(result); + } + + [Fact] + public async Task AllowAllToolFilter_CanExecuteToolAsync_ReturnsAllow() + { + // Arrange + var filter = new AllowAllToolFilter(); + var context = CreateTestContext(); + + // Act + var result = await filter.CanExecuteToolAsync("test_tool", context); + + // Assert + Assert.True(result.IsAuthorized); + Assert.Equal("All tools allowed", result.Reason); + } + + [Fact] + public void AllowAllToolFilter_Priority_DefaultsToMaxValue() + { + // Arrange & Act + var filter = new AllowAllToolFilter(); + + // Assert + Assert.Equal(int.MaxValue, filter.Priority); + } + + [Fact] + public void AllowAllToolFilter_Priority_CanBeSet() + { + // Arrange + const int expectedPriority = 100; + + // Act + var filter = new AllowAllToolFilter(expectedPriority); + + // Assert + Assert.Equal(expectedPriority, filter.Priority); + } + + [Fact] + public async Task DenyAllToolFilter_ShouldIncludeToolAsync_ReturnsFalse() + { + // Arrange + var filter = new DenyAllToolFilter(); + var tool = CreateTestTool("test_tool"); + var context = CreateTestContext(); + + // Act + var result = await filter.ShouldIncludeToolAsync(tool, context); + + // Assert + Assert.False(result); + } + + [Fact] + public async Task DenyAllToolFilter_CanExecuteToolAsync_ReturnsDeny() + { + // Arrange + var filter = new DenyAllToolFilter(); + var context = CreateTestContext(); + + // Act + var result = await filter.CanExecuteToolAsync("test_tool", context); + + // Assert + Assert.False(result.IsAuthorized); + Assert.Equal("All tools denied", result.Reason); + } + + [Fact] + public void DenyAllToolFilter_Priority_DefaultsToZero() + { + // Arrange & Act + var filter = new DenyAllToolFilter(); + + // Assert + Assert.Equal(0, filter.Priority); + } + + [Fact] + public void DenyAllToolFilter_Priority_CanBeSet() + { + // Arrange + const int expectedPriority = 50; + + // Act + var filter = new DenyAllToolFilter(expectedPriority); + + // Assert + Assert.Equal(expectedPriority, filter.Priority); + } + + [Theory] + [InlineData("admin_delete_tool")] + [InlineData("admin_modify_user")] + [InlineData("private_data_access")] + [InlineData("delete_all_files")] + public async Task OAuthToolFilter_HighPrivilegeTools_RequiresScope(string toolName) + { + // Arrange + var filter = new TestOAuthToolFilter("write:admin", "test-realm"); + var tool = CreateTestTool(toolName); + var context = CreateTestContext(); + + // Act + var includeResult = await filter.ShouldIncludeToolAsync(tool, context); + var executeResult = await filter.CanExecuteToolAsync(toolName, context); + + // Assert + Assert.False(includeResult); + Assert.False(executeResult.IsAuthorized); + Assert.Contains("Insufficient scope", executeResult.Reason); + Assert.IsType(executeResult.AdditionalData); + + var challenge = (AuthorizationChallenge)executeResult.AdditionalData; + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains("scope=\"write:admin\"", challenge.WwwAuthenticateValue); + Assert.Contains("error=\"insufficient_scope\"", challenge.WwwAuthenticateValue); + Assert.Contains("realm=\"test-realm\"", challenge.WwwAuthenticateValue); + } + + [Theory] + [InlineData("secure_tool")] + [InlineData("user_data_tool")] + [InlineData("private_tool")] + public async Task OAuthToolFilter_SecureTools_RequiresValidToken(string toolName) + { + // Arrange + var filter = new TestOAuthToolFilter("read:basic", "test-realm"); + var tool = CreateTestTool(toolName); + var context = CreateTestContext(); + + // Act + var includeResult = await filter.ShouldIncludeToolAsync(tool, context); + var executeResult = await filter.CanExecuteToolAsync(toolName, context); + + // Assert + Assert.False(includeResult); + Assert.False(executeResult.IsAuthorized); + Assert.Contains("Invalid or expired token", executeResult.Reason); + Assert.IsType(executeResult.AdditionalData); + + var challenge = (AuthorizationChallenge)executeResult.AdditionalData; + Assert.Contains("Bearer", challenge.WwwAuthenticateValue); + Assert.Contains("error=\"invalid_token\"", challenge.WwwAuthenticateValue); + Assert.Contains("realm=\"test-realm\"", challenge.WwwAuthenticateValue); + } + + [Theory] + [InlineData("public_info_tool")] + [InlineData("read_only_tool")] + [InlineData("public_read_tool")] + public async Task OAuthToolFilter_PublicTools_AllowsAccess(string toolName) + { + // Arrange + var filter = new TestOAuthToolFilter("read:basic", "test-realm"); + var tool = CreateTestTool(toolName); + var context = CreateTestContext(); + + // Act + var includeResult = await filter.ShouldIncludeToolAsync(tool, context); + var executeResult = await filter.CanExecuteToolAsync(toolName, context); + + // Assert + Assert.True(includeResult); + Assert.True(executeResult.IsAuthorized); + Assert.Equal("Valid credentials", executeResult.Reason); + } + + [Fact] + public async Task ToolNamePatternFilter_MatchingPattern_AllowsAccess() + { + // Arrange + var filter = new ToolNamePatternFilter(new[] { "read_*", "get_*" }, allowMatching: true); + var tool = CreateTestTool("read_data"); + var context = CreateTestContext(); + + // Act + var includeResult = await filter.ShouldIncludeToolAsync(tool, context); + var executeResult = await filter.CanExecuteToolAsync("get_info", context); + + // Assert + Assert.True(includeResult); + Assert.True(executeResult.IsAuthorized); + } + + [Fact] + public async Task ToolNamePatternFilter_NonMatchingPattern_DeniesAccess() + { + // Arrange + var filter = new ToolNamePatternFilter(new[] { "read_*", "get_*" }, allowMatching: true); + var tool = CreateTestTool("delete_data"); + var context = CreateTestContext(); + + // Act + var includeResult = await filter.ShouldIncludeToolAsync(tool, context); + var executeResult = await filter.CanExecuteToolAsync("delete_data", context); + + // Assert + Assert.False(includeResult); + Assert.False(executeResult.IsAuthorized); + } + + [Fact] + public async Task ToolNamePatternFilter_BlockPattern_DeniesMatchingTools() + { + // Arrange + var filter = new ToolNamePatternFilter(new[] { "delete_*", "remove_*" }, allowMatching: false); + var tool = CreateTestTool("delete_file"); + var context = CreateTestContext(); + + // Act + var includeResult = await filter.ShouldIncludeToolAsync(tool, context); + var executeResult = await filter.CanExecuteToolAsync("remove_user", context); + + // Assert + Assert.False(includeResult); + Assert.False(executeResult.IsAuthorized); + } + + [Fact] + public async Task ToolNamePatternFilter_BlockPattern_AllowsNonMatchingTools() + { + // Arrange + var filter = new ToolNamePatternFilter(new[] { "delete_*", "remove_*" }, allowMatching: false); + var tool = CreateTestTool("read_file"); + var context = CreateTestContext(); + + // Act + var includeResult = await filter.ShouldIncludeToolAsync(tool, context); + var executeResult = await filter.CanExecuteToolAsync("create_user", context); + + // Assert + Assert.True(includeResult); + Assert.True(executeResult.IsAuthorized); + } + + [Fact] + public async Task RoleBasedToolFilter_UserWithRequiredRole_AllowsAccess() + { + // Arrange + var filter = RoleBasedToolFilterBuilder.Create() + .RequireRole("admin") + .ForToolsMatching("admin_*") + .Build(); + + var tool = CreateTestTool("admin_panel"); + var context = CreateTestContext(); + context.UserRoles.Add("admin"); + + // Act + var includeResult = await filter.ShouldIncludeToolAsync(tool, context); + var executeResult = await filter.CanExecuteToolAsync("admin_delete", context); + + // Assert + Assert.True(includeResult); + Assert.True(executeResult.IsAuthorized); + } + + [Fact] + public async Task RoleBasedToolFilter_UserWithoutRequiredRole_DeniesAccess() + { + // Arrange + var filter = RoleBasedToolFilterBuilder.Create() + .RequireRole("admin") + .ForToolsMatching("admin_*") + .Build(); + + var tool = CreateTestTool("admin_panel"); + var context = CreateTestContext(); + context.UserRoles.Add("user"); + + // Act + var includeResult = await filter.ShouldIncludeToolAsync(tool, context); + var executeResult = await filter.CanExecuteToolAsync("admin_delete", context); + + // Assert + Assert.False(includeResult); + Assert.False(executeResult.IsAuthorized); + Assert.Contains("Required role", executeResult.Reason); + } + + [Fact] + public async Task RoleBasedToolFilter_NonMatchingTool_AllowsAccess() + { + // Arrange + var filter = RoleBasedToolFilterBuilder.Create() + .RequireRole("admin") + .ForToolsMatching("admin_*") + .Build(); + + var tool = CreateTestTool("user_profile"); + var context = CreateTestContext(); + context.UserRoles.Add("user"); + + // Act + var includeResult = await filter.ShouldIncludeToolAsync(tool, context); + var executeResult = await filter.CanExecuteToolAsync("user_profile", context); + + // Assert + Assert.True(includeResult); + Assert.True(executeResult.IsAuthorized); + } + + [Fact] + public async Task RoleBasedToolFilter_MultipleRoles_AllowsAccessWithAnyRole() + { + // Arrange + var filter = RoleBasedToolFilterBuilder.Create() + .RequireAnyRole("admin", "moderator") + .ForToolsMatching("*_manage") + .Build(); + + var tool = CreateTestTool("user_manage"); + var context = CreateTestContext(); + context.UserRoles.Add("moderator"); + + // Act + var includeResult = await filter.ShouldIncludeToolAsync(tool, context); + var executeResult = await filter.CanExecuteToolAsync("content_manage", context); + + // Assert + Assert.True(includeResult); + Assert.True(executeResult.IsAuthorized); + } + + [Fact] + public async Task RoleBasedToolFilter_AllRolesRequired_RequiresAllRoles() + { + // Arrange + var filter = RoleBasedToolFilterBuilder.Create() + .RequireAllRoles("admin", "security") + .ForToolsMatching("security_*") + .Build(); + + var tool = CreateTestTool("security_audit"); + var context = CreateTestContext(); + context.UserRoles.Add("admin"); + // Missing "security" role + + // Act + var includeResult = await filter.ShouldIncludeToolAsync(tool, context); + var executeResult = await filter.CanExecuteToolAsync("security_audit", context); + + // Assert + Assert.False(includeResult); + Assert.False(executeResult.IsAuthorized); + } + + [Fact] + public async Task RoleBasedToolFilter_AllRolesRequired_AllowsWithAllRoles() + { + // Arrange + var filter = RoleBasedToolFilterBuilder.Create() + .RequireAllRoles("admin", "security") + .ForToolsMatching("security_*") + .Build(); + + var tool = CreateTestTool("security_audit"); + var context = CreateTestContext(); + context.UserRoles.Add("admin"); + context.UserRoles.Add("security"); + + // Act + var includeResult = await filter.ShouldIncludeToolAsync(tool, context); + var executeResult = await filter.CanExecuteToolAsync("security_audit", context); + + // Assert + Assert.True(includeResult); + Assert.True(executeResult.IsAuthorized); + } + + [Fact] + public async Task CustomToolFilter_WithException_DeniesAccess() + { + // Arrange + var filter = new ExceptionThrowingFilter(); + var tool = CreateTestTool("test_tool"); + var context = CreateTestContext(); + + // Act & Assert + await Assert.ThrowsAsync(() => + filter.ShouldIncludeToolAsync(tool, context)); + + await Assert.ThrowsAsync(() => + filter.CanExecuteToolAsync("test_tool", context)); + } + + [Fact] + public async Task AsyncToolFilter_HandlesCancellation() + { + // Arrange + var filter = new SlowFilter(); + var tool = CreateTestTool("test_tool"); + var context = CreateTestContext(); + var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert + await Assert.ThrowsAsync(() => + filter.ShouldIncludeToolAsync(tool, context, cts.Token)); + + await Assert.ThrowsAsync(() => + filter.CanExecuteToolAsync("test_tool", context, cts.Token)); + } + + private static Tool CreateTestTool(string name, string? description = null) + { + return new Tool + { + Name = name, + Description = description ?? $"Test tool: {name}", + InputSchema = new JsonObject() + }; + } + + private static ToolAuthorizationContext CreateTestContext(string? sessionId = null, string? userId = null) + { + var context = ToolAuthorizationContext.ForSession(sessionId ?? "test-session"); + if (userId != null) + { + context.UserId = userId; + } + return context; + } + + /// + /// Test filter that throws exceptions to test error handling. + /// + private class ExceptionThrowingFilter : IToolFilter + { + public int Priority => 1; + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException("Test exception"); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException("Test exception"); + } + } + + /// + /// Test filter that simulates slow operations for testing cancellation. + /// + private class SlowFilter : IToolFilter + { + public int Priority => 1; + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + await Task.Delay(5000, cancellationToken); + return true; + } + + public async Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + await Task.Delay(5000, cancellationToken); + return AuthorizationResult.Allow(); + } + } + + /// + /// Test implementation of OAuth tool filter demonstrating proper authorization challenge handling. + /// + private class TestOAuthToolFilter : IToolFilter + { + private readonly string _requiredScope; + private readonly string? _realm; + + public TestOAuthToolFilter(string requiredScope, string? realm = null) + { + _requiredScope = requiredScope ?? throw new ArgumentNullException(nameof(requiredScope)); + _realm = realm; + } + + public int Priority => 100; + + public async Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + var result = await CanExecuteToolAsync(tool.Name, context, cancellationToken); + return result.IsAuthorized; + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + // Simulate checking if the user has the required scope + if (IsHighPrivilegeTool(toolName)) + { + // For high-privilege tools, require specific scope + if (!HasRequiredScope(context)) + { + return Task.FromResult(AuthorizationResult.DenyInsufficientScope(_requiredScope, _realm)); + } + } + else if (RequiresAuthentication(toolName)) + { + // For tools that require authentication, check for valid token + if (!HasValidToken(context)) + { + return Task.FromResult(AuthorizationResult.DenyInvalidToken(_realm)); + } + } + + // Tool is authorized + return Task.FromResult(AuthorizationResult.Allow("Valid credentials")); + } + + private static bool IsHighPrivilegeTool(string toolName) + { + return toolName.Contains("delete", StringComparison.OrdinalIgnoreCase) || + toolName.Contains("admin", StringComparison.OrdinalIgnoreCase) || + toolName.Contains("private", StringComparison.OrdinalIgnoreCase); + } + + private static bool RequiresAuthentication(string toolName) + { + return !toolName.Contains("public", StringComparison.OrdinalIgnoreCase) && + !toolName.Contains("read", StringComparison.OrdinalIgnoreCase); + } + + private bool HasRequiredScope(ToolAuthorizationContext context) + { + // For testing, always deny high-privilege tools to demonstrate challenge + return false; + } + + private bool HasValidToken(ToolAuthorizationContext context) + { + // For testing, always deny to demonstrate invalid token challenge + return false; + } + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/Authorization/ToolFilteringIntegrationTests.cs b/tests/ModelContextProtocol.Tests/Server/Authorization/ToolFilteringIntegrationTests.cs new file mode 100644 index 00000000..3b052799 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/Authorization/ToolFilteringIntegrationTests.cs @@ -0,0 +1,439 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Server.Authorization; +using ModelContextProtocol.Tests.Utils; +using System.Text.Json.Nodes; +using Xunit; + +namespace ModelContextProtocol.Tests.Server.Authorization; + +/// +/// Integration tests for end-to-end tool filtering in MCP server operations. +/// +public class ToolFilteringIntegrationTests : LoggedTest +{ + public ToolFilteringIntegrationTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) + { + } + + [Fact] + public async Task ListTools_WithNoFilters_ReturnsAllTools() + { + // Arrange + using var serverTransport = new TestServerTransport(); + using var clientTransport = serverTransport.GetClientTransport(); + + var serverOptions = CreateServerOptions(); + AddTestTools(serverOptions); + + using var server = McpServerFactory.CreateServer(serverTransport, serverOptions, LoggerFactory); + using var client = McpClientFactory.CreateClient(clientTransport, LoggerFactory); + + await server.StartAsync(TestContext.Current.CancellationToken); + await client.ConnectAsync(TestContext.Current.CancellationToken); + + // Act + var result = await client.ListToolsAsync(TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal(3, result.Tools.Count); + Assert.Contains(result.Tools, t => t.Name == "read_tool"); + Assert.Contains(result.Tools, t => t.Name == "write_tool"); + Assert.Contains(result.Tools, t => t.Name == "admin_tool"); + } + + [Fact] + public async Task ListTools_WithDenyAllFilter_ReturnsNoTools() + { + // Arrange + using var serverTransport = new TestServerTransport(); + using var clientTransport = serverTransport.GetClientTransport(); + + var serverOptions = CreateServerOptions(); + AddTestTools(serverOptions); + + // Add authorization service with deny all filter + serverOptions.ServiceCollection?.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + authService.RegisterFilter(new DenyAllToolFilter()); + return authService; + }); + + using var server = McpServerFactory.CreateServer(serverTransport, serverOptions, LoggerFactory); + using var client = McpClientFactory.CreateClient(clientTransport, LoggerFactory); + + await server.StartAsync(TestContext.Current.CancellationToken); + await client.ConnectAsync(TestContext.Current.CancellationToken); + + // Act + var result = await client.ListToolsAsync(TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Empty(result.Tools); + } + + [Fact] + public async Task ListTools_WithPatternFilter_ReturnsFilteredTools() + { + // Arrange + using var serverTransport = new TestServerTransport(); + using var clientTransport = serverTransport.GetClientTransport(); + + var serverOptions = CreateServerOptions(); + AddTestTools(serverOptions); + + // Add authorization service with pattern filter + serverOptions.ServiceCollection?.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + // Only allow read tools + authService.RegisterFilter(new ToolNamePatternFilter(new[] { "read_*" }, allowMatching: true)); + return authService; + }); + + using var server = McpServerFactory.CreateServer(serverTransport, serverOptions, LoggerFactory); + using var client = McpClientFactory.CreateClient(clientTransport, LoggerFactory); + + await server.StartAsync(TestContext.Current.CancellationToken); + await client.ConnectAsync(TestContext.Current.CancellationToken); + + // Act + var result = await client.ListToolsAsync(TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Single(result.Tools); + Assert.Equal("read_tool", result.Tools[0].Name); + } + + [Fact] + public async Task ListTools_WithDecoratorPattern_FiltersAllToolsFromAllSources() + { + // Arrange + using var serverTransport = new TestServerTransport(); + using var clientTransport = serverTransport.GetClientTransport(); + + var serverOptions = CreateServerOptions(); + + // Add tools through different sources to test decorator pattern + AddTestTools(serverOptions); + + // Add a custom list tools handler that adds additional tools (simulating another source) + serverOptions.ListToolsHandler = async (request, cancellationToken) => + { + return new ListToolsResult + { + Tools = { new Tool { Name = "external_tool", Description = "Tool from external source" } } + }; + }; + + // Add filter that blocks admin tools (should affect tools from all sources) + serverOptions.ServiceCollection?.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + authService.RegisterFilter(new ToolNamePatternFilter( + new[] { "^(?!admin_).*" }, // Allow all except admin_ tools + Array.Empty(), + priority: 100)); + return authService; + }); + + using var server = McpServerFactory.CreateServer(serverTransport, serverOptions, LoggerFactory); + using var client = McpClientFactory.CreateClient(clientTransport, LoggerFactory); + + await server.StartAsync(TestContext.Current.CancellationToken); + await client.ConnectAsync(TestContext.Current.CancellationToken); + + // Act + var result = await client.ListToolsAsync(TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(result); + // Should have tools from both sources, but admin tools should be filtered out + Assert.Contains(result.Tools, t => t.Name == "read_tool"); // From AddTestTools + Assert.Contains(result.Tools, t => t.Name == "external_tool"); // From custom handler + Assert.DoesNotContain(result.Tools, t => t.Name.StartsWith("admin_")); // Filtered out + } + + [Fact] + public async Task ListTools_WithMultipleFilters_AppliesPriorityOrder() + { + // Arrange + using var serverTransport = new TestServerTransport(); + using var clientTransport = serverTransport.GetClientTransport(); + + var serverOptions = CreateServerOptions(); + AddTestTools(serverOptions); + + // Add authorization service with multiple filters + serverOptions.ServiceCollection?.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + // Higher priority filter blocks admin tools + authService.RegisterFilter(new ToolNamePatternFilter(new[] { "admin_*" }, allowMatching: false, priority: 1)); + // Lower priority filter would allow all, but won't affect admin tools + authService.RegisterFilter(new AllowAllToolFilter(priority: 100)); + return authService; + }); + + using var server = McpServerFactory.CreateServer(serverTransport, serverOptions, LoggerFactory); + using var client = McpClientFactory.CreateClient(clientTransport, LoggerFactory); + + await server.StartAsync(TestContext.Current.CancellationToken); + await client.ConnectAsync(TestContext.Current.CancellationToken); + + // Act + var result = await client.ListToolsAsync(TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal(2, result.Tools.Count); + Assert.Contains(result.Tools, t => t.Name == "read_tool"); + Assert.Contains(result.Tools, t => t.Name == "write_tool"); + Assert.DoesNotContain(result.Tools, t => t.Name == "admin_tool"); + } + + [Fact] + public async Task CallTool_WithAllowAllFilter_Succeeds() + { + // Arrange + using var serverTransport = new TestServerTransport(); + using var clientTransport = serverTransport.GetClientTransport(); + + var serverOptions = CreateServerOptions(); + AddTestTools(serverOptions); + + serverOptions.ServiceCollection?.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + authService.RegisterFilter(new AllowAllToolFilter()); + return authService; + }); + + using var server = McpServerFactory.CreateServer(serverTransport, serverOptions, LoggerFactory); + using var client = McpClientFactory.CreateClient(clientTransport, LoggerFactory); + + await server.StartAsync(TestContext.Current.CancellationToken); + await client.ConnectAsync(TestContext.Current.CancellationToken); + + // Act + var result = await client.CallToolAsync("read_tool", new { }, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(result); + Assert.NotNull(result.Content); + Assert.Single(result.Content); + var textContent = Assert.IsType(result.Content[0]); + Assert.Equal("Read operation completed", textContent.Text); + } + + [Fact] + public async Task CallTool_WithDenyAllFilter_ThrowsException() + { + // Arrange + using var serverTransport = new TestServerTransport(); + using var clientTransport = serverTransport.GetClientTransport(); + + var serverOptions = CreateServerOptions(); + AddTestTools(serverOptions); + + serverOptions.ServiceCollection?.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + authService.RegisterFilter(new DenyAllToolFilter()); + return authService; + }); + + using var server = McpServerFactory.CreateServer(serverTransport, serverOptions, LoggerFactory); + using var client = McpClientFactory.CreateClient(clientTransport, LoggerFactory); + + await server.StartAsync(TestContext.Current.CancellationToken); + await client.ConnectAsync(TestContext.Current.CancellationToken); + + // Act & Assert + var exception = await Assert.ThrowsAsync(async () => + await client.CallToolAsync("read_tool", new { }, TestContext.Current.CancellationToken)); + + Assert.Contains("denied", exception.Message.ToLowerInvariant()); + } + + [Fact] + public async Task CallTool_WithRoleBasedFilter_RequiresCorrectRole() + { + // Arrange + using var serverTransport = new TestServerTransport(); + using var clientTransport = serverTransport.GetClientTransport(); + + var serverOptions = CreateServerOptions(); + AddTestTools(serverOptions); + + serverOptions.ServiceCollection?.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + // Require admin role for admin tools + var filter = RoleBasedToolFilterBuilder.Create() + .RequireRole("admin") + .ForToolsMatching("admin_*") + .Build(); + authService.RegisterFilter(filter); + return authService; + }); + + using var server = McpServerFactory.CreateServer(serverTransport, serverOptions, LoggerFactory); + using var client = McpClientFactory.CreateClient(clientTransport, LoggerFactory); + + await server.StartAsync(TestContext.Current.CancellationToken); + await client.ConnectAsync(TestContext.Current.CancellationToken); + + // Act & Assert - Should fail without admin role + var exception = await Assert.ThrowsAsync(async () => + await client.CallToolAsync("admin_tool", new { }, TestContext.Current.CancellationToken)); + + Assert.Contains("role", exception.Message.ToLowerInvariant()); + + // Should succeed for non-admin tools + var result = await client.CallToolAsync("read_tool", new { }, TestContext.Current.CancellationToken); + Assert.NotNull(result); + } + + [Fact] + public async Task ToolFiltering_WithFilterException_HandlesGracefully() + { + // Arrange + using var serverTransport = new TestServerTransport(); + using var clientTransport = serverTransport.GetClientTransport(); + + var serverOptions = CreateServerOptions(); + AddTestTools(serverOptions); + + serverOptions.ServiceCollection?.AddSingleton(sp => + { + var authService = new ToolAuthorizationService(sp.GetService>()); + authService.RegisterFilter(new ExceptionThrowingFilter()); + return authService; + }); + + using var server = McpServerFactory.CreateServer(serverTransport, serverOptions, LoggerFactory); + using var client = McpClientFactory.CreateClient(clientTransport, LoggerFactory); + + await server.StartAsync(TestContext.Current.CancellationToken); + await client.ConnectAsync(TestContext.Current.CancellationToken); + + // Act + var listResult = await client.ListToolsAsync(TestContext.Current.CancellationToken); + + // Assert - Should handle filter exceptions gracefully + Assert.NotNull(listResult); + Assert.Empty(listResult.Tools); // Tools should be filtered out due to exception + + // CallTool should also fail gracefully + var exception = await Assert.ThrowsAsync(async () => + await client.CallToolAsync("read_tool", new { }, TestContext.Current.CancellationToken)); + + Assert.Contains("error", exception.Message.ToLowerInvariant()); + } + + [Fact] + public async Task ToolFiltering_WithDIRegisteredFilters_WorksCorrectly() + { + // Arrange + using var serverTransport = new TestServerTransport(); + using var clientTransport = serverTransport.GetClientTransport(); + + var serverOptions = CreateServerOptions(); + AddTestTools(serverOptions); + + // Register filters via DI + serverOptions.ServiceCollection?.AddSingleton(new AllowAllToolFilter(priority: 100)); + serverOptions.ServiceCollection?.AddSingleton(new ToolNamePatternFilter(new[] { "admin_*" }, allowMatching: false, priority: 1)); + serverOptions.ServiceCollection?.AddSingleton(); + + using var server = McpServerFactory.CreateServer(serverTransport, serverOptions, LoggerFactory); + using var client = McpClientFactory.CreateClient(clientTransport, LoggerFactory); + + await server.StartAsync(TestContext.Current.CancellationToken); + await client.ConnectAsync(TestContext.Current.CancellationToken); + + // Act + var result = await client.ListToolsAsync(TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal(2, result.Tools.Count); // Should filter out admin_tool + Assert.Contains(result.Tools, t => t.Name == "read_tool"); + Assert.Contains(result.Tools, t => t.Name == "write_tool"); + Assert.DoesNotContain(result.Tools, t => t.Name == "admin_tool"); + } + + private static McpServerOptions CreateServerOptions() + { + var options = new McpServerOptions + { + ServerInfo = new Implementation { Name = "TestServer", Version = "1.0.0" }, + ServiceCollection = new ServiceCollection() + }; + + // Add basic logging + options.ServiceCollection.AddLogging(); + + return options; + } + + private static void AddTestTools(McpServerOptions options) + { + // Add test tools to the server + options.Capabilities.Tools ??= new(); + options.Capabilities.Tools.ToolCollection ??= new(); + + // Add read tool + var readTool = McpServerTool.Create(() => new CallToolResult + { + Content = [new TextResourceContents { Text = "Read operation completed" }] + }); + readTool.ProtocolTool.Name = "read_tool"; + readTool.ProtocolTool.Description = "Reads data"; + options.Capabilities.Tools.ToolCollection.Add(readTool); + + // Add write tool + var writeTool = McpServerTool.Create(() => new CallToolResult + { + Content = [new TextResourceContents { Text = "Write operation completed" }] + }); + writeTool.ProtocolTool.Name = "write_tool"; + writeTool.ProtocolTool.Description = "Writes data"; + options.Capabilities.Tools.ToolCollection.Add(writeTool); + + // Add admin tool + var adminTool = McpServerTool.Create(() => new CallToolResult + { + Content = [new TextResourceContents { Text = "Admin operation completed" }] + }); + adminTool.ProtocolTool.Name = "admin_tool"; + adminTool.ProtocolTool.Description = "Administrative operations"; + options.Capabilities.Tools.ToolCollection.Add(adminTool); + } + + /// + /// Test filter that throws exceptions to test error handling. + /// + private class ExceptionThrowingFilter : IToolFilter + { + public int Priority => 1; + + public Task ShouldIncludeToolAsync(Tool tool, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException("Test exception in ShouldIncludeToolAsync"); + } + + public Task CanExecuteToolAsync(string toolName, ToolAuthorizationContext context, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException("Test exception in CanExecuteToolAsync"); + } + } +} \ No newline at end of file