11using Microsoft . CodeAnalysis ;
22using Foundatio . Mediator . Utility ;
3+ using System . Text ;
34
45namespace Foundatio . Mediator ;
56
@@ -405,7 +406,7 @@ private static void GenerateAsyncHandleMethod(IndentedStringBuilder source, Hand
405406 {
406407 // Handler returns void
407408 source . AppendLine ( $ "await { stronglyTypedMethodName } (typedMessage, mediator.ServiceProvider, cancellationToken);")
408- . AppendLine ( "return new object() ;" ) ;
409+ . AppendLine ( "return null ;" ) ;
409410 }
410411 }
411412
@@ -495,7 +496,7 @@ private static void GenerateSyncHandleMethod(IndentedStringBuilder source, Handl
495496 {
496497 // Handler returns void
497498 source . AppendLine ( $ "{ stronglyTypedMethodName } (typedMessage, mediator.ServiceProvider, cancellationToken);")
498- . AppendLine ( "return new object() ;" ) ;
499+ . AppendLine ( "return null ;" ) ;
499500 }
500501 }
501502
@@ -581,20 +582,8 @@ private static void GenerateInterceptorMethod(IndentedStringBuilder source, Hand
581582 source . AppendLine ( $ "var result = { ( interceptorIsAsync && wrapperIsAsync ? "await " : "" ) } { stronglyTypedMethodName } (typedMessage, mediator.ServiceProvider, cancellationToken);")
582583 . AppendLine ( ) ;
583584
584- // For tuple types, we need to publish all items except the first (which is the return value)
585- // This is a simplified implementation that assumes Item2 is always published
586- // In a more complete implementation, you would parse the tuple type and generate appropriate logic
587- if ( interceptorIsAsync )
588- {
589- source . AppendLine ( "await mediator.PublishAsync(result.Item2, cancellationToken);" ) ;
590- }
591- else
592- {
593- source . AppendLine ( "mediator.Publish(result.Item2);" ) ;
594- }
595-
596- source . AppendLine ( )
597- . AppendLine ( "return result.Item1;" ) ;
585+ // Generate optimized typed code for tuple handling
586+ GenerateOptimizedTupleHandling ( source , handler , expectedResponseTypeName , interceptorIsAsync ) ;
598587 }
599588 else
600589 {
@@ -763,7 +752,14 @@ private static void GenerateMiddlewareExecutionCore(IndentedStringBuilder source
763752 }
764753
765754 source . AppendLine ( ) ;
766- source . AppendLine ( GetHandlerResultDeclaration ( handler ) ) ;
755+
756+ // Only generate handlerResult if it's actually needed
757+ bool needsHandlerResult = NeedsHandlerResultVariable ( handler , middlewares ) ;
758+ if ( needsHandlerResult )
759+ {
760+ source . AppendLine ( GetHandlerResultDeclaration ( handler ) ) ;
761+ }
762+
767763 source . AppendLine ( GetExceptionDeclaration ( ) ) ;
768764 source . AppendLine ( ) ;
769765 source . AppendLine ( "try" ) ;
@@ -859,7 +855,7 @@ private static void GenerateMiddlewareExecutionCore(IndentedStringBuilder source
859855 {
860856 var methodInfo = middleware . AfterMethod ;
861857 string args = String . Join ( ", " , methodInfo . Parameters . Select ( p =>
862- GenerateMiddlewareParameterExpression ( p , middleware , resultVar ) ) ) ;
858+ GenerateMiddlewareParameterExpression ( p , middleware , resultVar , handler ) ) ) ;
863859 string afterMethodCall = GenerateMiddlewareMethodCall ( middleware , methodInfo , args , middlewareVariableNames [ i ] ) ;
864860 if ( methodInfo . IsAsync )
865861 source . AppendLine ( $ "await { afterMethodCall } ;") ;
@@ -901,7 +897,7 @@ private static void GenerateMiddlewareExecutionCore(IndentedStringBuilder source
901897 {
902898 var methodInfo = middleware . FinallyMethod ;
903899 string args = String . Join ( ", " , methodInfo . Parameters . Select ( p =>
904- GenerateMiddlewareParameterExpression ( p , middleware , resultVar ) ) ) ;
900+ GenerateMiddlewareParameterExpression ( p , middleware , resultVar , handler ) ) ) ;
905901 string finallyMethodCall = GenerateMiddlewareMethodCall ( middleware , methodInfo , args , middlewareVariableNames [ i ] ) ;
906902 if ( methodInfo . IsAsync )
907903 source . AppendLine ( $ "await { finallyMethodCall } ;") ;
@@ -914,6 +910,21 @@ private static void GenerateMiddlewareExecutionCore(IndentedStringBuilder source
914910 source . AppendLine ( "}" ) ;
915911 }
916912
913+ /// <summary>
914+ /// Determines if the handlerResult variable is needed based on handler return type and middleware usage.
915+ /// </summary>
916+ private static bool NeedsHandlerResultVariable ( HandlerInfo handler , List < MiddlewareInfo > middlewares )
917+ {
918+ // For void/Task handlers, never generate handlerResult variable - pass null to middleware instead
919+ if ( IsVoidReturnType ( handler . OriginalReturnTypeName ) )
920+ {
921+ return false ;
922+ }
923+
924+ // For non-void handlers, we always need handlerResult for the return statement
925+ return true ;
926+ }
927+
917928 /// <summary>
918929 /// Gets the proper variable declaration for a handler result with nullable-safe initialization.
919930 /// </summary>
@@ -1052,7 +1063,7 @@ private static bool CanReturnHandlerResult(MiddlewareMethodInfo? method)
10521063 /// Generates the appropriate parameter expression for middleware methods,
10531064 /// including handling tuple field extraction from Before method results.
10541065 /// </summary>
1055- private static string GenerateMiddlewareParameterExpression ( ParameterInfo parameter , MiddlewareInfo middleware , string resultVariableName )
1066+ private static string GenerateMiddlewareParameterExpression ( ParameterInfo parameter , MiddlewareInfo middleware , string resultVariableName , HandlerInfo handler )
10561067 {
10571068 if ( parameter . IsMessage )
10581069 return "message" ;
@@ -1061,7 +1072,12 @@ private static string GenerateMiddlewareParameterExpression(ParameterInfo parame
10611072 if ( parameter . Name == "beforeResult" )
10621073 return resultVariableName ;
10631074 if ( parameter . Name == "handlerResult" )
1075+ {
1076+ // For void/Task handlers, pass null instead of handlerResult variable
1077+ if ( IsVoidReturnType ( handler . OriginalReturnTypeName ) )
1078+ return "null" ;
10641079 return "handlerResult" ;
1080+ }
10651081 if ( parameter . Name == "exception" )
10661082 return "exception" ;
10671083
@@ -1174,6 +1190,160 @@ private static void AddGeneratedFileHeader(IndentedStringBuilder source)
11741190 source . AppendLine ( ) ;
11751191 }
11761192
1193+ private static void GenerateOptimizedTupleHandling ( IndentedStringBuilder source , HandlerInfo handler , string expectedResponseTypeName , bool isAsync )
1194+ {
1195+ // Parse the tuple return type to determine which items to return vs publish
1196+ var tupleFields = ParseTupleReturnType ( handler . ReturnTypeName ) ;
1197+
1198+ if ( tupleFields . Count == 0 )
1199+ {
1200+ // Fallback - shouldn't happen for tuple types
1201+ source . AppendLine ( $ "return default({ expectedResponseTypeName } );") ;
1202+ return ;
1203+ }
1204+
1205+ // Find which tuple item matches the expected response type
1206+ int returnItemIndex = - 1 ;
1207+ var publishItems = new List < int > ( ) ;
1208+
1209+ for ( int i = 0 ; i < tupleFields . Count ; i ++ )
1210+ {
1211+ string fieldType = tupleFields [ i ] ;
1212+
1213+ // Check if this field type matches or is assignable to the expected response type
1214+ if ( IsTypeCompatible ( fieldType , expectedResponseTypeName ) )
1215+ {
1216+ if ( returnItemIndex == - 1 )
1217+ {
1218+ returnItemIndex = i ;
1219+ }
1220+ // If we already found a return item, this becomes a publish item
1221+ else
1222+ {
1223+ publishItems . Add ( i ) ;
1224+ }
1225+ }
1226+ else
1227+ {
1228+ publishItems . Add ( i ) ;
1229+ }
1230+ }
1231+
1232+ // Generate publishing code for non-return items
1233+ source . AppendLine ( "// publish cascading messages" ) ;
1234+ foreach ( int publishIndex in publishItems )
1235+ {
1236+ string itemAccess = $ "result.Item{ publishIndex + 1 } ";
1237+ if ( isAsync )
1238+ {
1239+ source . AppendLine ( $ "await mediator.PublishAsync({ itemAccess } , cancellationToken);") ;
1240+ }
1241+ else
1242+ {
1243+ source . AppendLine ( $ "mediator.PublishAsync({ itemAccess } , CancellationToken.None).GetAwaiter().GetResult();") ;
1244+ }
1245+ }
1246+
1247+ source . AppendLine ( ) ;
1248+ source . AppendLine ( "// return the desired type" ) ;
1249+ if ( returnItemIndex >= 0 )
1250+ {
1251+ source . AppendLine ( $ "return result.Item{ returnItemIndex + 1 } ;") ;
1252+ }
1253+ else
1254+ {
1255+ // No matching item found - return default (this should be caught at compile time)
1256+ source . AppendLine ( $ "return default({ expectedResponseTypeName } )!;") ;
1257+ }
1258+ }
1259+
1260+ private static List < string > ParseTupleReturnType ( string tupleType )
1261+ {
1262+ var fields = new List < string > ( ) ;
1263+
1264+ // Handle ValueTuple syntax: (Type1, Type2, Type3)
1265+ if ( tupleType . StartsWith ( "(" ) && tupleType . EndsWith ( ")" ) )
1266+ {
1267+ string content = tupleType . Substring ( 1 , tupleType . Length - 2 ) . Trim ( ) ;
1268+ fields . AddRange ( SplitTupleFields ( content ) ) ;
1269+ }
1270+ // Handle generic ValueTuple syntax: ValueTuple<Type1, Type2>
1271+ else if ( tupleType . Contains ( "ValueTuple<" ) )
1272+ {
1273+ int startIndex = tupleType . IndexOf ( '<' ) + 1 ;
1274+ int endIndex = tupleType . LastIndexOf ( '>' ) ;
1275+ if ( startIndex > 0 && endIndex > startIndex )
1276+ {
1277+ string typeArgs = tupleType . Substring ( startIndex , endIndex - startIndex ) ;
1278+ fields . AddRange ( SplitTupleFields ( typeArgs ) ) ;
1279+ }
1280+ }
1281+
1282+ return fields ;
1283+ }
1284+
1285+ private static List < string > SplitTupleFields ( string content )
1286+ {
1287+ var fields = new List < string > ( ) ;
1288+ var current = new StringBuilder ( ) ;
1289+ int depth = 0 ;
1290+
1291+ for ( int i = 0 ; i < content . Length ; i ++ )
1292+ {
1293+ char c = content [ i ] ;
1294+
1295+ if ( c == ',' && depth == 0 )
1296+ {
1297+ string field = current . ToString ( ) . Trim ( ) ;
1298+ if ( ! string . IsNullOrEmpty ( field ) )
1299+ {
1300+ // Extract just the type name, ignoring field names
1301+ string [ ] parts = field . Split ( ' ' ) ;
1302+ fields . Add ( parts [ 0 ] ) ; // First part is the type
1303+ }
1304+ current . Clear ( ) ;
1305+ }
1306+ else
1307+ {
1308+ if ( c == '<' || c == '(' ) depth ++ ;
1309+ else if ( c == '>' || c == ')' ) depth -- ;
1310+ current . Append ( c ) ;
1311+ }
1312+ }
1313+
1314+ // Add the last field
1315+ string lastField = current . ToString ( ) . Trim ( ) ;
1316+ if ( ! string . IsNullOrEmpty ( lastField ) )
1317+ {
1318+ string [ ] parts = lastField . Split ( ' ' ) ;
1319+ fields . Add ( parts [ 0 ] ) ; // First part is the type
1320+ }
1321+
1322+ return fields ;
1323+ }
1324+
1325+ private static bool IsTypeCompatible ( string fieldType , string expectedType )
1326+ {
1327+ // Direct match
1328+ if ( fieldType == expectedType )
1329+ return true ;
1330+
1331+ // Handle common namespace variations
1332+ string normalizedFieldType = NormalizeTypeName ( fieldType ) ;
1333+ string normalizedExpectedType = NormalizeTypeName ( expectedType ) ;
1334+
1335+ return normalizedFieldType == normalizedExpectedType ;
1336+ }
1337+
1338+ private static string NormalizeTypeName ( string typeName )
1339+ {
1340+ // Remove common namespace prefixes for comparison
1341+ return typeName
1342+ . Replace ( "System." , "" )
1343+ . Replace ( "global::" , "" )
1344+ . Trim ( ) ;
1345+ }
1346+
11771347 private static void GenerateHandleTupleResult ( IndentedStringBuilder source )
11781348 {
11791349 source . AppendLine ( )
0 commit comments