Skip to content

Commit 3be33bc

Browse files
committed
More code gen
1 parent 42cd144 commit 3be33bc

File tree

2 files changed

+191
-21
lines changed

2 files changed

+191
-21
lines changed

benchmarks/Foundatio.Mediator.Benchmarks/Foundatio.Mediator.Benchmarks.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
<!-- Emit compiler generated files for debugging purposes -->
1111
<CompilerGeneratedFilesOutputPath>Generated</CompilerGeneratedFilesOutputPath>
1212
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
13-
<DisableMediatorInterceptors>false</DisableMediatorInterceptors>
13+
<DisableMediatorInterceptors>true</DisableMediatorInterceptors>
1414
<InterceptorsNamespaces>$(InterceptorsNamespaces);Foundatio.Mediator</InterceptorsNamespaces>
1515
</PropertyGroup>
1616

src/Foundatio.Mediator.SourceGenerator/HandlerWrapperGenerator.cs

Lines changed: 190 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.CodeAnalysis;
22
using Foundatio.Mediator.Utility;
3+
using System.Text;
34

45
namespace Foundatio.Mediator;
56

@@ -405,7 +406,7 @@ private static void GenerateAsyncHandleMethod(IndentedStringBuilder source, Hand
405406
{
406407
// Handler returns void
407408
source.AppendLine($"await {stronglyTypedMethodName}(typedMessage, mediator.ServiceProvider, cancellationToken);")
408-
.AppendLine("return new object();");
409+
.AppendLine("return null;");
409410
}
410411
}
411412

@@ -495,7 +496,7 @@ private static void GenerateSyncHandleMethod(IndentedStringBuilder source, Handl
495496
{
496497
// Handler returns void
497498
source.AppendLine($"{stronglyTypedMethodName}(typedMessage, mediator.ServiceProvider, cancellationToken);")
498-
.AppendLine("return new object();");
499+
.AppendLine("return null;");
499500
}
500501
}
501502

@@ -581,20 +582,8 @@ private static void GenerateInterceptorMethod(IndentedStringBuilder source, Hand
581582
source.AppendLine($"var result = {(interceptorIsAsync && wrapperIsAsync ? "await " : "")}{stronglyTypedMethodName}(typedMessage, mediator.ServiceProvider, cancellationToken);")
582583
.AppendLine();
583584

584-
// For tuple types, we need to publish all items except the first (which is the return value)
585-
// This is a simplified implementation that assumes Item2 is always published
586-
// In a more complete implementation, you would parse the tuple type and generate appropriate logic
587-
if (interceptorIsAsync)
588-
{
589-
source.AppendLine("await mediator.PublishAsync(result.Item2, cancellationToken);");
590-
}
591-
else
592-
{
593-
source.AppendLine("mediator.Publish(result.Item2);");
594-
}
595-
596-
source.AppendLine()
597-
.AppendLine("return result.Item1;");
585+
// Generate optimized typed code for tuple handling
586+
GenerateOptimizedTupleHandling(source, handler, expectedResponseTypeName, interceptorIsAsync);
598587
}
599588
else
600589
{
@@ -763,7 +752,14 @@ private static void GenerateMiddlewareExecutionCore(IndentedStringBuilder source
763752
}
764753

765754
source.AppendLine();
766-
source.AppendLine(GetHandlerResultDeclaration(handler));
755+
756+
// Only generate handlerResult if it's actually needed
757+
bool needsHandlerResult = NeedsHandlerResultVariable(handler, middlewares);
758+
if (needsHandlerResult)
759+
{
760+
source.AppendLine(GetHandlerResultDeclaration(handler));
761+
}
762+
767763
source.AppendLine(GetExceptionDeclaration());
768764
source.AppendLine();
769765
source.AppendLine("try");
@@ -859,7 +855,7 @@ private static void GenerateMiddlewareExecutionCore(IndentedStringBuilder source
859855
{
860856
var methodInfo = middleware.AfterMethod;
861857
string args = String.Join(", ", methodInfo.Parameters.Select(p =>
862-
GenerateMiddlewareParameterExpression(p, middleware, resultVar)));
858+
GenerateMiddlewareParameterExpression(p, middleware, resultVar, handler)));
863859
string afterMethodCall = GenerateMiddlewareMethodCall(middleware, methodInfo, args, middlewareVariableNames[i]);
864860
if (methodInfo.IsAsync)
865861
source.AppendLine($"await {afterMethodCall};");
@@ -901,7 +897,7 @@ private static void GenerateMiddlewareExecutionCore(IndentedStringBuilder source
901897
{
902898
var methodInfo = middleware.FinallyMethod;
903899
string args = String.Join(", ", methodInfo.Parameters.Select(p =>
904-
GenerateMiddlewareParameterExpression(p, middleware, resultVar)));
900+
GenerateMiddlewareParameterExpression(p, middleware, resultVar, handler)));
905901
string finallyMethodCall = GenerateMiddlewareMethodCall(middleware, methodInfo, args, middlewareVariableNames[i]);
906902
if (methodInfo.IsAsync)
907903
source.AppendLine($"await {finallyMethodCall};");
@@ -914,6 +910,21 @@ private static void GenerateMiddlewareExecutionCore(IndentedStringBuilder source
914910
source.AppendLine("}");
915911
}
916912

913+
/// <summary>
914+
/// Determines if the handlerResult variable is needed based on handler return type and middleware usage.
915+
/// </summary>
916+
private static bool NeedsHandlerResultVariable(HandlerInfo handler, List<MiddlewareInfo> middlewares)
917+
{
918+
// For void/Task handlers, never generate handlerResult variable - pass null to middleware instead
919+
if (IsVoidReturnType(handler.OriginalReturnTypeName))
920+
{
921+
return false;
922+
}
923+
924+
// For non-void handlers, we always need handlerResult for the return statement
925+
return true;
926+
}
927+
917928
/// <summary>
918929
/// Gets the proper variable declaration for a handler result with nullable-safe initialization.
919930
/// </summary>
@@ -1052,7 +1063,7 @@ private static bool CanReturnHandlerResult(MiddlewareMethodInfo? method)
10521063
/// Generates the appropriate parameter expression for middleware methods,
10531064
/// including handling tuple field extraction from Before method results.
10541065
/// </summary>
1055-
private static string GenerateMiddlewareParameterExpression(ParameterInfo parameter, MiddlewareInfo middleware, string resultVariableName)
1066+
private static string GenerateMiddlewareParameterExpression(ParameterInfo parameter, MiddlewareInfo middleware, string resultVariableName, HandlerInfo handler)
10561067
{
10571068
if (parameter.IsMessage)
10581069
return "message";
@@ -1061,7 +1072,12 @@ private static string GenerateMiddlewareParameterExpression(ParameterInfo parame
10611072
if (parameter.Name == "beforeResult")
10621073
return resultVariableName;
10631074
if (parameter.Name == "handlerResult")
1075+
{
1076+
// For void/Task handlers, pass null instead of handlerResult variable
1077+
if (IsVoidReturnType(handler.OriginalReturnTypeName))
1078+
return "null";
10641079
return "handlerResult";
1080+
}
10651081
if (parameter.Name == "exception")
10661082
return "exception";
10671083

@@ -1174,6 +1190,160 @@ private static void AddGeneratedFileHeader(IndentedStringBuilder source)
11741190
source.AppendLine();
11751191
}
11761192

1193+
private static void GenerateOptimizedTupleHandling(IndentedStringBuilder source, HandlerInfo handler, string expectedResponseTypeName, bool isAsync)
1194+
{
1195+
// Parse the tuple return type to determine which items to return vs publish
1196+
var tupleFields = ParseTupleReturnType(handler.ReturnTypeName);
1197+
1198+
if (tupleFields.Count == 0)
1199+
{
1200+
// Fallback - shouldn't happen for tuple types
1201+
source.AppendLine($"return default({expectedResponseTypeName});");
1202+
return;
1203+
}
1204+
1205+
// Find which tuple item matches the expected response type
1206+
int returnItemIndex = -1;
1207+
var publishItems = new List<int>();
1208+
1209+
for (int i = 0; i < tupleFields.Count; i++)
1210+
{
1211+
string fieldType = tupleFields[i];
1212+
1213+
// Check if this field type matches or is assignable to the expected response type
1214+
if (IsTypeCompatible(fieldType, expectedResponseTypeName))
1215+
{
1216+
if (returnItemIndex == -1)
1217+
{
1218+
returnItemIndex = i;
1219+
}
1220+
// If we already found a return item, this becomes a publish item
1221+
else
1222+
{
1223+
publishItems.Add(i);
1224+
}
1225+
}
1226+
else
1227+
{
1228+
publishItems.Add(i);
1229+
}
1230+
}
1231+
1232+
// Generate publishing code for non-return items
1233+
source.AppendLine("// publish cascading messages");
1234+
foreach (int publishIndex in publishItems)
1235+
{
1236+
string itemAccess = $"result.Item{publishIndex + 1}";
1237+
if (isAsync)
1238+
{
1239+
source.AppendLine($"await mediator.PublishAsync({itemAccess}, cancellationToken);");
1240+
}
1241+
else
1242+
{
1243+
source.AppendLine($"mediator.PublishAsync({itemAccess}, CancellationToken.None).GetAwaiter().GetResult();");
1244+
}
1245+
}
1246+
1247+
source.AppendLine();
1248+
source.AppendLine("// return the desired type");
1249+
if (returnItemIndex >= 0)
1250+
{
1251+
source.AppendLine($"return result.Item{returnItemIndex + 1};");
1252+
}
1253+
else
1254+
{
1255+
// No matching item found - return default (this should be caught at compile time)
1256+
source.AppendLine($"return default({expectedResponseTypeName})!;");
1257+
}
1258+
}
1259+
1260+
private static List<string> ParseTupleReturnType(string tupleType)
1261+
{
1262+
var fields = new List<string>();
1263+
1264+
// Handle ValueTuple syntax: (Type1, Type2, Type3)
1265+
if (tupleType.StartsWith("(") && tupleType.EndsWith(")"))
1266+
{
1267+
string content = tupleType.Substring(1, tupleType.Length - 2).Trim();
1268+
fields.AddRange(SplitTupleFields(content));
1269+
}
1270+
// Handle generic ValueTuple syntax: ValueTuple<Type1, Type2>
1271+
else if (tupleType.Contains("ValueTuple<"))
1272+
{
1273+
int startIndex = tupleType.IndexOf('<') + 1;
1274+
int endIndex = tupleType.LastIndexOf('>');
1275+
if (startIndex > 0 && endIndex > startIndex)
1276+
{
1277+
string typeArgs = tupleType.Substring(startIndex, endIndex - startIndex);
1278+
fields.AddRange(SplitTupleFields(typeArgs));
1279+
}
1280+
}
1281+
1282+
return fields;
1283+
}
1284+
1285+
private static List<string> SplitTupleFields(string content)
1286+
{
1287+
var fields = new List<string>();
1288+
var current = new StringBuilder();
1289+
int depth = 0;
1290+
1291+
for (int i = 0; i < content.Length; i++)
1292+
{
1293+
char c = content[i];
1294+
1295+
if (c == ',' && depth == 0)
1296+
{
1297+
string field = current.ToString().Trim();
1298+
if (!string.IsNullOrEmpty(field))
1299+
{
1300+
// Extract just the type name, ignoring field names
1301+
string[] parts = field.Split(' ');
1302+
fields.Add(parts[0]); // First part is the type
1303+
}
1304+
current.Clear();
1305+
}
1306+
else
1307+
{
1308+
if (c == '<' || c == '(') depth++;
1309+
else if (c == '>' || c == ')') depth--;
1310+
current.Append(c);
1311+
}
1312+
}
1313+
1314+
// Add the last field
1315+
string lastField = current.ToString().Trim();
1316+
if (!string.IsNullOrEmpty(lastField))
1317+
{
1318+
string[] parts = lastField.Split(' ');
1319+
fields.Add(parts[0]); // First part is the type
1320+
}
1321+
1322+
return fields;
1323+
}
1324+
1325+
private static bool IsTypeCompatible(string fieldType, string expectedType)
1326+
{
1327+
// Direct match
1328+
if (fieldType == expectedType)
1329+
return true;
1330+
1331+
// Handle common namespace variations
1332+
string normalizedFieldType = NormalizeTypeName(fieldType);
1333+
string normalizedExpectedType = NormalizeTypeName(expectedType);
1334+
1335+
return normalizedFieldType == normalizedExpectedType;
1336+
}
1337+
1338+
private static string NormalizeTypeName(string typeName)
1339+
{
1340+
// Remove common namespace prefixes for comparison
1341+
return typeName
1342+
.Replace("System.", "")
1343+
.Replace("global::", "")
1344+
.Trim();
1345+
}
1346+
11771347
private static void GenerateHandleTupleResult(IndentedStringBuilder source)
11781348
{
11791349
source.AppendLine()

0 commit comments

Comments
 (0)