diff --git a/Foundatio.Mediator.slnx b/Foundatio.Mediator.slnx index 2b7557c..8787efa 100644 --- a/Foundatio.Mediator.slnx +++ b/Foundatio.Mediator.slnx @@ -8,12 +8,13 @@ - - + + + diff --git a/samples/ConsoleSample/ConsoleSample.csproj b/samples/ConsoleSample/ConsoleSample.csproj index f0d5613..41c3305 100644 --- a/samples/ConsoleSample/ConsoleSample.csproj +++ b/samples/ConsoleSample/ConsoleSample.csproj @@ -34,6 +34,12 @@ + + + + + + diff --git a/samples/ConsoleSample/Messages/Messages.cs b/samples/ConsoleSample/Messages/Messages.cs index af686cb..a5e7c35 100644 --- a/samples/ConsoleSample/Messages/Messages.cs +++ b/samples/ConsoleSample/Messages/Messages.cs @@ -3,11 +3,14 @@ namespace ConsoleSample.Messages; #region Simple +/// Ping endpoint used to test connectivity. public record Ping(string Text); +/// Gets a personalized greeting for the specified user. public record GetGreeting(string Name); #endregion // Order CRUD messages +/// Creates a new order for the specified customer. public record CreateOrder( [Required(ErrorMessage = "Customer ID is required")] [StringLength(50, MinimumLength = 3, ErrorMessage = "Customer ID must be between 3 and 50 characters")] @@ -20,8 +23,11 @@ public record CreateOrder( [Required(ErrorMessage = "Description is required")] [StringLength(200, MinimumLength = 5, ErrorMessage = "Description must be between 5 and 200 characters")] string Description) : IValidatable; +/// Gets an existing order by identifier. public record GetOrder(string OrderId); +/// Updates mutable fields on an existing order. public record UpdateOrder(string OrderId, decimal? Amount, string? Description); +/// Deletes an existing order. public record DeleteOrder(string OrderId); // Event messages (for publish pattern) diff --git a/samples/ConsoleSample/Middleware/TransactionMiddleware.cs b/samples/ConsoleSample/Middleware/TransactionMiddleware.cs index 5e41712..d229212 100644 --- a/samples/ConsoleSample/Middleware/TransactionMiddleware.cs +++ b/samples/ConsoleSample/Middleware/TransactionMiddleware.cs @@ -23,10 +23,9 @@ public void After(CreateOrder cmd, IDbTransaction transaction, ILogger logger) { - if (transaction == null) + if (transaction is not FakeTransaction tx) return; - var tx = (FakeTransaction)transaction; if (result?.IsSuccess == true) return; diff --git a/samples/ConsoleSample/Program.cs b/samples/ConsoleSample/Program.cs index 7f0f6f1..fa8cd05 100644 --- a/samples/ConsoleSample/Program.cs +++ b/samples/ConsoleSample/Program.cs @@ -1,7 +1,15 @@ using ConsoleSample; using Foundatio.Mediator; +using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; +using Scalar.AspNetCore; + +if (args.Any(a => string.Equals(a, "--web", StringComparison.OrdinalIgnoreCase))) +{ + await RunMinimalApiAsync(args); + return; +} // Create application host var builder = Host.CreateApplicationBuilder(args); @@ -16,3 +24,21 @@ var sampleRunner = new SampleRunner(mediator, host.Services); await sampleRunner.RunAllSamplesAsync(); + +static async Task RunMinimalApiAsync(string[] args) +{ + var builder = WebApplication.CreateBuilder(args); + builder.Services.ConfigureServices(); + builder.Services.AddEndpointsApiExplorer(); + builder.Services.AddOpenApi(); + + var app = builder.Build(); + app.MapOpenApi(); + app.MapScalarApiReference(options => + { + options.Title = "Foundatio.Mediator Console Sample"; + }); + app.MapMediatorEndpoints(); + + await app.RunAsync(); +} diff --git a/src/Foundatio.Mediator.Abstractions/MediatorExtensions.cs b/src/Foundatio.Mediator.Abstractions/MediatorExtensions.cs index d24759d..2b578f2 100644 --- a/src/Foundatio.Mediator.Abstractions/MediatorExtensions.cs +++ b/src/Foundatio.Mediator.Abstractions/MediatorExtensions.cs @@ -16,7 +16,7 @@ public static IServiceCollection AddMediator(this IServiceCollection services, M if (configuration.Assemblies == null) { - configuration.Assemblies = AppDomain.CurrentDomain.GetAssemblies().Where(a => !a.IsDynamic && !a.FullName.StartsWith("System.")).ToList(); + configuration.Assemblies = AppDomain.CurrentDomain.GetAssemblies().Where(a => a is { IsDynamic: false, FullName: not null } && !a.FullName.StartsWith("System.")).ToList(); } services.Add(ServiceDescriptor.Describe(typeof(IMediator), typeof(Mediator), configuration.MediatorLifetime)); diff --git a/src/Foundatio.Mediator/EndpointGenerator.cs b/src/Foundatio.Mediator/EndpointGenerator.cs new file mode 100644 index 0000000..881e809 --- /dev/null +++ b/src/Foundatio.Mediator/EndpointGenerator.cs @@ -0,0 +1,345 @@ +using System.Collections.Generic; +using System.Text; +using Foundatio.Mediator.Models; +using Foundatio.Mediator.Utility; +using Microsoft.CodeAnalysis; + +namespace Foundatio.Mediator; + +internal static class EndpointGenerator +{ + private static readonly string[] SupportedPrefixes = ["Get", "Create", "Update", "Delete"]; + + public static void Execute(SourceProductionContext context, List handlers, Compilation compilation) + { + if (!SupportsMinimalApis(compilation)) + return; + + var endpointHandlers = handlers + .Where(IsEndpointCandidate) + .OrderBy(h => h.MessageType.FullName, StringComparer.Ordinal) + .ToList(); + + if (endpointHandlers.Count == 0) + return; + + bool hasAsParameters = compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Mvc.AsParametersAttribute") is not null + || compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Http.AsParametersAttribute") is not null; + bool hasFromBody = compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Mvc.FromBodyAttribute") is not null; + bool hasFromQuery = compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Mvc.FromQueryAttribute") is not null; + bool hasFromServices = compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Mvc.FromServicesAttribute") is not null; + bool hasOpenApi = compilation.GetTypeByMetadataName("Microsoft.AspNetCore.OpenApi.OpenApiRouteHandlerBuilderExtensions") is not null; + + var source = BuildSource(endpointHandlers, hasAsParameters, hasFromBody, hasFromQuery, hasFromServices, hasOpenApi); + context.AddSource("MediatorMinimalApiEndpoints.g.cs", source); + } + + private static string BuildSource(IReadOnlyList handlers, bool hasAsParameters, bool hasFromBody, bool hasFromQuery, bool hasFromServices, bool hasOpenApi) + { + var source = new IndentedStringBuilder(); + source.AddGeneratedFileHeader(); + + source.AppendLines(""" + using System; + using System.Collections.Generic; + using System.Linq; + using System.Runtime.CompilerServices; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.AspNetCore.Builder; + using Microsoft.AspNetCore.Http; + """); + + if (hasAsParameters || hasFromBody || hasFromQuery || hasFromServices) + source.AppendLine("using Microsoft.AspNetCore.Mvc;"); + + source.AppendLines(""" + using Microsoft.AspNetCore.Routing; + using Microsoft.Extensions.DependencyInjection; + + namespace Foundatio.Mediator; + + public static class MediatorEndpointExtensions + { + private const string DefaultBasePath = "/api/messages"; + + public static IEndpointRouteBuilder MapMediatorEndpoints(this IEndpointRouteBuilder app, string? basePath = null) + { + ArgumentNullException.ThrowIfNull(app); + var group = app.MapGroup(string.IsNullOrWhiteSpace(basePath) ? DefaultBasePath : basePath!); + RegisterMediatorEndpoints(group); + return app; + } + + private static void RegisterMediatorEndpoints(RouteGroupBuilder group) + { + ArgumentNullException.ThrowIfNull(group); + """); + + source.IncrementIndent(); + source.IncrementIndent(); + + var categoryOrder = new List(); + var handlersByCategory = new Dictionary>(StringComparer.Ordinal); + var categoryLabels = new Dictionary(StringComparer.Ordinal); + + foreach (var handler in handlers) + { + var categoryKey = handler.Category?.Trim() ?? string.Empty; + if (!handlersByCategory.TryGetValue(categoryKey, out var list)) + { + list = new List(); + handlersByCategory[categoryKey] = list; + categoryOrder.Add(categoryKey); + categoryLabels[categoryKey] = handler.Category?.Trim(); + } + + list.Add(handler); + } + + int endpointIndex = 0; + int categoryIndex = 0; + + for (int categoryOrderIndex = 0; categoryOrderIndex < categoryOrder.Count; categoryOrderIndex++) + { + var categoryKey = categoryOrder[categoryOrderIndex]; + var categoryHandlers = handlersByCategory[categoryKey]; + string? categoryLabel = categoryLabels[categoryKey]; + string groupAccessor = "group"; + + if (!string.IsNullOrEmpty(categoryKey) && !string.IsNullOrWhiteSpace(categoryLabel)) + { + string categoryVar = $"categoryGroup{categoryIndex++}"; + string routeSegment = ToCategoryRouteSegmentLiteral(categoryLabel!); + source.AppendLine($"var {categoryVar} = group.MapGroup(\"/{routeSegment}\");"); + source.AppendLine($"{categoryVar}.WithGroupName(\"{EscapeString(categoryLabel!)}\");"); + source.AppendLine(); + groupAccessor = categoryVar; + } + + foreach (var handler in categoryHandlers) + { + string handlerClassName = HandlerGenerator.GetHandlerClassName(handler); + string methodName = HandlerGenerator.GetHandlerDefaultMethodName(handler); + string mapMethod = GetMapMethod(handler); + string routeSegment = ToRouteSegment(handler.MessageType.Name); + string requestParameter = $"{handler.MessageType.FullName} request"; + bool handlerDefaultReturnsValue = HandlerGenerator.HandlerDefaultReturnsValue(handler); + bool handlerDefaultIsAsync = HandlerGenerator.HandlerDefaultIsAsync(handler); + + bool requiresQueryBinding = RequiresQueryBinding(mapMethod); + bool requiresBodyBinding = RequiresBodyBinding(mapMethod); + + if (requiresQueryBinding) + { + if (hasAsParameters) + { + requestParameter = "[AsParameters] " + requestParameter; + } + else if (hasFromQuery) + { + requestParameter = "[FromQuery] " + requestParameter; + } + } + else if (requiresBodyBinding && hasFromBody) + { + requestParameter = "[FromBody] " + requestParameter; + } + + string servicesParameter = hasFromServices ? "[FromServices] IServiceProvider services" : "IServiceProvider services"; + string mediatorParameter = hasFromServices ? "[FromServices] Foundatio.Mediator.IMediator mediator" : "Foundatio.Mediator.IMediator mediator"; + + string invocation = handlerDefaultIsAsync + ? $"await {handlerClassName}.{methodName}(mediator, services, request, cancellationToken).ConfigureAwait(false)" + : $"{handlerClassName}.{methodName}(mediator, services, request, cancellationToken)"; + + string endpointVar = $"endpoint{endpointIndex++}"; + + source.AppendLine($"var {endpointVar} = {groupAccessor}.{mapMethod}(\"/{routeSegment}\", async ({requestParameter}, {servicesParameter}, {mediatorParameter}, CancellationToken cancellationToken) =>"); + source.AppendLine("{"); + source.IncrementIndent(); + if (handlerDefaultReturnsValue) + { + source.AppendLine($"var handlerResult = {invocation};"); + source.AppendLine("return MediatorEndpointResultMapper.ToHttpResult(handlerResult);"); + } + else + { + source.AppendLine($"{invocation};"); + source.AppendLine("return MediatorEndpointResultMapper.ToHttpResult(null);"); + } + source.DecrementIndent(); + source.AppendLine("});"); + source.AppendLine($"{endpointVar}.WithName(\"{handler.MessageType.Name}\");"); + if (!string.IsNullOrWhiteSpace(handler.MessageSummary)) + { + var escapedSummary = EscapeString(handler.MessageSummary!); + source.AppendLine($"{endpointVar}.WithSummary(\"{escapedSummary}\");"); + source.AppendLine($"{endpointVar}.WithDescription(\"{escapedSummary}\");"); + } + if (hasOpenApi) + source.AppendLine($"{endpointVar}.WithOpenApi();"); + source.AppendLine(); + } + + if (categoryOrderIndex < categoryOrder.Count - 1) + source.AppendLine(); + } + + source.DecrementIndent(); + source.DecrementIndent(); + + source.AppendLines(""" + } + + private static bool RequiresQueryBinding(string mapMethod) => mapMethod is "MapGet" or "MapDelete"; + private static bool RequiresBodyBinding(string mapMethod) => mapMethod is "MapPost" or "MapPut"; + + private static class MediatorEndpointResultMapper + { + public static global::Microsoft.AspNetCore.Http.IResult ToHttpResult(object? value) + { + if (value is null) + return Results.NoContent(); + + if (value is Foundatio.Mediator.IResult mediatorResult) + return MapMediatorResult(mediatorResult); + + return Results.Ok(value); + } + + private static global::Microsoft.AspNetCore.Http.IResult MapMediatorResult(Foundatio.Mediator.IResult result) + { + return result.Status switch + { + ResultStatus.Success => result.GetValue() is { } value ? Results.Ok(value) : Results.Ok(), + ResultStatus.Created => Results.Created(string.IsNullOrWhiteSpace(result.Location) ? DefaultBasePath : result.Location, result.GetValue()), + ResultStatus.NoContent => Results.NoContent(), + ResultStatus.BadRequest => Results.BadRequest(result.Message), + ResultStatus.Invalid => Results.ValidationProblem(ToValidationState(result.ValidationErrors), detail: result.Message), + ResultStatus.NotFound => string.IsNullOrWhiteSpace(result.Message) ? Results.NotFound() : Results.NotFound(result.Message), + ResultStatus.Unauthorized => Results.Unauthorized(), + ResultStatus.Forbidden => Results.Forbid(), + ResultStatus.Conflict => Results.Conflict(result.Message), + ResultStatus.Error => Results.Problem(detail: result.Message), + ResultStatus.CriticalError => Results.Problem(detail: result.Message, statusCode: StatusCodes.Status500InternalServerError), + ResultStatus.Unavailable => Results.StatusCode(StatusCodes.Status503ServiceUnavailable), + _ => Results.Ok(result.GetValue()) + }; + } + + private static IDictionary ToValidationState(IEnumerable errors) + { + var dictionary = new Dictionary>(StringComparer.OrdinalIgnoreCase); + + foreach (var error in errors) + { + var key = string.IsNullOrWhiteSpace(error.Identifier) ? string.Empty : error.Identifier; + if (!dictionary.TryGetValue(key, out var list)) + { + list = new List(); + dictionary[key] = list; + } + + if (!string.IsNullOrWhiteSpace(error.ErrorMessage)) + list.Add(error.ErrorMessage); + } + + return dictionary.ToDictionary(kvp => kvp.Key, kvp => kvp.Value.ToArray()); + } + } + } + """); + + return source.ToString(); + } + + private static bool SupportsMinimalApis(Compilation compilation) + { + return compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Routing.IEndpointRouteBuilder") is not null + && compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Http.Results") is not null; + } + + private static bool IsEndpointCandidate(HandlerInfo handler) + { + foreach (var prefix in SupportedPrefixes) + { + if (handler.MessageType.Name.StartsWith(prefix, StringComparison.OrdinalIgnoreCase)) + return true; + } + + return false; + } + + private static string GetMapMethod(HandlerInfo handler) + { + var name = handler.MessageType.Name; + if (name.StartsWith("Get", StringComparison.OrdinalIgnoreCase)) + return "MapGet"; + if (name.StartsWith("Update", StringComparison.OrdinalIgnoreCase)) + return "MapPut"; + if (name.StartsWith("Create", StringComparison.OrdinalIgnoreCase)) + return "MapPost"; + if (name.StartsWith("Delete", StringComparison.OrdinalIgnoreCase)) + return "MapDelete"; + + return "MapPost"; + } + + private static string ToRouteSegment(string name) + { + if (string.IsNullOrEmpty(name)) + return name; + + var builder = new StringBuilder(); + for (int i = 0; i < name.Length; i++) + { + char c = name[i]; + if (char.IsUpper(c) && i > 0) + builder.Append('-'); + + builder.Append(char.ToLowerInvariant(c)); + } + + return builder.ToString(); + } + + private static string EscapeString(string value) + { + if (string.IsNullOrEmpty(value)) + return value; + + return value + .Replace("\\", "\\\\") + .Replace("\"", "\\\"") + .Replace('\r', ' ') + .Replace('\n', ' '); + } + + private static string ToCategoryRouteSegmentLiteral(string category) + { + if (string.IsNullOrWhiteSpace(category)) + return "uncategorized"; + + var builder = new StringBuilder(); + foreach (var c in category) + { + if (char.IsLetterOrDigit(c)) + { + builder.Append(char.ToLowerInvariant(c)); + continue; + } + + if (builder.Length == 0 || builder[builder.Length - 1] == '-') + continue; + + builder.Append('-'); + } + + var result = builder.ToString().Trim('-'); + return string.IsNullOrEmpty(result) ? "uncategorized" : result; + } + + private static bool RequiresQueryBinding(string mapMethod) => mapMethod is "MapGet" or "MapDelete"; + private static bool RequiresBodyBinding(string mapMethod) => mapMethod is "MapPost" or "MapPut"; +} diff --git a/src/Foundatio.Mediator/Foundatio.Mediator.csproj b/src/Foundatio.Mediator/Foundatio.Mediator.csproj index 0fc725f..9e630d8 100644 --- a/src/Foundatio.Mediator/Foundatio.Mediator.csproj +++ b/src/Foundatio.Mediator/Foundatio.Mediator.csproj @@ -1,4 +1,4 @@ - + @@ -14,7 +14,6 @@ - diff --git a/src/Foundatio.Mediator/HandlerAnalyzer.cs b/src/Foundatio.Mediator/HandlerAnalyzer.cs index 76631fd..539609d 100644 --- a/src/Foundatio.Mediator/HandlerAnalyzer.cs +++ b/src/Foundatio.Mediator/HandlerAnalyzer.cs @@ -1,3 +1,5 @@ +using System.Text; +using System.Xml.Linq; using Foundatio.Mediator.Models; using Foundatio.Mediator.Utility; @@ -208,12 +210,17 @@ public static List GetHandlers(GeneratorSyntaxContext context) bool hasConstructorParameters = !handlerMethod.IsStatic && classSymbol.InstanceConstructors.Any(c => c.Parameters.Length > 0); + string? messageSummary = GetSummaryComment(messageType); + string? category = GetCategory(messageType) ?? GetCategory(classSymbol); + handlers.Add(new HandlerInfo { Identifier = classSymbol.Name.ToIdentifier(), FullName = classSymbol.ToDisplayString(), MethodName = handlerMethod.Name, MessageType = TypeSymbolInfo.From(messageType, context.SemanticModel.Compilation), + MessageSummary = messageSummary, + Category = category, MessageInterfaces = new(messageInterfaces.ToArray()), MessageBaseClasses = new(messageBaseClasses.ToArray()), ReturnType = TypeSymbolInfo.From(handlerMethod.ReturnType, context.SemanticModel.Compilation), @@ -237,6 +244,204 @@ public static List GetHandlers(GeneratorSyntaxContext context) return handlers; } + private static string? GetSummaryComment(ISymbol symbol) + { + var summary = ParseSummaryFromXml(symbol.GetDocumentationCommentXml(expandIncludes: true)); + if (!String.IsNullOrWhiteSpace(summary)) + return summary; + + foreach (var syntaxRef in symbol.DeclaringSyntaxReferences) + { + if (syntaxRef.GetSyntax() is not CSharpSyntaxNode syntaxNode) + continue; + + var docTrivia = syntaxNode.GetLeadingTrivia() + .Select(t => t.GetStructure()) + .OfType() + .FirstOrDefault(); + + if (docTrivia == null) + { + var summaryFromSourceText = ParseSummaryFromSourceText(syntaxRef); + if (!String.IsNullOrWhiteSpace(summaryFromSourceText)) + return summaryFromSourceText; + + continue; + } + + var summaryElement = docTrivia.Content + .OfType() + .FirstOrDefault(e => String.Equals(e.StartTag?.Name.ToString(), "summary", StringComparison.Ordinal)); + + if (summaryElement != null) + { + var summaryText = String.Join(" ", summaryElement.Content + .OfType() + .SelectMany(t => t.TextTokens) + .Select(tt => tt.Text.Trim()) + .Where(t => t.Length > 0)); + + if (!String.IsNullOrWhiteSpace(summaryText)) + return summaryText; + } + + var summaryFromRaw = ParseSummaryFromRaw(docTrivia.ToFullString()); + if (!String.IsNullOrWhiteSpace(summaryFromRaw)) + return summaryFromRaw; + + var summaryFromSource = ParseSummaryFromSourceText(syntaxRef); + if (!String.IsNullOrWhiteSpace(summaryFromSource)) + return summaryFromSource; + } + + return null; + } + + private static string? GetCategory(ISymbol? symbol) + { + if (symbol == null) + return null; + + foreach (var attribute in symbol.GetAttributes()) + { + if (attribute.AttributeClass is not { } attributeClass) + continue; + + if (!IsCategoryAttribute(attributeClass)) + continue; + + if (attribute.ConstructorArguments.Length > 0) + { + foreach (var arg in attribute.ConstructorArguments) + { + if (arg.Value is string category && !String.IsNullOrWhiteSpace(category)) + return category; + } + } + + if (attribute.NamedArguments.Length > 0) + { + foreach (var arg in attribute.NamedArguments) + { + if (arg.Value.Value is string namedCategory && !String.IsNullOrWhiteSpace(namedCategory)) + return namedCategory; + } + } + } + + return null; + } + + private static bool IsCategoryAttribute(INamedTypeSymbol attributeClass) + { + if (attributeClass.Name == "CategoryAttribute") + return true; + + var displayName = attributeClass.ToDisplayString(); + return displayName == "System.ComponentModel.CategoryAttribute" + || displayName.EndsWith(".CategoryAttribute", StringComparison.Ordinal); + } + + private static string? ParseSummaryFromSourceText(SyntaxReference syntaxRef) + { + var syntaxTree = syntaxRef.SyntaxTree; + if (syntaxTree == null) + return null; + + var sourceText = syntaxTree.GetText(); + var startLine = sourceText.Lines.GetLineFromPosition(syntaxRef.Span.Start).LineNumber; + if (startLine == 0) + return null; + + var collectedLines = new Stack(); + bool foundDocLine = false; + + for (int lineNumber = startLine - 1; lineNumber >= 0; lineNumber--) + { + var line = sourceText.Lines[lineNumber].ToString(); + var trimmed = line.Trim(); + + if (trimmed.Length == 0) + { + if (foundDocLine) + break; + + continue; + } + + if (!trimmed.StartsWith("///", StringComparison.Ordinal)) + { + if (!foundDocLine) + continue; + + break; + } + + foundDocLine = true; + collectedLines.Push(trimmed); + } + + if (!foundDocLine || collectedLines.Count == 0) + return null; + + var raw = string.Join("\n", collectedLines); + return ParseSummaryFromRaw(raw); + } + + private static string? ParseSummaryFromRaw(string? rawText) + { + if (String.IsNullOrWhiteSpace(rawText)) + return null; + + var builder = new StringBuilder(); + var lines = rawText!.Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries); + foreach (var line in lines) + { + var trimmed = line.Trim(); + if (!trimmed.StartsWith("///", StringComparison.Ordinal)) + continue; + + var content = trimmed.Length > 3 ? trimmed.Substring(3).TrimStart() : String.Empty; + builder.AppendLine(content); + } + + var inner = builder.ToString().Trim(); + if (inner.Length == 0) + return null; + + var wrapped = $"{inner}"; + return ParseSummaryFromXml(wrapped); + } + + private static string? ParseSummaryFromXml(string? documentation) + { + if (String.IsNullOrWhiteSpace(documentation)) + return null; + + try + { + var document = XDocument.Parse(documentation); + var summary = document.Root?.Elements("summary").FirstOrDefault(); + if (summary == null) + return null; + + var text = summary.Value; + if (String.IsNullOrWhiteSpace(text)) + return null; + + var normalized = String.Join(" ", text + .Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries) + .Select(t => t.Trim()) + .Where(t => t.Length > 0)); + + return normalized.Length == 0 ? null : normalized; + } + catch + { + return null; + } + } + private static bool IsHandlerMethod(IMethodSymbol method, Compilation compilation, bool treatAsHandlerClass) { if (method.DeclaredAccessibility != Accessibility.Public) diff --git a/src/Foundatio.Mediator/HandlerGenerator.cs b/src/Foundatio.Mediator/HandlerGenerator.cs index de7194f..3581352 100644 --- a/src/Foundatio.Mediator/HandlerGenerator.cs +++ b/src/Foundatio.Mediator/HandlerGenerator.cs @@ -1,5 +1,6 @@ using Foundatio.Mediator.Models; using Foundatio.Mediator.Utility; +using System.Linq; namespace Foundatio.Mediator; @@ -92,6 +93,8 @@ namespace Foundatio.Mediator.Generated; GenerateHandleMethod(source, handler, configuration); + GenerateHandleDefaultMethod(source, handler); + GenerateUntypedHandleMethod(source, handler); GenerateInterceptorMethods(source, handler, configuration.InterceptorsEnabled); @@ -381,6 +384,79 @@ private static void GenerateHandleMethod(IndentedStringBuilder source, HandlerIn source.AppendLine(); } + private static void GenerateHandleDefaultMethod(IndentedStringBuilder source, HandlerInfo handler) + { + bool returnsValue = handler.ReturnType.IsTuple || handler.HasReturnValue; + bool needsAsync = handler.IsAsync || handler.ReturnType.IsTuple; + string methodName = GetHandlerDefaultMethodName(handler); + string handlerMethod = GetHandlerMethodName(handler); + string defaultReturnType = GetHandlerDefaultReturnType(handler); + + string methodReturnType; + if (needsAsync) + { + methodReturnType = returnsValue + ? $"System.Threading.Tasks.ValueTask<{defaultReturnType}>" + : "System.Threading.Tasks.ValueTask"; + } + else + { + methodReturnType = returnsValue ? defaultReturnType : "void"; + } + + source.AppendLine($"public static {(needsAsync ? "async " : string.Empty)}{methodReturnType} {methodName}(Foundatio.Mediator.IMediator mediator, System.IServiceProvider serviceProvider, {handler.MessageType.FullName} message, System.Threading.CancellationToken cancellationToken)"); + source.AppendLine("{"); + + source.IncrementIndent(); + + string awaitPrefix = handler.IsAsync ? "await " : string.Empty; + string awaitSuffix = handler.IsAsync ? ".ConfigureAwait(false)" : string.Empty; + if (handler.ReturnType.IsVoid && !handler.ReturnType.IsTuple) + { + source.AppendLine($"{awaitPrefix}{handlerMethod}(serviceProvider, message, cancellationToken){awaitSuffix};"); + source.AppendLine("return;"); + source.DecrementIndent(); + source.AppendLine("}"); + source.AppendLine(); + return; + } + + source.Append("var result = "); + source.AppendLine($"{awaitPrefix}{handlerMethod}(serviceProvider, message, cancellationToken){awaitSuffix};"); + + if (handler.ReturnType.IsTuple) + { + var returnItem = handler.ReturnType.TupleItems.First(); + var publishItems = handler.ReturnType.TupleItems.Skip(1).ToList(); + + foreach (var publishItem in publishItems) + { + string access = $"result.{publishItem.Name}"; + if (publishItem.IsNullable) + { + source.AppendLine($"if ({access} != null)"); + source.AppendLine("{"); + source.AppendLine($" await mediator.PublishAsync({access}, cancellationToken).ConfigureAwait(false);"); + source.AppendLine("}"); + } + else + { + source.AppendLine($"await mediator.PublishAsync({access}, cancellationToken).ConfigureAwait(false);"); + } + } + + source.AppendLine($"return result.{returnItem.Name};"); + } + else if (returnsValue) + { + source.AppendLine("return result;"); + } + + source.DecrementIndent(); + source.AppendLine("}"); + source.AppendLine(); + } + private static void GenerateUntypedHandleMethod(IndentedStringBuilder source, HandlerInfo handler) { // For async handlers that can use fast path and return void, generate a non-async version @@ -930,6 +1006,43 @@ public static string GetHandlerMethodName(HandlerInfo handler) return handler.IsAsync ? "HandleAsync" : "Handle"; } + public static string GetHandlerDefaultMethodName(HandlerInfo handler) + { + return handler.IsAsync || handler.ReturnType.IsTuple ? "DefaultHandleAsync" : "DefaultHandle"; + } + + private static string GetHandlerDefaultReturnType(HandlerInfo handler) + { + if (handler.ReturnType.IsTuple) + { + var tupleItem = handler.ReturnType.TupleItems.First(); + return AppendNullableAnnotation(tupleItem.TypeFullName, tupleItem.IsNullable); + } + + return AppendNullableAnnotation(handler.ReturnType.UnwrappedFullName, handler.ReturnType.IsNullable); + } + + private static string AppendNullableAnnotation(string typeName, bool isNullable) + { + if (!isNullable || string.IsNullOrEmpty(typeName) || typeName.EndsWith("?", StringComparison.Ordinal)) + return typeName; + + if (typeName.Equals("void", StringComparison.Ordinal)) + return typeName; + + return typeName + "?"; + } + + public static bool HandlerDefaultReturnsValue(HandlerInfo handler) + { + return handler.ReturnType.IsTuple || handler.HasReturnValue; + } + + public static bool HandlerDefaultIsAsync(HandlerInfo handler) + { + return handler.IsAsync || handler.ReturnType.IsTuple; + } + private static void Validate(SourceProductionContext context, List handlers) { var processedMiddleware = new HashSet(); diff --git a/src/Foundatio.Mediator/MediatorGenerator.cs b/src/Foundatio.Mediator/MediatorGenerator.cs index 6e5694c..5f656e2 100644 --- a/src/Foundatio.Mediator/MediatorGenerator.cs +++ b/src/Foundatio.Mediator/MediatorGenerator.cs @@ -1,4 +1,4 @@ -using System.Diagnostics; +using System.Diagnostics; using Foundatio.Mediator.Models; using Foundatio.Mediator.Utility; @@ -129,7 +129,7 @@ private static void Execute(ImmutableArray handlers, ImmutableArray { callSitesByMessage.TryGetValue(handler.MessageType, out var handlerCallSites); var applicableMiddleware = GetApplicableMiddlewares(allMiddleware.ToImmutableArray(), handler, compilation); - handlersWithInfo.Add(handler with { CallSites = new(handlerCallSites), Middleware = applicableMiddleware }); + handlersWithInfo.Add(handler with { CallSites = new(handlerCallSites ?? []), Middleware = applicableMiddleware }); } // Collect call sites that need cross-assembly interceptors @@ -188,6 +188,8 @@ private static void Execute(ImmutableArray handlers, ImmutableArray HandlerGenerator.Execute(context, handlersWithInfo, configuration); + EndpointGenerator.Execute(context, handlersWithInfo, compilation); + sw.Stop(); GeneratorDiagnostics.LogExecute( compilation.AssemblyName ?? "Unknown", diff --git a/src/Foundatio.Mediator/Models/HandlerInfo.cs b/src/Foundatio.Mediator/Models/HandlerInfo.cs index 2fdaf3c..962dca8 100644 --- a/src/Foundatio.Mediator/Models/HandlerInfo.cs +++ b/src/Foundatio.Mediator/Models/HandlerInfo.cs @@ -8,6 +8,8 @@ internal readonly record struct HandlerInfo public string FullName { get; init; } public string MethodName { get; init; } public TypeSymbolInfo MessageType { get; init; } + public string? MessageSummary { get; init; } + public string? Category { get; init; } public EquatableArray MessageInterfaces { get; init; } public EquatableArray MessageBaseClasses { get; init; } public bool HasReturnValue => !ReturnType.IsVoid; diff --git a/src/Foundatio.Mediator/Models/TypeSymbolInfo.cs b/src/Foundatio.Mediator/Models/TypeSymbolInfo.cs index 83bf221..3dee2ff 100644 --- a/src/Foundatio.Mediator/Models/TypeSymbolInfo.cs +++ b/src/Foundatio.Mediator/Models/TypeSymbolInfo.cs @@ -9,6 +9,10 @@ internal readonly record struct TypeSymbolInfo /// public string Identifier { get; init; } /// + /// The simple name of the type without namespace qualification. + /// + public string Name { get; init; } + /// /// The full name of the type, including namespace and any generic parameters. /// This may use short names when types are in scope via using directives. /// @@ -88,6 +92,7 @@ public static TypeSymbolInfo Void() return new TypeSymbolInfo { Identifier = "void", + Name = "void", FullName = "void", QualifiedName = "void", UnwrappedFullName = "void", @@ -166,6 +171,7 @@ static string GetTypeArgIdentifier(ITypeSymbol ts) return new TypeSymbolInfo { Identifier = identifier, + Name = typeSymbol.Name, FullName = typeSymbol.ToDisplayString(), QualifiedName = qualifiedName, UnwrappedFullName = unwrappedTypeFullName, diff --git a/tests/Foundatio.Mediator.Tests/BasicHandlerGenerationTests.GeneratesWrapperForSimpleHandler.received.txt b/tests/Foundatio.Mediator.Tests/BasicHandlerGenerationTests.GeneratesWrapperForSimpleHandler.received.txt new file mode 100644 index 0000000..96cdef7 --- /dev/null +++ b/tests/Foundatio.Mediator.Tests/BasicHandlerGenerationTests.GeneratesWrapperForSimpleHandler.received.txt @@ -0,0 +1,180 @@ +{ + Diagnostics: null, + GeneratorDiagnostics: null, + GeneratedTrees: [ + { + HintName: InterceptsLocationAttribute.g.cs, + Source: +// This file was generated by Foundatio.Mediator source generators. +// Changes to this file may be lost when the code is regenerated. +// + +#nullable enable + +using System; + +namespace System.Runtime.CompilerServices; + +/// +/// Indicates that a method is an interceptor and provides the location of the intercepted call. +/// +[global::System.CodeDom.Compiler.GeneratedCode("Foundatio.Mediator", "")] +[global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] +internal sealed class InterceptsLocationAttribute : global::System.Attribute +{ + /// + /// Initializes a new instance of the class. + /// + /// The version of the location encoding. + /// The encoded location data. + public InterceptsLocationAttribute(int version, string data) + { + Version = version; + Data = data; + } + + /// + /// Gets the version of the location encoding. + /// + public int Version { get; } + + /// + /// Gets the encoded location data. + /// + public string Data { get; } +} + + }, + { + HintName: PingHandler_Ping_Handler.g.cs, + Source: +// This file was generated by Foundatio.Mediator source generators. +// Changes to this file may be lost when the code is regenerated. +// + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace Foundatio.Mediator; + +[global::System.CodeDom.Compiler.GeneratedCode("Foundatio.Mediator", "")] +[ExcludeFromCodeCoverage] +internal static class PingHandler_Ping_Handler +{ + public static async System.Threading.Tasks.Task HandleAsync(System.IServiceProvider serviceProvider, Ping message, System.Threading.CancellationToken cancellationToken) + { + var logger = serviceProvider.GetService()?.CreateLogger("PingHandler"); + logger?.LogDebug("Processing message {MessageType}", "Ping"); + + using var activity = MediatorActivitySource.Instance.StartActivity("Ping"); + activity?.SetTag("messaging.system", "Foundatio.Mediator"); + activity?.SetTag("messaging.message.type", "Ping"); + + string? handlerResult = default; + var handlerInstance = GetOrCreateHandler(serviceProvider); + handlerResult = await handlerInstance.HandleAsync(message, cancellationToken); + + logger?.LogDebug("Completed processing message {MessageType}", "Ping"); + activity?.SetStatus(System.Diagnostics.ActivityStatusCode.Ok); + return handlerResult; + } + + public static async System.Threading.Tasks.ValueTask DefaultHandleAsync(Foundatio.Mediator.IMediator mediator, System.IServiceProvider serviceProvider, Ping message, System.Threading.CancellationToken cancellationToken) + { + var result = await HandleAsync(serviceProvider, message, cancellationToken).ConfigureAwait(false); + return result; + } + + public static async ValueTask UntypedHandleAsync(IMediator mediator, object message, CancellationToken cancellationToken, Type? responseType) + { + using var handlerScope = GetOrCreateScope(mediator, cancellationToken); + var typedMessage = (Ping)message; + var result = await HandleAsync(handlerScope.Services, typedMessage, cancellationToken); + + if (responseType == null) + { + return null; + } + + return result; + } + + [DebuggerStepThrough] + private static HandlerScopeValue GetOrCreateScope(IMediator mediator, CancellationToken cancellationToken) + { + return HandlerScope.GetOrCreate(mediator, cancellationToken); + } + + [DebuggerStepThrough] + private static PingHandler GetOrCreateHandler(IServiceProvider serviceProvider) + { + return serviceProvider.GetRequiredService(); + } +} + + }, + { + HintName: Tests_FoundatioModuleAttribute.g.cs, + Source: +// This file was generated by Foundatio.Mediator source generators. +// Changes to this file may be lost when the code is regenerated. +// + +#nullable enable + + +[assembly: Foundatio.Mediator.FoundatioModule] + + }, + { + HintName: Tests_MediatorHandlers.g.cs, + Source: +// This file was generated by Foundatio.Mediator source generators. +// Changes to this file may be lost when the code is regenerated. +// + +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; + +namespace Foundatio.Mediator; + +[global::System.CodeDom.Compiler.GeneratedCode("Foundatio.Mediator", "")] +[ExcludeFromCodeCoverage] +public static class Tests_MediatorHandlers +{ + public static void AddHandlers(this IServiceCollection services) + { + // Register HandlerRegistration instances keyed by message type name + + services.TryAddSingleton(); + services.AddHandler(new HandlerRegistration( + MessageTypeKey.Get(typeof(Ping)), + "PingHandler_Ping_Handler", + PingHandler_Ping_Handler.UntypedHandleAsync, + null, + true)); + + } +} + + } + ] +} \ No newline at end of file diff --git a/tests/Foundatio.Mediator.Tests/BasicHandlerGenerationTests.GeneratesWrapperForSimpleHandler.verified.txt b/tests/Foundatio.Mediator.Tests/BasicHandlerGenerationTests.GeneratesWrapperForSimpleHandler.verified.txt index 91a7cb3..5228069 100644 --- a/tests/Foundatio.Mediator.Tests/BasicHandlerGenerationTests.GeneratesWrapperForSimpleHandler.verified.txt +++ b/tests/Foundatio.Mediator.Tests/BasicHandlerGenerationTests.GeneratesWrapperForSimpleHandler.verified.txt @@ -296,6 +296,12 @@ public static class PingHandler_Ping_Handler } } + public static async System.Threading.Tasks.ValueTask DefaultHandleAsync(Foundatio.Mediator.IMediator mediator, System.IServiceProvider serviceProvider, Ping message, System.Threading.CancellationToken cancellationToken) + { + var result = await HandleAsync(serviceProvider, message, cancellationToken).ConfigureAwait(false); + return result; + } + public static async ValueTask UntypedHandleAsync(IMediator mediator, object message, CancellationToken cancellationToken, Type? responseType) { var typedMessage = (Ping)message;