Skip to content

Commit 86cde50

Browse files
committed
More progress
1 parent 8d44e0f commit 86cde50

File tree

1 file changed

+43
-38
lines changed

1 file changed

+43
-38
lines changed

src/Foundatio.Mediator.SourceGenerator/HandlerGenerator.cs

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ internal static class {{wrapperClassName}}
6767

6868
GenerateHandleMethod(source, handler);
6969

70-
// generate untyped handle method
70+
GenerateUntypedHandleMethod(source, handler);
7171

7272
GenerateInterceptorMethods(source, handler);
7373

@@ -93,6 +93,7 @@ private static void GenerateHandleMethod(IndentedStringBuilder source, HandlerIn
9393

9494
string asyncModifier = handler.IsAsync ? "async " : "";
9595
string result, accessor, parameters, defaultValue;
96+
bool allowNull = false;
9697
string returnType = handler.ReturnType.FullName;
9798

9899
var variables = new Dictionary<string, string>();
@@ -107,45 +108,44 @@ private static void GenerateHandleMethod(IndentedStringBuilder source, HandlerIn
107108
if (handler.ReturnType.IsTask == false && handler.IsAsync)
108109
{
109110
if (handler.ReturnType.IsVoid)
110-
returnType = "global::System.Threading.Tasks.ValueTask";
111+
returnType = "System.Threading.Tasks.ValueTask";
111112
else
112-
returnType = $"global::System.Threading.Tasks.ValueTask<global::{returnType}>";
113+
returnType = $"System.Threading.Tasks.ValueTask<{returnType}>";
113114
}
114115

115-
source.AppendLine($"public static {asyncModifier}{returnType} {stronglyTypedMethodName}(global::Foundatio.Mediator.IMediator mediator, global::{handler.MessageType.FullName} message, global::System.Threading.CancellationToken cancellationToken)")
116+
source.AppendLine($"public static {asyncModifier}{returnType} {stronglyTypedMethodName}(Foundatio.Mediator.IMediator mediator, {handler.MessageType.FullName} message, System.Threading.CancellationToken cancellationToken)")
116117
.AppendLine("{");
117118

118119
source.IncrementIndent();
119120

120-
source.AppendLine("var serviceProvider = (global::System.IServiceProvider)mediator;");
121+
source.AppendLine("var serviceProvider = (System.IServiceProvider)mediator;");
121122
variables["System.IServiceProvider"] = "serviceProvider";
122123
source.AppendLine();
123124

124125
// build middleware instances
125126
foreach (var m in handler.Middleware.Where(m => m.IsStatic == false))
126127
{
127-
source.AppendLine($"var {m.Identifier.ToCamelCase()} = global::Foundatio.Mediator.Mediator.GetOrCreateMiddleware<{m.FullName}>(serviceProvider);");
128+
source.AppendLine($"var {m.Identifier.ToCamelCase()} = Foundatio.Mediator.Mediator.GetOrCreateMiddleware<{m.FullName}>(serviceProvider);");
128129
}
129-
130-
source.AppendLine();
130+
source.AppendLineIf(handler.Middleware.Any(m => m.IsStatic == false));
131131

132132
// build result variables for before methods
133133
foreach (var m in beforeMiddleware.Where(m => m.Method.HasReturnValue))
134134
{
135-
bool allowNull = m.Method.ReturnType.IsNullable || m.Method.ReturnType.IsReferenceType;
135+
allowNull = m.Method.ReturnType.IsNullable || m.Method.ReturnType.IsReferenceType;
136136
defaultValue = allowNull ? "null" : "default";
137-
var prefix = m.Method.ReturnType.IsTuple ? "" : "global::";
138-
source.AppendLine($"{prefix}{m.Method.ReturnType.FullName}{(allowNull ? "?" : "")} {m.Middleware.Identifier.ToCamelCase()}Result = {defaultValue};");
137+
source.AppendLine($"{m.Method.ReturnType.UnwrappedFullName}{(allowNull ? "?" : "")} {m.Middleware.Identifier.ToCamelCase()}Result = {defaultValue};");
139138
}
139+
source.AppendLineIf(beforeMiddleware.Any(m => m.Method.HasReturnValue));
140140

141-
source.AppendLine();
142-
defaultValue = handler.ReturnType.IsNullable ? "null" : "default";
143-
source.AppendLine($"{handler.ReturnType.UnwrappedFullName} handlerResult = {defaultValue};");
141+
allowNull = handler.ReturnType.IsNullable || handler.ReturnType.IsReferenceType;
142+
defaultValue = handler.ReturnType.IsNullable || handler.ReturnType.IsReferenceType ? "null" : "default";
143+
source.AppendLineIf($"{handler.ReturnType.UnwrappedFullName}{(allowNull ? "?" : "")} handlerResult = {defaultValue};", handler.HasReturnValue);
144144

145145
if (shouldUseTryCatch)
146146
{
147147
source.AppendLine("""
148-
global::System.Exception? exception = null;
148+
System.Exception? exception = null;
149149
150150
try
151151
{
@@ -159,7 +159,7 @@ private static void GenerateHandleMethod(IndentedStringBuilder source, HandlerIn
159159
// call before middleware
160160
foreach (var m in beforeMiddleware)
161161
{
162-
asyncModifier = m.Middleware.IsAsync ? "await " : "";
162+
asyncModifier = m.Method.IsAsync ? "await " : "";
163163
result = m.Method.ReturnType.IsVoid ? "" : $"{m.Middleware.Identifier.ToCamelCase()}Result = ";
164164
accessor = m.Middleware.IsStatic ? m.Middleware.FullName : $"{m.Middleware.Identifier.ToCamelCase()}";
165165
parameters = BuildParameters(m.Method.Parameters);
@@ -169,8 +169,8 @@ private static void GenerateHandleMethod(IndentedStringBuilder source, HandlerIn
169169
source.AppendLineIf(beforeMiddleware.Any());
170170

171171
// call handler
172-
asyncModifier = handler.IsAsync ? "await " : "";
173-
result = handler.ReturnType.IsVoid ? "" : shouldUseTryCatch ? "handlerResult = " : "return ";
172+
asyncModifier = handler.ReturnType.IsTask ? "await " : "";
173+
result = handler.ReturnType.IsVoid ? "" : "handlerResult = ";
174174
accessor = handler.IsStatic ? handler.FullName : $"handlerInstance";
175175
parameters = BuildParameters(handler.Parameters);
176176

@@ -181,7 +181,7 @@ private static void GenerateHandleMethod(IndentedStringBuilder source, HandlerIn
181181
// call after middleware
182182
foreach (var m in afterMiddleware)
183183
{
184-
asyncModifier = m.Middleware.IsAsync ? "await " : "";
184+
asyncModifier = m.Method.IsAsync ? "await " : "";
185185
accessor = m.Middleware.IsStatic ? m.Middleware.FullName : $"{m.Middleware.Identifier.ToCamelCase()}";
186186
parameters = BuildParameters(m.Method.Parameters, variables);
187187

@@ -256,7 +256,7 @@ private static string BuildParameters(EquatableArray<ParameterInfo> parameters,
256256
}
257257
else
258258
{
259-
parameterValues.Add($"serviceProvider.GetRequiredService<global::{param.Type.FullName}>()");
259+
parameterValues.Add($"serviceProvider.GetRequiredService<{param.Type.FullName}>()");
260260
}
261261
}
262262

@@ -281,21 +281,26 @@ private static void GenerateUntypedHandleMethod(IndentedStringBuilder source, Ha
281281
source.AppendLine($"var typedMessage = ({handler.MessageType.FullName})message;");
282282

283283
string stronglyTypedMethodName = GetHandlerMethodName(handler);
284+
string asyncModifier = handler.IsAsync ? "await " : "";
285+
var result = handler.ReturnType.IsVoid ? "" : "var result = ";
284286

285-
if (!handler.ReturnType.IsVoid)
286-
{
287-
source.AppendLine($"var result = {(handler.IsAsync ? "await " : "")}{stronglyTypedMethodName}(mediator, typedMessage, cancellationToken);");
287+
source.AppendLine($"{result}{asyncModifier}{stronglyTypedMethodName}(mediator, typedMessage, cancellationToken);");
288288

289-
if (handler.ReturnType.IsTuple)
290-
{
291-
source.AppendLine("return await PublishCascadingMessagesAsync(mediator, result, responseType);");
292-
}
293-
else
294-
{
295-
GenerateNonTupleResultHandling(source, handler);
296-
}
289+
if (handler.ReturnType.IsTuple)
290+
{
291+
source.AppendLine("return await PublishCascadingMessagesAsync(mediator, result, responseType);");
292+
}
293+
else if (handler.HasReturnValue)
294+
{
295+
GenerateNonTupleResultHandling(source, handler);
296+
}
297+
else
298+
{
299+
source.AppendLine("return null;");
297300
}
298301
}
302+
303+
source.AppendLine("}");
299304
}
300305

301306
private static void GenerateNonTupleResultHandling(IndentedStringBuilder source, HandlerInfo handler)
@@ -413,7 +418,7 @@ private static void GenerateInterceptorMethod(IndentedStringBuilder source, Hand
413418

414419
// Generate method signature
415420
string returnType = GenerateInterceptorReturnType(interceptorIsAsync, isGeneric, expectedResponseTypeName);
416-
string parameters = "this global::Foundatio.Mediator.IMediator mediator, object message, global::System.Threading.CancellationToken cancellationToken = default";
421+
string parameters = "this Foundatio.Mediator.IMediator mediator, object message, System.Threading.CancellationToken cancellationToken = default";
417422
string stronglyTypedMethodName = GetHandlerMethodName(handler);
418423

419424
string asyncModifier = interceptorIsAsync ? "async " : "";
@@ -448,19 +453,19 @@ private static void GenerateInterceptorMethod(IndentedStringBuilder source, Hand
448453
private static string GenerateInterceptorAttribute(CallSiteInfo callSite)
449454
{
450455
var location = callSite.Location;
451-
return $"[global::System.Runtime.CompilerServices.InterceptsLocation({location.Version}, \"{location.Data}\")] // {location.DisplayLocation}";
456+
return $"[System.Runtime.CompilerServices.InterceptsLocation({location.Version}, \"{location.Data}\")] // {location.DisplayLocation}";
452457
}
453458

454459
private static string GenerateInterceptorReturnType(bool isAsync, bool isGeneric, string expectedResponseTypeName)
455460
{
456461
if (isGeneric)
457462
{
458463
// For generic methods, return the exact same type as the original method
459-
return isAsync ? $"global::System.Threading.Tasks.ValueTask<{expectedResponseTypeName}>" : expectedResponseTypeName;
464+
return isAsync ? $"System.Threading.Tasks.ValueTask<{expectedResponseTypeName}>" : expectedResponseTypeName;
460465
}
461466

462467
// For non-generic methods, return the exact same type as the original method
463-
return isAsync ? "global::System.Threading.Tasks.ValueTask" : "void";
468+
return isAsync ? "System.Threading.Tasks.ValueTask" : "void";
464469
}
465470

466471
/// <summary>
@@ -476,7 +481,7 @@ private static string GetSafeCastExpression(string handlerResultVar, HandlerInfo
476481
if (returnType.StartsWith("Foundatio.Mediator.Result<") && returnType != "Foundatio.Mediator.Result")
477482
{
478483
// Check if the value might be a non-generic Result that needs conversion to Result<T>
479-
return $"{handlerResultVar}.Value is global::Foundatio.Mediator.Result result ? ({returnType})result : ({returnType}?){handlerResultVar}.Value ?? default({returnType})!";
484+
return $"{handlerResultVar}.Value is Foundatio.Mediator.Result result ? ({returnType})result : ({returnType}?){handlerResultVar}.Value ?? default({returnType})!";
480485
}
481486

482487
// For reference types, provide a null-coalescing fallback to satisfy non-nullable return types
@@ -553,7 +558,7 @@ private static void GenerateGetOrCreateHandler(IndentedStringBuilder source, Han
553558
private static string GenerateShortCircuitCheck(MiddlewareInfo middleware, string resultVariableName, string hrVariableName)
554559
{
555560
if (middleware.BeforeMethod == null)
556-
return $"if ({resultVariableName} is global::Foundatio.Mediator.HandlerResult {hrVariableName} && {hrVariableName}.IsShortCircuited)";
561+
return $"if ({resultVariableName} is Foundatio.Mediator.HandlerResult {hrVariableName} && {hrVariableName}.IsShortCircuited)";
557562

558563
var methodInfo = middleware.BeforeMethod.Value;
559564

@@ -573,7 +578,7 @@ private static string GenerateShortCircuitCheck(MiddlewareInfo middleware, strin
573578
// Otherwise, fall back to pattern matching for object/object? return types
574579
else
575580
{
576-
return $"if ({resultVariableName} is global::Foundatio.Mediator.HandlerResult {hrVariableName} && {hrVariableName}.IsShortCircuited)";
581+
return $"if ({resultVariableName} is Foundatio.Mediator.HandlerResult {hrVariableName} && {hrVariableName}.IsShortCircuited)";
577582
}
578583
}
579584

0 commit comments

Comments
 (0)