Skip to content

Commit 402c2c2

Browse files
committed
Minor
1 parent 82319ad commit 402c2c2

File tree

5 files changed

+53
-99
lines changed

5 files changed

+53
-99
lines changed

samples/ConsoleSample/Middleware/CommandMiddleware.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@ namespace ConsoleSample.Middleware;
44

55
public class CommandMiddleware
66
{
7-
public Task<HandlerResult> BeforeAsync(ICommand command, CancellationToken cancellationToken)
7+
public void Before(ICommand command)
88
{
99
Console.WriteLine($"📋 [CommandMiddleware] Before: Processing command of type {command.GetType().Name}");
10-
return Task.FromResult(HandlerResult.Continue());
1110
}
1211
}

samples/ConsoleSample/Middleware/GlobalMiddleware.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ public class GlobalMiddleware
88
return (DateTime.UtcNow, DateTime.UtcNow.TimeOfDay);
99
}
1010

11-
public void Finally(object message, Exception? exception, CancellationToken cancellationToken)
11+
public void Finally(object message, Exception? exception, DateTime date)
1212
{
1313
if (exception != null)
1414
{
15-
Console.WriteLine($"🌍 [GlobalMiddleware] Finally: Error processing {message.GetType().Name}: {exception.Message}");
15+
Console.WriteLine($"🌍 [GlobalMiddleware] Finally: Error processing {message.GetType().Name} at {date}: {exception.Message}");
1616
}
1717
else
1818
{
19-
Console.WriteLine($"🌍 [GlobalMiddleware] Finally: Successfully completed {message.GetType().Name}");
19+
Console.WriteLine($"🌍 [GlobalMiddleware] Finally: Successfully completed {message.GetType().Name} at {date}");
2020
}
2121
}
2222
}

src/Foundatio.Mediator.SourceGenerator/HandlerWrapperGenerator.cs

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,11 @@ namespace Foundatio.Mediator;
55

66
internal static class HandlerWrapperGenerator
77
{
8-
private static void AddGeneratedFileHeader(StringBuilder source)
9-
{
10-
source.AppendLine("// <auto-generated>");
11-
source.AppendLine("// This file was generated by Foundatio.Mediator source generators.");
12-
source.AppendLine("// Changes to this file may be lost when the code is regenerated.");
13-
source.AppendLine("// </auto-generated>");
14-
source.AppendLine();
15-
source.AppendLine("using System.Diagnostics;");
16-
source.AppendLine("using System.Diagnostics.CodeAnalysis;");
17-
source.AppendLine();
18-
}
19-
208
public static void GenerateHandlerWrappers(List<HandlerInfo> handlers, List<MiddlewareInfo> middlewares, List<CallSiteInfo> callSites, bool interceptorsEnabled, SourceProductionContext context)
219
{
2210
// Group call sites by message type for easier lookup
2311
var callSitesByMessage = callSites
24-
.Where(cs => !cs.MethodName.StartsWith("Publish")) // Only process Invoke calls for interceptors
12+
.Where(cs => !cs.IsPublish)
2513
.GroupBy(cs => cs.MessageTypeName)
2614
.ToDictionary(g => g.Key, g => g.ToList());
2715

@@ -30,19 +18,22 @@ public static void GenerateHandlerWrappers(List<HandlerInfo> handlers, List<Midd
3018
string wrapperClassName = GetWrapperClassName(handler);
3119

3220
callSitesByMessage.TryGetValue(handler.MessageTypeName, out var handlerCallSites);
33-
handlerCallSites ??= new List<CallSiteInfo>();
21+
handlerCallSites ??= [];
3422

35-
string source = GenerateStaticHandlerWrapper(handler, wrapperClassName, middlewares, handlerCallSites, interceptorsEnabled);
23+
var applicableMiddlewares = GetApplicableMiddlewares(middlewares, handler);
24+
25+
string source = GenerateHandlerWrapper(handler, wrapperClassName, applicableMiddlewares, handlerCallSites, interceptorsEnabled);
3626
string fileName = $"{wrapperClassName}.g.cs";
3727
context.AddSource(fileName, source);
3828
}
3929
}
4030

41-
public static string GenerateStaticHandlerWrapper(HandlerInfo handler, string wrapperClassName, List<MiddlewareInfo> middlewares, List<CallSiteInfo> callSites, bool interceptorsEnabled)
31+
public static string GenerateHandlerWrapper(HandlerInfo handler, string wrapperClassName, List<MiddlewareInfo> middlewares, List<CallSiteInfo> callSites, bool interceptorsEnabled)
4232
{
4333
var source = new StringBuilder();
4434

4535
AddGeneratedFileHeader(source);
36+
4637
source.AppendLine("#nullable enable");
4738
source.AppendLine("using System;");
4839
source.AppendLine("using System.Threading;");
@@ -59,10 +50,7 @@ public static string GenerateStaticHandlerWrapper(HandlerInfo handler, string wr
5950
// Generate strongly typed method that matches handler signature
6051
GenerateStronglyTypedMethod(source, handler, middlewares);
6152

62-
// Determine if we need async handle method based on handler or middleware
63-
var applicableMiddlewares = GetApplicableMiddlewares(middlewares, handler);
64-
bool hasAsyncMiddleware = applicableMiddlewares.Any(m => m.IsAsync);
65-
53+
bool hasAsyncMiddleware = middlewares.Any(m => m.IsAsync);
6654
bool needsAsyncHandleMethod = handler.IsAsync || hasAsyncMiddleware;
6755

6856
// Generate single generic method based on effective async status
@@ -149,13 +137,9 @@ private static void GenerateStronglyTypedMethod(StringBuilder source, HandlerInf
149137
{
150138
string stronglyTypedMethodName = GetStronglyTypedMethodName(handler);
151139

152-
// Get applicable middlewares for this handler
153-
var applicableMiddlewares = GetApplicableMiddlewares(middlewares, handler);
154-
155140
// For the strongly typed method, we need to preserve the original method signature
156141
// but make it async if we have async middleware or the handler is async
157-
bool hasAsyncMiddleware = applicableMiddlewares.Any(m => m.IsAsync);
158-
142+
bool hasAsyncMiddleware = middlewares.Any(m => m.IsAsync);
159143

160144
string returnType = ReconstructOriginalReturnType(handler, hasAsyncMiddleware);
161145
bool isAsync = handler.IsAsync || hasAsyncMiddleware;
@@ -165,10 +149,10 @@ private static void GenerateStronglyTypedMethod(StringBuilder source, HandlerInf
165149
source.AppendLine($" public static {asyncModifier}{returnType} {stronglyTypedMethodName}({handler.MessageTypeName} message, IServiceProvider serviceProvider, CancellationToken cancellationToken)");
166150
source.AppendLine(" {");
167151

168-
if (applicableMiddlewares.Any())
152+
if (middlewares.Any())
169153
{
170154
// Generate middleware-aware execution
171-
GenerateMiddlewareAwareExecution(source, handler, applicableMiddlewares, stronglyTypedMethodName);
155+
GenerateMiddlewareAwareExecution(source, handler, middlewares, stronglyTypedMethodName);
172156
}
173157
else
174158
{
@@ -448,18 +432,17 @@ private static void GenerateInterceptorMethods(StringBuilder source, HandlerInfo
448432
var key = group.Key;
449433
var groupCallSites = group.ToList();
450434

451-
GenerateInterceptorMethod(source, handler, key.MethodName, key.MessageTypeName, key.ExpectedResponseTypeName, groupCallSites, methodCounter++, middlewares);
435+
GenerateInterceptorMethod(source, handler, key.MethodName, key.ExpectedResponseTypeName, groupCallSites, methodCounter++, middlewares);
452436
}
453437
}
454438

455-
private static void GenerateInterceptorMethod(StringBuilder source, HandlerInfo handler, string methodName, string messageTypeName, string expectedResponseTypeName, List<CallSiteInfo> callSites, int methodIndex, List<MiddlewareInfo> middlewares)
439+
private static void GenerateInterceptorMethod(StringBuilder source, HandlerInfo handler, string methodName, string expectedResponseTypeName, List<CallSiteInfo> callSites, int methodIndex, List<MiddlewareInfo> middlewares)
456440
{
457441
// Generate unique method name for the interceptor
458442
string interceptorMethodName = $"Intercept{methodName}{methodIndex}";
459443

460444
// Determine if the wrapper method is async (either because the handler is async OR because there are async middleware)
461-
var applicableMiddlewares = GetApplicableMiddlewares(middlewares, handler);
462-
bool hasAsyncMiddleware = applicableMiddlewares.Any(m =>
445+
bool hasAsyncMiddleware = middlewares.Any(m =>
463446
(m.BeforeMethod?.IsAsync == true) ||
464447
(m.AfterMethod?.IsAsync == true) ||
465448
(m.FinallyMethod?.IsAsync == true));
@@ -480,7 +463,7 @@ private static void GenerateInterceptorMethod(StringBuilder source, HandlerInfo
480463

481464
// Generate interceptor attributes for all call sites
482465
var interceptorAttributes = callSites
483-
.Select(cs => GenerateInterceptorAttribute(cs))
466+
.Select(GenerateInterceptorAttribute)
484467
.Where(attr => !String.IsNullOrEmpty(attr))
485468
.ToList();
486469

@@ -1528,4 +1511,16 @@ private static string GenerateSingleMiddlewareParameterExpression(ParameterInfo
15281511
// Fall back to DI resolution
15291512
return $"serviceProvider.GetRequiredService<{parameter.TypeName}>()";
15301513
}
1514+
1515+
private static void AddGeneratedFileHeader(StringBuilder source)
1516+
{
1517+
source.AppendLine("// <auto-generated>");
1518+
source.AppendLine("// This file was generated by Foundatio.Mediator source generators.");
1519+
source.AppendLine("// Changes to this file may be lost when the code is regenerated.");
1520+
source.AppendLine("// </auto-generated>");
1521+
source.AppendLine();
1522+
source.AppendLine("using System.Diagnostics;");
1523+
source.AppendLine("using System.Diagnostics.CodeAnalysis;");
1524+
source.AppendLine();
1525+
}
15311526
}

src/Foundatio.Mediator.SourceGenerator/MediatorGenerator.cs

Lines changed: 19 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@ public sealed class MediatorGenerator : IIncrementalGenerator
1212
{
1313
public void Initialize(IncrementalGeneratorInitializationContext context)
1414
{
15-
// Check if interceptors are enabled
16-
var interceptorsEnabled = context.AnalyzerConfigOptionsProvider
17-
.Select(static (provider, _) => IsInterceptorsEnabled(provider));
18-
1915
var interceptionEnabledSetting = context.AnalyzerConfigOptionsProvider
2016
.Select((x, _) =>
2117
x.GlobalOptions.TryGetValue($"build_property.{Constants.EnabledPropertyName}", out string? enableSwitch)
@@ -28,36 +24,38 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
2824
.Combine(csharpSufficient)
2925
.WithTrackingName(TrackingNames.Settings);
3026

31-
// Find all handler classes and their methods
32-
var handlerDeclarations = context.SyntaxProvider
27+
var interceptionEnabled = settings
28+
.Select((x, _) => x is { Left: true, Right: true });
29+
30+
var callSites = context.SyntaxProvider
3331
.CreateSyntaxProvider(
34-
predicate: static (s, _) => IsPotentialHandlerClass(s),
35-
transform: static (ctx, _) => GetHandlersForGeneration(ctx))
36-
.Where(static m => m is not null && m.Count > 0)
37-
.SelectMany(static (handlers, _) => handlers ?? []); // Flatten the collections
32+
predicate: static (s, _) => CallSiteAnalyzer.IsPotentialMediatorCall(s),
33+
transform: static (ctx, _) => CallSiteAnalyzer.GetCallSiteForAnalysis(ctx))
34+
.Where(static cs => cs.HasValue)
35+
.Select(static (cs, _) => cs!.Value)
36+
.WithTrackingName(TrackingNames.CallSites);
3837

39-
// Find all middleware classes and their methods
4038
var middlewareDeclarations = context.SyntaxProvider
4139
.CreateSyntaxProvider(
4240
predicate: static (s, _) => MiddlewareGenerator.IsPotentialMiddlewareClass(s),
4341
transform: static (ctx, _) => MiddlewareGenerator.GetMiddlewareForGeneration(ctx))
4442
.Where(static m => m is not null && m.Count > 0)
45-
.SelectMany(static (middlewares, _) => middlewares ?? []); // Flatten the collections
43+
.SelectMany(static (middlewares, _) => middlewares ?? [])
44+
.WithTrackingName(TrackingNames.Middleware);
4645

47-
// Find all mediator call sites
48-
var callSites = context.SyntaxProvider
46+
var handlerDeclarations = context.SyntaxProvider
4947
.CreateSyntaxProvider(
50-
predicate: static (s, _) => CallSiteAnalyzer.IsPotentialMediatorCall(s),
51-
transform: static (ctx, _) => CallSiteAnalyzer.GetCallSiteForAnalysis(ctx))
52-
.Where(static cs => cs.HasValue)
53-
.Select(static (cs, _) => cs!.Value);
48+
predicate: static (s, _) => IsPotentialHandlerClass(s),
49+
transform: static (ctx, _) => GetHandlersForGeneration(ctx))
50+
.Where(static m => m is not null && m.Count > 0)
51+
.SelectMany(static (handlers, _) => handlers ?? [])
52+
.WithTrackingName(TrackingNames.Handlers);
5453

55-
// Combine handlers, middleware, call sites, interceptor availability and generate everything
5654
var compilationAndData = context.CompilationProvider
5755
.Combine(handlerDeclarations.Collect())
5856
.Combine(middlewareDeclarations.Collect())
5957
.Combine(callSites.Collect())
60-
.Combine(interceptorsEnabled);
58+
.Combine(interceptionEnabled);
6159

6260
context.RegisterSourceOutput(compilationAndData,
6361
static (spc, source) => Execute(source.Left.Left.Left.Left, source.Left.Left.Left.Right, source.Left.Left.Right, source.Left.Right, source.Right, spc));
@@ -69,38 +67,6 @@ private static bool IsPotentialHandlerClass(SyntaxNode node)
6967
&& (name.EndsWith("Handler") || name.EndsWith("Consumer"));
7068
}
7169

72-
private static bool IsInterceptorsEnabled(AnalyzerConfigOptionsProvider provider)
73-
{
74-
// First, check for explicit InterceptorsNamespaces property (preferred approach)
75-
string[] propertyNames =
76-
[
77-
"build_property.InterceptorsNamespaces",
78-
"build_property.InterceptorsPreviewNamespaces"
79-
];
80-
81-
foreach (string propertyName in propertyNames)
82-
{
83-
if (provider.GlobalOptions.TryGetValue(propertyName, out string? value) &&
84-
!String.IsNullOrEmpty(value))
85-
{
86-
return true;
87-
}
88-
}
89-
90-
// For .NET 9+, interceptors are stable - enable if explicitly configured or targeting net9.0+
91-
if (provider.GlobalOptions.TryGetValue("build_property.TargetFramework", out string? targetFramework))
92-
{
93-
// Enable for .NET 9+ as interceptors are stable there
94-
if (targetFramework?.StartsWith("net9") == true ||
95-
targetFramework?.StartsWith("net1") == true) // net10+
96-
{
97-
return true;
98-
}
99-
}
100-
101-
return false;
102-
}
103-
10470
private static List<HandlerInfo>? GetHandlersForGeneration(GeneratorSyntaxContext context)
10571
{
10672
var classDeclaration = (ClassDeclarationSyntax)context.Node;
@@ -276,7 +242,7 @@ private static void Execute(Compilation compilation, ImmutableArray<HandlerInfo>
276242
return;
277243

278244
var validHandlers = handlers.ToList();
279-
var validMiddlewares = middlewares.IsDefaultOrEmpty ? new List<MiddlewareInfo>() : middlewares.ToList();
245+
var validMiddlewares = middlewares.IsDefaultOrEmpty ? [] : middlewares.ToList();
280246

281247
if (validHandlers.Count == 0)
282248
return;

src/Foundatio.Mediator.SourceGenerator/Utility/TrackingNames.cs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,8 @@
22

33
public class TrackingNames
44
{
5-
public const string InitialExtraction = nameof(InitialExtraction);
6-
public const string InitialExternalExtraction = nameof(InitialExternalExtraction);
7-
public const string InitialInterceptable = nameof(InitialInterceptable);
8-
public const string InitialInterceptableOnly = nameof(InitialInterceptableOnly);
9-
public const string RemovingNulls = nameof(RemovingNulls);
10-
public const string InterceptedLocations = nameof(InterceptedLocations);
5+
public const string Middleware = nameof(Middleware);
6+
public const string CallSites = nameof(CallSites);
7+
public const string Handlers = nameof(Handlers);
118
public const string Settings = nameof(Settings);
12-
public const string EnumInterceptions = nameof(EnumInterceptions);
13-
public const string ExternalInterceptions = nameof(ExternalInterceptions);
14-
public const string AdditionalInterceptions = nameof(AdditionalInterceptions);
159
}

0 commit comments

Comments
 (0)