diff --git a/Foundatio.Mediator.slnx b/Foundatio.Mediator.slnx
index 2b7557c..8787efa 100644
--- a/Foundatio.Mediator.slnx
+++ b/Foundatio.Mediator.slnx
@@ -8,12 +8,13 @@
-
-
+
+
+
diff --git a/samples/ConsoleSample/ConsoleSample.csproj b/samples/ConsoleSample/ConsoleSample.csproj
index f0d5613..41c3305 100644
--- a/samples/ConsoleSample/ConsoleSample.csproj
+++ b/samples/ConsoleSample/ConsoleSample.csproj
@@ -34,6 +34,12 @@
+
+
+
+
+
+
diff --git a/samples/ConsoleSample/Messages/Messages.cs b/samples/ConsoleSample/Messages/Messages.cs
index af686cb..a5e7c35 100644
--- a/samples/ConsoleSample/Messages/Messages.cs
+++ b/samples/ConsoleSample/Messages/Messages.cs
@@ -3,11 +3,14 @@
namespace ConsoleSample.Messages;
#region Simple
+/// Ping endpoint used to test connectivity.
public record Ping(string Text);
+/// Gets a personalized greeting for the specified user.
public record GetGreeting(string Name);
#endregion
// Order CRUD messages
+/// Creates a new order for the specified customer.
public record CreateOrder(
[Required(ErrorMessage = "Customer ID is required")]
[StringLength(50, MinimumLength = 3, ErrorMessage = "Customer ID must be between 3 and 50 characters")]
@@ -20,8 +23,11 @@ public record CreateOrder(
[Required(ErrorMessage = "Description is required")]
[StringLength(200, MinimumLength = 5, ErrorMessage = "Description must be between 5 and 200 characters")]
string Description) : IValidatable;
+/// Gets an existing order by identifier.
public record GetOrder(string OrderId);
+/// Updates mutable fields on an existing order.
public record UpdateOrder(string OrderId, decimal? Amount, string? Description);
+/// Deletes an existing order.
public record DeleteOrder(string OrderId);
// Event messages (for publish pattern)
diff --git a/samples/ConsoleSample/Middleware/TransactionMiddleware.cs b/samples/ConsoleSample/Middleware/TransactionMiddleware.cs
index 5e41712..d229212 100644
--- a/samples/ConsoleSample/Middleware/TransactionMiddleware.cs
+++ b/samples/ConsoleSample/Middleware/TransactionMiddleware.cs
@@ -23,10 +23,9 @@ public void After(CreateOrder cmd, IDbTransaction transaction, ILogger logger)
{
- if (transaction == null)
+ if (transaction is not FakeTransaction tx)
return;
- var tx = (FakeTransaction)transaction;
if (result?.IsSuccess == true)
return;
diff --git a/samples/ConsoleSample/Program.cs b/samples/ConsoleSample/Program.cs
index 7f0f6f1..fa8cd05 100644
--- a/samples/ConsoleSample/Program.cs
+++ b/samples/ConsoleSample/Program.cs
@@ -1,7 +1,15 @@
using ConsoleSample;
using Foundatio.Mediator;
+using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
+using Scalar.AspNetCore;
+
+if (args.Any(a => string.Equals(a, "--web", StringComparison.OrdinalIgnoreCase)))
+{
+ await RunMinimalApiAsync(args);
+ return;
+}
// Create application host
var builder = Host.CreateApplicationBuilder(args);
@@ -16,3 +24,21 @@
var sampleRunner = new SampleRunner(mediator, host.Services);
await sampleRunner.RunAllSamplesAsync();
+
+static async Task RunMinimalApiAsync(string[] args)
+{
+ var builder = WebApplication.CreateBuilder(args);
+ builder.Services.ConfigureServices();
+ builder.Services.AddEndpointsApiExplorer();
+ builder.Services.AddOpenApi();
+
+ var app = builder.Build();
+ app.MapOpenApi();
+ app.MapScalarApiReference(options =>
+ {
+ options.Title = "Foundatio.Mediator Console Sample";
+ });
+ app.MapMediatorEndpoints();
+
+ await app.RunAsync();
+}
diff --git a/src/Foundatio.Mediator.Abstractions/MediatorExtensions.cs b/src/Foundatio.Mediator.Abstractions/MediatorExtensions.cs
index d24759d..2b578f2 100644
--- a/src/Foundatio.Mediator.Abstractions/MediatorExtensions.cs
+++ b/src/Foundatio.Mediator.Abstractions/MediatorExtensions.cs
@@ -16,7 +16,7 @@ public static IServiceCollection AddMediator(this IServiceCollection services, M
if (configuration.Assemblies == null)
{
- configuration.Assemblies = AppDomain.CurrentDomain.GetAssemblies().Where(a => !a.IsDynamic && !a.FullName.StartsWith("System.")).ToList();
+ configuration.Assemblies = AppDomain.CurrentDomain.GetAssemblies().Where(a => a is { IsDynamic: false, FullName: not null } && !a.FullName.StartsWith("System.")).ToList();
}
services.Add(ServiceDescriptor.Describe(typeof(IMediator), typeof(Mediator), configuration.MediatorLifetime));
diff --git a/src/Foundatio.Mediator/EndpointGenerator.cs b/src/Foundatio.Mediator/EndpointGenerator.cs
new file mode 100644
index 0000000..881e809
--- /dev/null
+++ b/src/Foundatio.Mediator/EndpointGenerator.cs
@@ -0,0 +1,345 @@
+using System.Collections.Generic;
+using System.Text;
+using Foundatio.Mediator.Models;
+using Foundatio.Mediator.Utility;
+using Microsoft.CodeAnalysis;
+
+namespace Foundatio.Mediator;
+
+internal static class EndpointGenerator
+{
+ private static readonly string[] SupportedPrefixes = ["Get", "Create", "Update", "Delete"];
+
+ public static void Execute(SourceProductionContext context, List handlers, Compilation compilation)
+ {
+ if (!SupportsMinimalApis(compilation))
+ return;
+
+ var endpointHandlers = handlers
+ .Where(IsEndpointCandidate)
+ .OrderBy(h => h.MessageType.FullName, StringComparer.Ordinal)
+ .ToList();
+
+ if (endpointHandlers.Count == 0)
+ return;
+
+ bool hasAsParameters = compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Mvc.AsParametersAttribute") is not null
+ || compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Http.AsParametersAttribute") is not null;
+ bool hasFromBody = compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Mvc.FromBodyAttribute") is not null;
+ bool hasFromQuery = compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Mvc.FromQueryAttribute") is not null;
+ bool hasFromServices = compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Mvc.FromServicesAttribute") is not null;
+ bool hasOpenApi = compilation.GetTypeByMetadataName("Microsoft.AspNetCore.OpenApi.OpenApiRouteHandlerBuilderExtensions") is not null;
+
+ var source = BuildSource(endpointHandlers, hasAsParameters, hasFromBody, hasFromQuery, hasFromServices, hasOpenApi);
+ context.AddSource("MediatorMinimalApiEndpoints.g.cs", source);
+ }
+
+ private static string BuildSource(IReadOnlyList handlers, bool hasAsParameters, bool hasFromBody, bool hasFromQuery, bool hasFromServices, bool hasOpenApi)
+ {
+ var source = new IndentedStringBuilder();
+ source.AddGeneratedFileHeader();
+
+ source.AppendLines("""
+ using System;
+ using System.Collections.Generic;
+ using System.Linq;
+ using System.Runtime.CompilerServices;
+ using System.Threading;
+ using System.Threading.Tasks;
+ using Microsoft.AspNetCore.Builder;
+ using Microsoft.AspNetCore.Http;
+ """);
+
+ if (hasAsParameters || hasFromBody || hasFromQuery || hasFromServices)
+ source.AppendLine("using Microsoft.AspNetCore.Mvc;");
+
+ source.AppendLines("""
+ using Microsoft.AspNetCore.Routing;
+ using Microsoft.Extensions.DependencyInjection;
+
+ namespace Foundatio.Mediator;
+
+ public static class MediatorEndpointExtensions
+ {
+ private const string DefaultBasePath = "/api/messages";
+
+ public static IEndpointRouteBuilder MapMediatorEndpoints(this IEndpointRouteBuilder app, string? basePath = null)
+ {
+ ArgumentNullException.ThrowIfNull(app);
+ var group = app.MapGroup(string.IsNullOrWhiteSpace(basePath) ? DefaultBasePath : basePath!);
+ RegisterMediatorEndpoints(group);
+ return app;
+ }
+
+ private static void RegisterMediatorEndpoints(RouteGroupBuilder group)
+ {
+ ArgumentNullException.ThrowIfNull(group);
+ """);
+
+ source.IncrementIndent();
+ source.IncrementIndent();
+
+ var categoryOrder = new List();
+ var handlersByCategory = new Dictionary>(StringComparer.Ordinal);
+ var categoryLabels = new Dictionary(StringComparer.Ordinal);
+
+ foreach (var handler in handlers)
+ {
+ var categoryKey = handler.Category?.Trim() ?? string.Empty;
+ if (!handlersByCategory.TryGetValue(categoryKey, out var list))
+ {
+ list = new List();
+ handlersByCategory[categoryKey] = list;
+ categoryOrder.Add(categoryKey);
+ categoryLabels[categoryKey] = handler.Category?.Trim();
+ }
+
+ list.Add(handler);
+ }
+
+ int endpointIndex = 0;
+ int categoryIndex = 0;
+
+ for (int categoryOrderIndex = 0; categoryOrderIndex < categoryOrder.Count; categoryOrderIndex++)
+ {
+ var categoryKey = categoryOrder[categoryOrderIndex];
+ var categoryHandlers = handlersByCategory[categoryKey];
+ string? categoryLabel = categoryLabels[categoryKey];
+ string groupAccessor = "group";
+
+ if (!string.IsNullOrEmpty(categoryKey) && !string.IsNullOrWhiteSpace(categoryLabel))
+ {
+ string categoryVar = $"categoryGroup{categoryIndex++}";
+ string routeSegment = ToCategoryRouteSegmentLiteral(categoryLabel!);
+ source.AppendLine($"var {categoryVar} = group.MapGroup(\"/{routeSegment}\");");
+ source.AppendLine($"{categoryVar}.WithGroupName(\"{EscapeString(categoryLabel!)}\");");
+ source.AppendLine();
+ groupAccessor = categoryVar;
+ }
+
+ foreach (var handler in categoryHandlers)
+ {
+ string handlerClassName = HandlerGenerator.GetHandlerClassName(handler);
+ string methodName = HandlerGenerator.GetHandlerDefaultMethodName(handler);
+ string mapMethod = GetMapMethod(handler);
+ string routeSegment = ToRouteSegment(handler.MessageType.Name);
+ string requestParameter = $"{handler.MessageType.FullName} request";
+ bool handlerDefaultReturnsValue = HandlerGenerator.HandlerDefaultReturnsValue(handler);
+ bool handlerDefaultIsAsync = HandlerGenerator.HandlerDefaultIsAsync(handler);
+
+ bool requiresQueryBinding = RequiresQueryBinding(mapMethod);
+ bool requiresBodyBinding = RequiresBodyBinding(mapMethod);
+
+ if (requiresQueryBinding)
+ {
+ if (hasAsParameters)
+ {
+ requestParameter = "[AsParameters] " + requestParameter;
+ }
+ else if (hasFromQuery)
+ {
+ requestParameter = "[FromQuery] " + requestParameter;
+ }
+ }
+ else if (requiresBodyBinding && hasFromBody)
+ {
+ requestParameter = "[FromBody] " + requestParameter;
+ }
+
+ string servicesParameter = hasFromServices ? "[FromServices] IServiceProvider services" : "IServiceProvider services";
+ string mediatorParameter = hasFromServices ? "[FromServices] Foundatio.Mediator.IMediator mediator" : "Foundatio.Mediator.IMediator mediator";
+
+ string invocation = handlerDefaultIsAsync
+ ? $"await {handlerClassName}.{methodName}(mediator, services, request, cancellationToken).ConfigureAwait(false)"
+ : $"{handlerClassName}.{methodName}(mediator, services, request, cancellationToken)";
+
+ string endpointVar = $"endpoint{endpointIndex++}";
+
+ source.AppendLine($"var {endpointVar} = {groupAccessor}.{mapMethod}(\"/{routeSegment}\", async ({requestParameter}, {servicesParameter}, {mediatorParameter}, CancellationToken cancellationToken) =>");
+ source.AppendLine("{");
+ source.IncrementIndent();
+ if (handlerDefaultReturnsValue)
+ {
+ source.AppendLine($"var handlerResult = {invocation};");
+ source.AppendLine("return MediatorEndpointResultMapper.ToHttpResult(handlerResult);");
+ }
+ else
+ {
+ source.AppendLine($"{invocation};");
+ source.AppendLine("return MediatorEndpointResultMapper.ToHttpResult(null);");
+ }
+ source.DecrementIndent();
+ source.AppendLine("});");
+ source.AppendLine($"{endpointVar}.WithName(\"{handler.MessageType.Name}\");");
+ if (!string.IsNullOrWhiteSpace(handler.MessageSummary))
+ {
+ var escapedSummary = EscapeString(handler.MessageSummary!);
+ source.AppendLine($"{endpointVar}.WithSummary(\"{escapedSummary}\");");
+ source.AppendLine($"{endpointVar}.WithDescription(\"{escapedSummary}\");");
+ }
+ if (hasOpenApi)
+ source.AppendLine($"{endpointVar}.WithOpenApi();");
+ source.AppendLine();
+ }
+
+ if (categoryOrderIndex < categoryOrder.Count - 1)
+ source.AppendLine();
+ }
+
+ source.DecrementIndent();
+ source.DecrementIndent();
+
+ source.AppendLines("""
+ }
+
+ private static bool RequiresQueryBinding(string mapMethod) => mapMethod is "MapGet" or "MapDelete";
+ private static bool RequiresBodyBinding(string mapMethod) => mapMethod is "MapPost" or "MapPut";
+
+ private static class MediatorEndpointResultMapper
+ {
+ public static global::Microsoft.AspNetCore.Http.IResult ToHttpResult(object? value)
+ {
+ if (value is null)
+ return Results.NoContent();
+
+ if (value is Foundatio.Mediator.IResult mediatorResult)
+ return MapMediatorResult(mediatorResult);
+
+ return Results.Ok(value);
+ }
+
+ private static global::Microsoft.AspNetCore.Http.IResult MapMediatorResult(Foundatio.Mediator.IResult result)
+ {
+ return result.Status switch
+ {
+ ResultStatus.Success => result.GetValue() is { } value ? Results.Ok(value) : Results.Ok(),
+ ResultStatus.Created => Results.Created(string.IsNullOrWhiteSpace(result.Location) ? DefaultBasePath : result.Location, result.GetValue()),
+ ResultStatus.NoContent => Results.NoContent(),
+ ResultStatus.BadRequest => Results.BadRequest(result.Message),
+ ResultStatus.Invalid => Results.ValidationProblem(ToValidationState(result.ValidationErrors), detail: result.Message),
+ ResultStatus.NotFound => string.IsNullOrWhiteSpace(result.Message) ? Results.NotFound() : Results.NotFound(result.Message),
+ ResultStatus.Unauthorized => Results.Unauthorized(),
+ ResultStatus.Forbidden => Results.Forbid(),
+ ResultStatus.Conflict => Results.Conflict(result.Message),
+ ResultStatus.Error => Results.Problem(detail: result.Message),
+ ResultStatus.CriticalError => Results.Problem(detail: result.Message, statusCode: StatusCodes.Status500InternalServerError),
+ ResultStatus.Unavailable => Results.StatusCode(StatusCodes.Status503ServiceUnavailable),
+ _ => Results.Ok(result.GetValue())
+ };
+ }
+
+ private static IDictionary ToValidationState(IEnumerable errors)
+ {
+ var dictionary = new Dictionary>(StringComparer.OrdinalIgnoreCase);
+
+ foreach (var error in errors)
+ {
+ var key = string.IsNullOrWhiteSpace(error.Identifier) ? string.Empty : error.Identifier;
+ if (!dictionary.TryGetValue(key, out var list))
+ {
+ list = new List();
+ dictionary[key] = list;
+ }
+
+ if (!string.IsNullOrWhiteSpace(error.ErrorMessage))
+ list.Add(error.ErrorMessage);
+ }
+
+ return dictionary.ToDictionary(kvp => kvp.Key, kvp => kvp.Value.ToArray());
+ }
+ }
+ }
+ """);
+
+ return source.ToString();
+ }
+
+ private static bool SupportsMinimalApis(Compilation compilation)
+ {
+ return compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Routing.IEndpointRouteBuilder") is not null
+ && compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Http.Results") is not null;
+ }
+
+ private static bool IsEndpointCandidate(HandlerInfo handler)
+ {
+ foreach (var prefix in SupportedPrefixes)
+ {
+ if (handler.MessageType.Name.StartsWith(prefix, StringComparison.OrdinalIgnoreCase))
+ return true;
+ }
+
+ return false;
+ }
+
+ private static string GetMapMethod(HandlerInfo handler)
+ {
+ var name = handler.MessageType.Name;
+ if (name.StartsWith("Get", StringComparison.OrdinalIgnoreCase))
+ return "MapGet";
+ if (name.StartsWith("Update", StringComparison.OrdinalIgnoreCase))
+ return "MapPut";
+ if (name.StartsWith("Create", StringComparison.OrdinalIgnoreCase))
+ return "MapPost";
+ if (name.StartsWith("Delete", StringComparison.OrdinalIgnoreCase))
+ return "MapDelete";
+
+ return "MapPost";
+ }
+
+ private static string ToRouteSegment(string name)
+ {
+ if (string.IsNullOrEmpty(name))
+ return name;
+
+ var builder = new StringBuilder();
+ for (int i = 0; i < name.Length; i++)
+ {
+ char c = name[i];
+ if (char.IsUpper(c) && i > 0)
+ builder.Append('-');
+
+ builder.Append(char.ToLowerInvariant(c));
+ }
+
+ return builder.ToString();
+ }
+
+ private static string EscapeString(string value)
+ {
+ if (string.IsNullOrEmpty(value))
+ return value;
+
+ return value
+ .Replace("\\", "\\\\")
+ .Replace("\"", "\\\"")
+ .Replace('\r', ' ')
+ .Replace('\n', ' ');
+ }
+
+ private static string ToCategoryRouteSegmentLiteral(string category)
+ {
+ if (string.IsNullOrWhiteSpace(category))
+ return "uncategorized";
+
+ var builder = new StringBuilder();
+ foreach (var c in category)
+ {
+ if (char.IsLetterOrDigit(c))
+ {
+ builder.Append(char.ToLowerInvariant(c));
+ continue;
+ }
+
+ if (builder.Length == 0 || builder[builder.Length - 1] == '-')
+ continue;
+
+ builder.Append('-');
+ }
+
+ var result = builder.ToString().Trim('-');
+ return string.IsNullOrEmpty(result) ? "uncategorized" : result;
+ }
+
+ private static bool RequiresQueryBinding(string mapMethod) => mapMethod is "MapGet" or "MapDelete";
+ private static bool RequiresBodyBinding(string mapMethod) => mapMethod is "MapPost" or "MapPut";
+}
diff --git a/src/Foundatio.Mediator/Foundatio.Mediator.csproj b/src/Foundatio.Mediator/Foundatio.Mediator.csproj
index 0fc725f..9e630d8 100644
--- a/src/Foundatio.Mediator/Foundatio.Mediator.csproj
+++ b/src/Foundatio.Mediator/Foundatio.Mediator.csproj
@@ -1,4 +1,4 @@
-
+
@@ -14,7 +14,6 @@
-
diff --git a/src/Foundatio.Mediator/HandlerAnalyzer.cs b/src/Foundatio.Mediator/HandlerAnalyzer.cs
index 76631fd..539609d 100644
--- a/src/Foundatio.Mediator/HandlerAnalyzer.cs
+++ b/src/Foundatio.Mediator/HandlerAnalyzer.cs
@@ -1,3 +1,5 @@
+using System.Text;
+using System.Xml.Linq;
using Foundatio.Mediator.Models;
using Foundatio.Mediator.Utility;
@@ -208,12 +210,17 @@ public static List GetHandlers(GeneratorSyntaxContext context)
bool hasConstructorParameters = !handlerMethod.IsStatic &&
classSymbol.InstanceConstructors.Any(c => c.Parameters.Length > 0);
+ string? messageSummary = GetSummaryComment(messageType);
+ string? category = GetCategory(messageType) ?? GetCategory(classSymbol);
+
handlers.Add(new HandlerInfo
{
Identifier = classSymbol.Name.ToIdentifier(),
FullName = classSymbol.ToDisplayString(),
MethodName = handlerMethod.Name,
MessageType = TypeSymbolInfo.From(messageType, context.SemanticModel.Compilation),
+ MessageSummary = messageSummary,
+ Category = category,
MessageInterfaces = new(messageInterfaces.ToArray()),
MessageBaseClasses = new(messageBaseClasses.ToArray()),
ReturnType = TypeSymbolInfo.From(handlerMethod.ReturnType, context.SemanticModel.Compilation),
@@ -237,6 +244,204 @@ public static List GetHandlers(GeneratorSyntaxContext context)
return handlers;
}
+ private static string? GetSummaryComment(ISymbol symbol)
+ {
+ var summary = ParseSummaryFromXml(symbol.GetDocumentationCommentXml(expandIncludes: true));
+ if (!String.IsNullOrWhiteSpace(summary))
+ return summary;
+
+ foreach (var syntaxRef in symbol.DeclaringSyntaxReferences)
+ {
+ if (syntaxRef.GetSyntax() is not CSharpSyntaxNode syntaxNode)
+ continue;
+
+ var docTrivia = syntaxNode.GetLeadingTrivia()
+ .Select(t => t.GetStructure())
+ .OfType()
+ .FirstOrDefault();
+
+ if (docTrivia == null)
+ {
+ var summaryFromSourceText = ParseSummaryFromSourceText(syntaxRef);
+ if (!String.IsNullOrWhiteSpace(summaryFromSourceText))
+ return summaryFromSourceText;
+
+ continue;
+ }
+
+ var summaryElement = docTrivia.Content
+ .OfType()
+ .FirstOrDefault(e => String.Equals(e.StartTag?.Name.ToString(), "summary", StringComparison.Ordinal));
+
+ if (summaryElement != null)
+ {
+ var summaryText = String.Join(" ", summaryElement.Content
+ .OfType()
+ .SelectMany(t => t.TextTokens)
+ .Select(tt => tt.Text.Trim())
+ .Where(t => t.Length > 0));
+
+ if (!String.IsNullOrWhiteSpace(summaryText))
+ return summaryText;
+ }
+
+ var summaryFromRaw = ParseSummaryFromRaw(docTrivia.ToFullString());
+ if (!String.IsNullOrWhiteSpace(summaryFromRaw))
+ return summaryFromRaw;
+
+ var summaryFromSource = ParseSummaryFromSourceText(syntaxRef);
+ if (!String.IsNullOrWhiteSpace(summaryFromSource))
+ return summaryFromSource;
+ }
+
+ return null;
+ }
+
+ private static string? GetCategory(ISymbol? symbol)
+ {
+ if (symbol == null)
+ return null;
+
+ foreach (var attribute in symbol.GetAttributes())
+ {
+ if (attribute.AttributeClass is not { } attributeClass)
+ continue;
+
+ if (!IsCategoryAttribute(attributeClass))
+ continue;
+
+ if (attribute.ConstructorArguments.Length > 0)
+ {
+ foreach (var arg in attribute.ConstructorArguments)
+ {
+ if (arg.Value is string category && !String.IsNullOrWhiteSpace(category))
+ return category;
+ }
+ }
+
+ if (attribute.NamedArguments.Length > 0)
+ {
+ foreach (var arg in attribute.NamedArguments)
+ {
+ if (arg.Value.Value is string namedCategory && !String.IsNullOrWhiteSpace(namedCategory))
+ return namedCategory;
+ }
+ }
+ }
+
+ return null;
+ }
+
+ private static bool IsCategoryAttribute(INamedTypeSymbol attributeClass)
+ {
+ if (attributeClass.Name == "CategoryAttribute")
+ return true;
+
+ var displayName = attributeClass.ToDisplayString();
+ return displayName == "System.ComponentModel.CategoryAttribute"
+ || displayName.EndsWith(".CategoryAttribute", StringComparison.Ordinal);
+ }
+
+ private static string? ParseSummaryFromSourceText(SyntaxReference syntaxRef)
+ {
+ var syntaxTree = syntaxRef.SyntaxTree;
+ if (syntaxTree == null)
+ return null;
+
+ var sourceText = syntaxTree.GetText();
+ var startLine = sourceText.Lines.GetLineFromPosition(syntaxRef.Span.Start).LineNumber;
+ if (startLine == 0)
+ return null;
+
+ var collectedLines = new Stack();
+ bool foundDocLine = false;
+
+ for (int lineNumber = startLine - 1; lineNumber >= 0; lineNumber--)
+ {
+ var line = sourceText.Lines[lineNumber].ToString();
+ var trimmed = line.Trim();
+
+ if (trimmed.Length == 0)
+ {
+ if (foundDocLine)
+ break;
+
+ continue;
+ }
+
+ if (!trimmed.StartsWith("///", StringComparison.Ordinal))
+ {
+ if (!foundDocLine)
+ continue;
+
+ break;
+ }
+
+ foundDocLine = true;
+ collectedLines.Push(trimmed);
+ }
+
+ if (!foundDocLine || collectedLines.Count == 0)
+ return null;
+
+ var raw = string.Join("\n", collectedLines);
+ return ParseSummaryFromRaw(raw);
+ }
+
+ private static string? ParseSummaryFromRaw(string? rawText)
+ {
+ if (String.IsNullOrWhiteSpace(rawText))
+ return null;
+
+ var builder = new StringBuilder();
+ var lines = rawText!.Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries);
+ foreach (var line in lines)
+ {
+ var trimmed = line.Trim();
+ if (!trimmed.StartsWith("///", StringComparison.Ordinal))
+ continue;
+
+ var content = trimmed.Length > 3 ? trimmed.Substring(3).TrimStart() : String.Empty;
+ builder.AppendLine(content);
+ }
+
+ var inner = builder.ToString().Trim();
+ if (inner.Length == 0)
+ return null;
+
+ var wrapped = $"{inner}";
+ return ParseSummaryFromXml(wrapped);
+ }
+
+ private static string? ParseSummaryFromXml(string? documentation)
+ {
+ if (String.IsNullOrWhiteSpace(documentation))
+ return null;
+
+ try
+ {
+ var document = XDocument.Parse(documentation);
+ var summary = document.Root?.Elements("summary").FirstOrDefault();
+ if (summary == null)
+ return null;
+
+ var text = summary.Value;
+ if (String.IsNullOrWhiteSpace(text))
+ return null;
+
+ var normalized = String.Join(" ", text
+ .Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries)
+ .Select(t => t.Trim())
+ .Where(t => t.Length > 0));
+
+ return normalized.Length == 0 ? null : normalized;
+ }
+ catch
+ {
+ return null;
+ }
+ }
+
private static bool IsHandlerMethod(IMethodSymbol method, Compilation compilation, bool treatAsHandlerClass)
{
if (method.DeclaredAccessibility != Accessibility.Public)
diff --git a/src/Foundatio.Mediator/HandlerGenerator.cs b/src/Foundatio.Mediator/HandlerGenerator.cs
index de7194f..3581352 100644
--- a/src/Foundatio.Mediator/HandlerGenerator.cs
+++ b/src/Foundatio.Mediator/HandlerGenerator.cs
@@ -1,5 +1,6 @@
using Foundatio.Mediator.Models;
using Foundatio.Mediator.Utility;
+using System.Linq;
namespace Foundatio.Mediator;
@@ -92,6 +93,8 @@ namespace Foundatio.Mediator.Generated;
GenerateHandleMethod(source, handler, configuration);
+ GenerateHandleDefaultMethod(source, handler);
+
GenerateUntypedHandleMethod(source, handler);
GenerateInterceptorMethods(source, handler, configuration.InterceptorsEnabled);
@@ -381,6 +384,79 @@ private static void GenerateHandleMethod(IndentedStringBuilder source, HandlerIn
source.AppendLine();
}
+ private static void GenerateHandleDefaultMethod(IndentedStringBuilder source, HandlerInfo handler)
+ {
+ bool returnsValue = handler.ReturnType.IsTuple || handler.HasReturnValue;
+ bool needsAsync = handler.IsAsync || handler.ReturnType.IsTuple;
+ string methodName = GetHandlerDefaultMethodName(handler);
+ string handlerMethod = GetHandlerMethodName(handler);
+ string defaultReturnType = GetHandlerDefaultReturnType(handler);
+
+ string methodReturnType;
+ if (needsAsync)
+ {
+ methodReturnType = returnsValue
+ ? $"System.Threading.Tasks.ValueTask<{defaultReturnType}>"
+ : "System.Threading.Tasks.ValueTask";
+ }
+ else
+ {
+ methodReturnType = returnsValue ? defaultReturnType : "void";
+ }
+
+ source.AppendLine($"public static {(needsAsync ? "async " : string.Empty)}{methodReturnType} {methodName}(Foundatio.Mediator.IMediator mediator, System.IServiceProvider serviceProvider, {handler.MessageType.FullName} message, System.Threading.CancellationToken cancellationToken)");
+ source.AppendLine("{");
+
+ source.IncrementIndent();
+
+ string awaitPrefix = handler.IsAsync ? "await " : string.Empty;
+ string awaitSuffix = handler.IsAsync ? ".ConfigureAwait(false)" : string.Empty;
+ if (handler.ReturnType.IsVoid && !handler.ReturnType.IsTuple)
+ {
+ source.AppendLine($"{awaitPrefix}{handlerMethod}(serviceProvider, message, cancellationToken){awaitSuffix};");
+ source.AppendLine("return;");
+ source.DecrementIndent();
+ source.AppendLine("}");
+ source.AppendLine();
+ return;
+ }
+
+ source.Append("var result = ");
+ source.AppendLine($"{awaitPrefix}{handlerMethod}(serviceProvider, message, cancellationToken){awaitSuffix};");
+
+ if (handler.ReturnType.IsTuple)
+ {
+ var returnItem = handler.ReturnType.TupleItems.First();
+ var publishItems = handler.ReturnType.TupleItems.Skip(1).ToList();
+
+ foreach (var publishItem in publishItems)
+ {
+ string access = $"result.{publishItem.Name}";
+ if (publishItem.IsNullable)
+ {
+ source.AppendLine($"if ({access} != null)");
+ source.AppendLine("{");
+ source.AppendLine($" await mediator.PublishAsync({access}, cancellationToken).ConfigureAwait(false);");
+ source.AppendLine("}");
+ }
+ else
+ {
+ source.AppendLine($"await mediator.PublishAsync({access}, cancellationToken).ConfigureAwait(false);");
+ }
+ }
+
+ source.AppendLine($"return result.{returnItem.Name};");
+ }
+ else if (returnsValue)
+ {
+ source.AppendLine("return result;");
+ }
+
+ source.DecrementIndent();
+ source.AppendLine("}");
+ source.AppendLine();
+ }
+
private static void GenerateUntypedHandleMethod(IndentedStringBuilder source, HandlerInfo handler)
{
// For async handlers that can use fast path and return void, generate a non-async version
@@ -930,6 +1006,43 @@ public static string GetHandlerMethodName(HandlerInfo handler)
return handler.IsAsync ? "HandleAsync" : "Handle";
}
+ public static string GetHandlerDefaultMethodName(HandlerInfo handler)
+ {
+ return handler.IsAsync || handler.ReturnType.IsTuple ? "DefaultHandleAsync" : "DefaultHandle";
+ }
+
+ private static string GetHandlerDefaultReturnType(HandlerInfo handler)
+ {
+ if (handler.ReturnType.IsTuple)
+ {
+ var tupleItem = handler.ReturnType.TupleItems.First();
+ return AppendNullableAnnotation(tupleItem.TypeFullName, tupleItem.IsNullable);
+ }
+
+ return AppendNullableAnnotation(handler.ReturnType.UnwrappedFullName, handler.ReturnType.IsNullable);
+ }
+
+ private static string AppendNullableAnnotation(string typeName, bool isNullable)
+ {
+ if (!isNullable || string.IsNullOrEmpty(typeName) || typeName.EndsWith("?", StringComparison.Ordinal))
+ return typeName;
+
+ if (typeName.Equals("void", StringComparison.Ordinal))
+ return typeName;
+
+ return typeName + "?";
+ }
+
+ public static bool HandlerDefaultReturnsValue(HandlerInfo handler)
+ {
+ return handler.ReturnType.IsTuple || handler.HasReturnValue;
+ }
+
+ public static bool HandlerDefaultIsAsync(HandlerInfo handler)
+ {
+ return handler.IsAsync || handler.ReturnType.IsTuple;
+ }
+
private static void Validate(SourceProductionContext context, List handlers)
{
var processedMiddleware = new HashSet();
diff --git a/src/Foundatio.Mediator/MediatorGenerator.cs b/src/Foundatio.Mediator/MediatorGenerator.cs
index 6e5694c..5f656e2 100644
--- a/src/Foundatio.Mediator/MediatorGenerator.cs
+++ b/src/Foundatio.Mediator/MediatorGenerator.cs
@@ -1,4 +1,4 @@
-using System.Diagnostics;
+using System.Diagnostics;
using Foundatio.Mediator.Models;
using Foundatio.Mediator.Utility;
@@ -129,7 +129,7 @@ private static void Execute(ImmutableArray handlers, ImmutableArray
{
callSitesByMessage.TryGetValue(handler.MessageType, out var handlerCallSites);
var applicableMiddleware = GetApplicableMiddlewares(allMiddleware.ToImmutableArray(), handler, compilation);
- handlersWithInfo.Add(handler with { CallSites = new(handlerCallSites), Middleware = applicableMiddleware });
+ handlersWithInfo.Add(handler with { CallSites = new(handlerCallSites ?? []), Middleware = applicableMiddleware });
}
// Collect call sites that need cross-assembly interceptors
@@ -188,6 +188,8 @@ private static void Execute(ImmutableArray handlers, ImmutableArray
HandlerGenerator.Execute(context, handlersWithInfo, configuration);
+ EndpointGenerator.Execute(context, handlersWithInfo, compilation);
+
sw.Stop();
GeneratorDiagnostics.LogExecute(
compilation.AssemblyName ?? "Unknown",
diff --git a/src/Foundatio.Mediator/Models/HandlerInfo.cs b/src/Foundatio.Mediator/Models/HandlerInfo.cs
index 2fdaf3c..962dca8 100644
--- a/src/Foundatio.Mediator/Models/HandlerInfo.cs
+++ b/src/Foundatio.Mediator/Models/HandlerInfo.cs
@@ -8,6 +8,8 @@ internal readonly record struct HandlerInfo
public string FullName { get; init; }
public string MethodName { get; init; }
public TypeSymbolInfo MessageType { get; init; }
+ public string? MessageSummary { get; init; }
+ public string? Category { get; init; }
public EquatableArray MessageInterfaces { get; init; }
public EquatableArray MessageBaseClasses { get; init; }
public bool HasReturnValue => !ReturnType.IsVoid;
diff --git a/src/Foundatio.Mediator/Models/TypeSymbolInfo.cs b/src/Foundatio.Mediator/Models/TypeSymbolInfo.cs
index 83bf221..3dee2ff 100644
--- a/src/Foundatio.Mediator/Models/TypeSymbolInfo.cs
+++ b/src/Foundatio.Mediator/Models/TypeSymbolInfo.cs
@@ -9,6 +9,10 @@ internal readonly record struct TypeSymbolInfo
///
public string Identifier { get; init; }
///
+ /// The simple name of the type without namespace qualification.
+ ///
+ public string Name { get; init; }
+ ///
/// The full name of the type, including namespace and any generic parameters.
/// This may use short names when types are in scope via using directives.
///
@@ -88,6 +92,7 @@ public static TypeSymbolInfo Void()
return new TypeSymbolInfo
{
Identifier = "void",
+ Name = "void",
FullName = "void",
QualifiedName = "void",
UnwrappedFullName = "void",
@@ -166,6 +171,7 @@ static string GetTypeArgIdentifier(ITypeSymbol ts)
return new TypeSymbolInfo
{
Identifier = identifier,
+ Name = typeSymbol.Name,
FullName = typeSymbol.ToDisplayString(),
QualifiedName = qualifiedName,
UnwrappedFullName = unwrappedTypeFullName,
diff --git a/tests/Foundatio.Mediator.Tests/BasicHandlerGenerationTests.GeneratesWrapperForSimpleHandler.received.txt b/tests/Foundatio.Mediator.Tests/BasicHandlerGenerationTests.GeneratesWrapperForSimpleHandler.received.txt
new file mode 100644
index 0000000..96cdef7
--- /dev/null
+++ b/tests/Foundatio.Mediator.Tests/BasicHandlerGenerationTests.GeneratesWrapperForSimpleHandler.received.txt
@@ -0,0 +1,180 @@
+{
+ Diagnostics: null,
+ GeneratorDiagnostics: null,
+ GeneratedTrees: [
+ {
+ HintName: InterceptsLocationAttribute.g.cs,
+ Source:
+// This file was generated by Foundatio.Mediator source generators.
+// Changes to this file may be lost when the code is regenerated.
+//
+
+#nullable enable
+
+using System;
+
+namespace System.Runtime.CompilerServices;
+
+///
+/// Indicates that a method is an interceptor and provides the location of the intercepted call.
+///
+[global::System.CodeDom.Compiler.GeneratedCode("Foundatio.Mediator", "")]
+[global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)]
+internal sealed class InterceptsLocationAttribute : global::System.Attribute
+{
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The version of the location encoding.
+ /// The encoded location data.
+ public InterceptsLocationAttribute(int version, string data)
+ {
+ Version = version;
+ Data = data;
+ }
+
+ ///
+ /// Gets the version of the location encoding.
+ ///
+ public int Version { get; }
+
+ ///
+ /// Gets the encoded location data.
+ ///
+ public string Data { get; }
+}
+
+ },
+ {
+ HintName: PingHandler_Ping_Handler.g.cs,
+ Source:
+// This file was generated by Foundatio.Mediator source generators.
+// Changes to this file may be lost when the code is regenerated.
+//
+
+#nullable enable
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Linq;
+using System.Reflection;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging;
+
+namespace Foundatio.Mediator;
+
+[global::System.CodeDom.Compiler.GeneratedCode("Foundatio.Mediator", "")]
+[ExcludeFromCodeCoverage]
+internal static class PingHandler_Ping_Handler
+{
+ public static async System.Threading.Tasks.Task HandleAsync(System.IServiceProvider serviceProvider, Ping message, System.Threading.CancellationToken cancellationToken)
+ {
+ var logger = serviceProvider.GetService()?.CreateLogger("PingHandler");
+ logger?.LogDebug("Processing message {MessageType}", "Ping");
+
+ using var activity = MediatorActivitySource.Instance.StartActivity("Ping");
+ activity?.SetTag("messaging.system", "Foundatio.Mediator");
+ activity?.SetTag("messaging.message.type", "Ping");
+
+ string? handlerResult = default;
+ var handlerInstance = GetOrCreateHandler(serviceProvider);
+ handlerResult = await handlerInstance.HandleAsync(message, cancellationToken);
+
+ logger?.LogDebug("Completed processing message {MessageType}", "Ping");
+ activity?.SetStatus(System.Diagnostics.ActivityStatusCode.Ok);
+ return handlerResult;
+ }
+
+ public static async System.Threading.Tasks.ValueTask DefaultHandleAsync(Foundatio.Mediator.IMediator mediator, System.IServiceProvider serviceProvider, Ping message, System.Threading.CancellationToken cancellationToken)
+ {
+ var result = await HandleAsync(serviceProvider, message, cancellationToken).ConfigureAwait(false);
+ return result;
+ }
+
+ public static async ValueTask