@@ -5,23 +5,11 @@ namespace Foundatio.Mediator;
55
66internal static class HandlerWrapperGenerator
77{
8- private static void AddGeneratedFileHeader ( StringBuilder source )
9- {
10- source . AppendLine ( "// <auto-generated>" ) ;
11- source . AppendLine ( "// This file was generated by Foundatio.Mediator source generators." ) ;
12- source . AppendLine ( "// Changes to this file may be lost when the code is regenerated." ) ;
13- source . AppendLine ( "// </auto-generated>" ) ;
14- source . AppendLine ( ) ;
15- source . AppendLine ( "using System.Diagnostics;" ) ;
16- source . AppendLine ( "using System.Diagnostics.CodeAnalysis;" ) ;
17- source . AppendLine ( ) ;
18- }
19-
208 public static void GenerateHandlerWrappers ( List < HandlerInfo > handlers , List < MiddlewareInfo > middlewares , List < CallSiteInfo > callSites , bool interceptorsEnabled , SourceProductionContext context )
219 {
2210 // Group call sites by message type for easier lookup
2311 var callSitesByMessage = callSites
24- . Where ( cs => ! cs . MethodName . StartsWith ( "Publish" ) ) // Only process Invoke calls for interceptors
12+ . Where ( cs => ! cs . IsPublish )
2513 . GroupBy ( cs => cs . MessageTypeName )
2614 . ToDictionary ( g => g . Key , g => g . ToList ( ) ) ;
2715
@@ -30,19 +18,22 @@ public static void GenerateHandlerWrappers(List<HandlerInfo> handlers, List<Midd
3018 string wrapperClassName = GetWrapperClassName ( handler ) ;
3119
3220 callSitesByMessage . TryGetValue ( handler . MessageTypeName , out var handlerCallSites ) ;
33- handlerCallSites ??= new List < CallSiteInfo > ( ) ;
21+ handlerCallSites ??= [ ] ;
3422
35- string source = GenerateStaticHandlerWrapper ( handler , wrapperClassName , middlewares , handlerCallSites , interceptorsEnabled ) ;
23+ var applicableMiddlewares = GetApplicableMiddlewares ( middlewares , handler ) ;
24+
25+ string source = GenerateHandlerWrapper ( handler , wrapperClassName , applicableMiddlewares , handlerCallSites , interceptorsEnabled ) ;
3626 string fileName = $ "{ wrapperClassName } .g.cs";
3727 context . AddSource ( fileName , source ) ;
3828 }
3929 }
4030
41- public static string GenerateStaticHandlerWrapper ( HandlerInfo handler , string wrapperClassName , List < MiddlewareInfo > middlewares , List < CallSiteInfo > callSites , bool interceptorsEnabled )
31+ public static string GenerateHandlerWrapper ( HandlerInfo handler , string wrapperClassName , List < MiddlewareInfo > middlewares , List < CallSiteInfo > callSites , bool interceptorsEnabled )
4232 {
4333 var source = new StringBuilder ( ) ;
4434
4535 AddGeneratedFileHeader ( source ) ;
36+
4637 source . AppendLine ( "#nullable enable" ) ;
4738 source . AppendLine ( "using System;" ) ;
4839 source . AppendLine ( "using System.Threading;" ) ;
@@ -59,10 +50,7 @@ public static string GenerateStaticHandlerWrapper(HandlerInfo handler, string wr
5950 // Generate strongly typed method that matches handler signature
6051 GenerateStronglyTypedMethod ( source , handler , middlewares ) ;
6152
62- // Determine if we need async handle method based on handler or middleware
63- var applicableMiddlewares = GetApplicableMiddlewares ( middlewares , handler ) ;
64- bool hasAsyncMiddleware = applicableMiddlewares . Any ( m => m . IsAsync ) ;
65-
53+ bool hasAsyncMiddleware = middlewares . Any ( m => m . IsAsync ) ;
6654 bool needsAsyncHandleMethod = handler . IsAsync || hasAsyncMiddleware ;
6755
6856 // Generate single generic method based on effective async status
@@ -149,13 +137,9 @@ private static void GenerateStronglyTypedMethod(StringBuilder source, HandlerInf
149137 {
150138 string stronglyTypedMethodName = GetStronglyTypedMethodName ( handler ) ;
151139
152- // Get applicable middlewares for this handler
153- var applicableMiddlewares = GetApplicableMiddlewares ( middlewares , handler ) ;
154-
155140 // For the strongly typed method, we need to preserve the original method signature
156141 // but make it async if we have async middleware or the handler is async
157- bool hasAsyncMiddleware = applicableMiddlewares . Any ( m => m . IsAsync ) ;
158-
142+ bool hasAsyncMiddleware = middlewares . Any ( m => m . IsAsync ) ;
159143
160144 string returnType = ReconstructOriginalReturnType ( handler , hasAsyncMiddleware ) ;
161145 bool isAsync = handler . IsAsync || hasAsyncMiddleware ;
@@ -165,10 +149,10 @@ private static void GenerateStronglyTypedMethod(StringBuilder source, HandlerInf
165149 source . AppendLine ( $ " public static { asyncModifier } { returnType } { stronglyTypedMethodName } ({ handler . MessageTypeName } message, IServiceProvider serviceProvider, CancellationToken cancellationToken)") ;
166150 source . AppendLine ( " {" ) ;
167151
168- if ( applicableMiddlewares . Any ( ) )
152+ if ( middlewares . Any ( ) )
169153 {
170154 // Generate middleware-aware execution
171- GenerateMiddlewareAwareExecution ( source , handler , applicableMiddlewares , stronglyTypedMethodName ) ;
155+ GenerateMiddlewareAwareExecution ( source , handler , middlewares , stronglyTypedMethodName ) ;
172156 }
173157 else
174158 {
@@ -448,18 +432,17 @@ private static void GenerateInterceptorMethods(StringBuilder source, HandlerInfo
448432 var key = group . Key ;
449433 var groupCallSites = group . ToList ( ) ;
450434
451- GenerateInterceptorMethod ( source , handler , key . MethodName , key . MessageTypeName , key . ExpectedResponseTypeName , groupCallSites , methodCounter ++ , middlewares ) ;
435+ GenerateInterceptorMethod ( source , handler , key . MethodName , key . ExpectedResponseTypeName , groupCallSites , methodCounter ++ , middlewares ) ;
452436 }
453437 }
454438
455- private static void GenerateInterceptorMethod ( StringBuilder source , HandlerInfo handler , string methodName , string messageTypeName , string expectedResponseTypeName , List < CallSiteInfo > callSites , int methodIndex , List < MiddlewareInfo > middlewares )
439+ private static void GenerateInterceptorMethod ( StringBuilder source , HandlerInfo handler , string methodName , string expectedResponseTypeName , List < CallSiteInfo > callSites , int methodIndex , List < MiddlewareInfo > middlewares )
456440 {
457441 // Generate unique method name for the interceptor
458442 string interceptorMethodName = $ "Intercept{ methodName } { methodIndex } ";
459443
460444 // Determine if the wrapper method is async (either because the handler is async OR because there are async middleware)
461- var applicableMiddlewares = GetApplicableMiddlewares ( middlewares , handler ) ;
462- bool hasAsyncMiddleware = applicableMiddlewares . Any ( m =>
445+ bool hasAsyncMiddleware = middlewares . Any ( m =>
463446 ( m . BeforeMethod ? . IsAsync == true ) ||
464447 ( m . AfterMethod ? . IsAsync == true ) ||
465448 ( m . FinallyMethod ? . IsAsync == true ) ) ;
@@ -480,7 +463,7 @@ private static void GenerateInterceptorMethod(StringBuilder source, HandlerInfo
480463
481464 // Generate interceptor attributes for all call sites
482465 var interceptorAttributes = callSites
483- . Select ( cs => GenerateInterceptorAttribute ( cs ) )
466+ . Select ( GenerateInterceptorAttribute )
484467 . Where ( attr => ! String . IsNullOrEmpty ( attr ) )
485468 . ToList ( ) ;
486469
@@ -1528,4 +1511,16 @@ private static string GenerateSingleMiddlewareParameterExpression(ParameterInfo
15281511 // Fall back to DI resolution
15291512 return $ "serviceProvider.GetRequiredService<{ parameter . TypeName } >()";
15301513 }
1514+
1515+ private static void AddGeneratedFileHeader ( StringBuilder source )
1516+ {
1517+ source . AppendLine ( "// <auto-generated>" ) ;
1518+ source . AppendLine ( "// This file was generated by Foundatio.Mediator source generators." ) ;
1519+ source . AppendLine ( "// Changes to this file may be lost when the code is regenerated." ) ;
1520+ source . AppendLine ( "// </auto-generated>" ) ;
1521+ source . AppendLine ( ) ;
1522+ source . AppendLine ( "using System.Diagnostics;" ) ;
1523+ source . AppendLine ( "using System.Diagnostics.CodeAnalysis;" ) ;
1524+ source . AppendLine ( ) ;
1525+ }
15311526}
0 commit comments