Skip to content

Commit 449491a

Browse files
committed
Refactoring
1 parent 402c2c2 commit 449491a

File tree

4 files changed

+162
-24
lines changed

4 files changed

+162
-24
lines changed

src/Foundatio.Mediator.SourceGenerator/DIRegistrationGenerator.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System.Text;
2+
using Foundatio.Mediator.Utility;
23

34
namespace Foundatio.Mediator;
45

@@ -55,7 +56,9 @@ public static string GenerateDIRegistration(List<HandlerInfo> handlers, List<Mid
5556
// Register the handler under all message types in its hierarchy
5657
foreach (string? messageTypeName in handler.MessageTypeHierarchy)
5758
{
58-
source.AppendLine($" services.AddKeyedSingleton<HandlerRegistration>(\"{messageTypeName}\",");
59+
// Convert compile-time format to runtime format for DI registration
60+
string runtimeTypeName = TypeNameHelper.ConvertToRuntimeTypeName(messageTypeName);
61+
source.AppendLine($" services.AddKeyedSingleton<HandlerRegistration>(\"{runtimeTypeName}\",");
5962
source.AppendLine($" new HandlerRegistration(");
6063
source.AppendLine($" \"{handler.MessageTypeName}\","); // Keep the primary message type name for identification
6164

src/Foundatio.Mediator.SourceGenerator/HandlerWrapperGenerator.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.CodeAnalysis;
22
using System.Text;
3+
using Foundatio.Mediator.Utility;
34

45
namespace Foundatio.Mediator;
56

@@ -121,9 +122,10 @@ public static string GenerateHandlerWrapper(HandlerInfo handler, string wrapperC
121122
public static string GetWrapperClassName(HandlerInfo handler)
122123
{
123124
// Create a deterministic wrapper class name based on handler type and method
124-
string handlerTypeName = handler.HandlerTypeName.Split('.').Last().Replace("<", "_").Replace(">", "_").Replace(",", "_");
125+
// Extract the simple type name from the full type name, handling both . and + separators
126+
string handlerTypeName = TypeNameHelper.GetSimpleTypeName(handler.HandlerTypeName);
125127
string methodName = handler.MethodName;
126-
string messageTypeName = handler.MessageTypeName.Split('.').Last().Replace("<", "_").Replace(">", "_").Replace(",", "_");
128+
string messageTypeName = TypeNameHelper.GetSimpleTypeName(handler.MessageTypeName);
127129
return $"{handlerTypeName}_{methodName}_{messageTypeName}_StaticWrapper";
128130
}
129131

src/Foundatio.Mediator.SourceGenerator/MediatorImplementationGenerator.cs

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,21 @@ public static string GenerateMediatorImplementation(List<HandlerInfo> handlers)
5757
source.AppendLine(" var allHandlers = new List<HandlerRegistration>();");
5858
source.AppendLine();
5959
source.AppendLine(" // Add handlers for the exact message type");
60-
source.AppendLine(" var exactHandlers = _serviceProvider.GetKeyedServices<HandlerRegistration>(messageType.FullName);");
60+
source.AppendLine(" var exactHandlers = GetHandlersForType(messageType);");
6161
source.AppendLine(" allHandlers.AddRange(exactHandlers);");
6262
source.AppendLine();
6363
source.AppendLine(" // Add handlers for all implemented interfaces");
6464
source.AppendLine(" foreach (var interfaceType in messageType.GetInterfaces())");
6565
source.AppendLine(" {");
66-
source.AppendLine(" var interfaceHandlers = _serviceProvider.GetKeyedServices<HandlerRegistration>(interfaceType.FullName);");
66+
source.AppendLine(" var interfaceHandlers = GetHandlersForType(interfaceType);");
6767
source.AppendLine(" allHandlers.AddRange(interfaceHandlers);");
6868
source.AppendLine(" }");
6969
source.AppendLine();
7070
source.AppendLine(" // Add handlers for all base classes");
7171
source.AppendLine(" var currentType = messageType.BaseType;");
7272
source.AppendLine(" while (currentType != null && currentType != typeof(object))");
7373
source.AppendLine(" {");
74-
source.AppendLine(" var baseHandlers = _serviceProvider.GetKeyedServices<HandlerRegistration>(currentType.FullName);");
74+
source.AppendLine(" var baseHandlers = GetHandlersForType(currentType);");
7575
source.AppendLine(" allHandlers.AddRange(baseHandlers);");
7676
source.AppendLine(" currentType = currentType.BaseType;");
7777
source.AppendLine(" }");
@@ -80,18 +80,26 @@ public static string GenerateMediatorImplementation(List<HandlerInfo> handlers)
8080
source.AppendLine(" }");
8181
source.AppendLine();
8282

83+
// Helper method to get handlers for a specific type
84+
source.AppendLine(" [DebuggerStepThrough]");
85+
source.AppendLine(" private IEnumerable<HandlerRegistration> GetHandlersForType(Type type)");
86+
source.AppendLine(" {");
87+
source.AppendLine(" return _serviceProvider.GetKeyedServices<HandlerRegistration>(type.FullName);");
88+
source.AppendLine(" }");
89+
source.AppendLine();
90+
8391
// Generate InvokeAsync method
8492
source.AppendLine(" public async ValueTask InvokeAsync(object message, CancellationToken cancellationToken = default)");
8593
source.AppendLine(" {");
86-
source.AppendLine(" var messageTypeName = message.GetType().FullName;");
87-
source.AppendLine(" var handlers = _serviceProvider.GetKeyedServices<HandlerRegistration>(messageTypeName);");
94+
source.AppendLine(" var messageType = message.GetType();");
95+
source.AppendLine(" var handlers = GetHandlersForType(messageType);");
8896
source.AppendLine(" var handlersList = handlers.ToList();");
8997
source.AppendLine();
9098
source.AppendLine(" if (handlersList.Count == 0)");
91-
source.AppendLine(" throw new InvalidOperationException($\"No handler found for message type {messageTypeName}\");");
99+
source.AppendLine(" throw new InvalidOperationException($\"No handler found for message type {messageType.FullName}\");");
92100
source.AppendLine();
93101
source.AppendLine(" if (handlersList.Count > 1)");
94-
source.AppendLine(" throw new InvalidOperationException($\"Multiple handlers found for message type {messageTypeName}. Use PublishAsync for multiple handlers.\");");
102+
source.AppendLine(" throw new InvalidOperationException($\"Multiple handlers found for message type {messageType.FullName}. Use PublishAsync for multiple handlers.\");");
95103
source.AppendLine();
96104
source.AppendLine(" var handler = handlersList.First();");
97105
source.AppendLine(" await handler.HandleAsync(this, message, cancellationToken, null);");
@@ -101,19 +109,19 @@ public static string GenerateMediatorImplementation(List<HandlerInfo> handlers)
101109
// Generate Invoke method (sync)
102110
source.AppendLine(" public void Invoke(object message, CancellationToken cancellationToken = default)");
103111
source.AppendLine(" {");
104-
source.AppendLine(" var messageTypeName = message.GetType().FullName;");
105-
source.AppendLine(" var handlers = _serviceProvider.GetKeyedServices<HandlerRegistration>(messageTypeName);");
112+
source.AppendLine(" var messageType = message.GetType();");
113+
source.AppendLine(" var handlers = GetHandlersForType(messageType);");
106114
source.AppendLine(" var handlersList = handlers.ToList();");
107115
source.AppendLine();
108116
source.AppendLine(" if (handlersList.Count == 0)");
109-
source.AppendLine(" throw new InvalidOperationException($\"No handler found for message type {messageTypeName}\");");
117+
source.AppendLine(" throw new InvalidOperationException($\"No handler found for message type {messageType.FullName}\");");
110118
source.AppendLine();
111119
source.AppendLine(" if (handlersList.Count > 1)");
112-
source.AppendLine(" throw new InvalidOperationException($\"Multiple handlers found for message type {messageTypeName}. Use Publish for multiple handlers.\");");
120+
source.AppendLine(" throw new InvalidOperationException($\"Multiple handlers found for message type {messageType.FullName}. Use Publish for multiple handlers.\");");
113121
source.AppendLine();
114122
source.AppendLine(" var handler = handlersList.First();");
115123
source.AppendLine(" if (handler.IsAsync)");
116-
source.AppendLine(" throw new InvalidOperationException($\"Cannot use synchronous Invoke with async-only handler for message type {messageTypeName}. Use InvokeAsync instead.\");");
124+
source.AppendLine(" throw new InvalidOperationException($\"Cannot use synchronous Invoke with async-only handler for message type {messageType.FullName}. Use InvokeAsync instead.\");");
117125
source.AppendLine();
118126
source.AppendLine(" handler.Handle!(this, message, cancellationToken, null);");
119127
source.AppendLine(" }");
@@ -122,15 +130,15 @@ public static string GenerateMediatorImplementation(List<HandlerInfo> handlers)
122130
// Generate InvokeAsync<TResponse> method
123131
source.AppendLine(" public async ValueTask<TResponse> InvokeAsync<TResponse>(object message, CancellationToken cancellationToken = default)");
124132
source.AppendLine(" {");
125-
source.AppendLine(" var messageTypeName = message.GetType().FullName;");
126-
source.AppendLine(" var handlers = _serviceProvider.GetKeyedServices<HandlerRegistration>(messageTypeName);");
133+
source.AppendLine(" var messageType = message.GetType();");
134+
source.AppendLine(" var handlers = GetHandlersForType(messageType);");
127135
source.AppendLine(" var handlersList = handlers.ToList();");
128136
source.AppendLine();
129137
source.AppendLine(" if (handlersList.Count == 0)");
130-
source.AppendLine(" throw new InvalidOperationException($\"No handler found for message type {messageTypeName}\");");
138+
source.AppendLine(" throw new InvalidOperationException($\"No handler found for message type {messageType.FullName}\");");
131139
source.AppendLine();
132140
source.AppendLine(" if (handlersList.Count > 1)");
133-
source.AppendLine(" throw new InvalidOperationException($\"Multiple handlers found for message type {messageTypeName}. Use PublishAsync for multiple handlers.\");");
141+
source.AppendLine(" throw new InvalidOperationException($\"Multiple handlers found for message type {messageType.FullName}. Use PublishAsync for multiple handlers.\");");
134142
source.AppendLine();
135143
source.AppendLine(" var handler = handlersList.First();");
136144
source.AppendLine(" var result = await handler.HandleAsync(this, message, cancellationToken, typeof(TResponse));");
@@ -142,19 +150,19 @@ public static string GenerateMediatorImplementation(List<HandlerInfo> handlers)
142150
// Generate Invoke<TResponse> method (sync)
143151
source.AppendLine(" public TResponse Invoke<TResponse>(object message, CancellationToken cancellationToken = default)");
144152
source.AppendLine(" {");
145-
source.AppendLine(" var messageTypeName = message.GetType().FullName;");
146-
source.AppendLine(" var handlers = _serviceProvider.GetKeyedServices<HandlerRegistration>(messageTypeName);");
153+
source.AppendLine(" var messageType = message.GetType();");
154+
source.AppendLine(" var handlers = GetHandlersForType(messageType);");
147155
source.AppendLine(" var handlersList = handlers.ToList();");
148156
source.AppendLine();
149157
source.AppendLine(" if (handlersList.Count == 0)");
150-
source.AppendLine(" throw new InvalidOperationException($\"No handler found for message type {messageTypeName}\");");
158+
source.AppendLine(" throw new InvalidOperationException($\"No handler found for message type {messageType.FullName}\");");
151159
source.AppendLine();
152160
source.AppendLine(" if (handlersList.Count > 1)");
153-
source.AppendLine(" throw new InvalidOperationException($\"Multiple handlers found for message type {messageTypeName}. Use Publish for multiple handlers.\");");
161+
source.AppendLine(" throw new InvalidOperationException($\"Multiple handlers found for message type {messageType.FullName}. Use Publish for multiple handlers.\");");
154162
source.AppendLine();
155163
source.AppendLine(" var handler = handlersList.First();");
156164
source.AppendLine(" if (handler.IsAsync)");
157-
source.AppendLine(" throw new InvalidOperationException($\"Cannot use synchronous Invoke with async-only handler for message type {messageTypeName}. Use InvokeAsync instead.\");");
165+
source.AppendLine(" throw new InvalidOperationException($\"Cannot use synchronous Invoke with async-only handler for message type {messageType.FullName}. Use InvokeAsync instead.\");");
158166
source.AppendLine();
159167
source.AppendLine(" object result = handler.Handle!(this, message, cancellationToken, typeof(TResponse));");
160168
source.AppendLine(" return (TResponse)result;");
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
using Microsoft.CodeAnalysis;
2+
using System.Text;
3+
4+
namespace Foundatio.Mediator.Utility;
5+
6+
internal static class TypeNameHelper
7+
{
8+
/// <summary>
9+
/// Gets the runtime-compatible type name from an ITypeSymbol.
10+
/// This builds the type name properly using the Roslyn API to distinguish between
11+
/// namespace separators (.) and nested type separators (+).
12+
/// </summary>
13+
/// <param name="typeSymbol">The type symbol to get the name for</param>
14+
/// <returns>Type name compatible with Type.FullName at runtime</returns>
15+
public static string GetRuntimeTypeName(ITypeSymbol typeSymbol)
16+
{
17+
// Use the proper Roslyn way to build runtime type names
18+
var parts = new List<string>();
19+
20+
// Build the type hierarchy from innermost to outermost
21+
var currentType = typeSymbol;
22+
while (currentType != null)
23+
{
24+
// Get the type name without namespace
25+
string typeName = currentType.Name;
26+
27+
// Handle generic types
28+
if (currentType is INamedTypeSymbol namedType && namedType.TypeArguments.Length > 0)
29+
{
30+
typeName += "`" + namedType.TypeArguments.Length;
31+
}
32+
33+
parts.Insert(0, typeName);
34+
35+
// Move to containing type (for nested types)
36+
currentType = currentType.ContainingType;
37+
}
38+
39+
// Build the full type name
40+
string namespaceName = typeSymbol.ContainingNamespace?.ToDisplayString();
41+
if (!string.IsNullOrEmpty(namespaceName) && namespaceName != "<global namespace>")
42+
{
43+
// Namespace parts use dots
44+
string typePart = string.Join("+", parts); // Nested types use +
45+
return namespaceName + "." + typePart;
46+
}
47+
else
48+
{
49+
// No namespace, just the type parts
50+
return string.Join("+", parts);
51+
}
52+
}
53+
54+
/// <summary>
55+
/// For string-based type names from ToDisplayString(), convert nested type separators.
56+
/// This handles the case where ITypeSymbol.ToDisplayString() uses dots for nested types
57+
/// but Type.FullName uses + for nested types.
58+
/// </summary>
59+
/// <param name="displayTypeName">Type name from ITypeSymbol.ToDisplayString()</param>
60+
/// <returns>Type name compatible with Type.FullName for runtime lookup</returns>
61+
public static string ConvertToRuntimeTypeName(string displayTypeName)
62+
{
63+
// For most cases, ITypeSymbol.ToDisplayString() produces the correct format
64+
// The main exception is nested types where dots should become +
65+
66+
// Conservative heuristic: only convert if we have at least 5 parts suggesting deep nesting
67+
// like "Namespace1.Namespace2.Namespace3.OuterClass.InnerClass"
68+
var parts = displayTypeName.Split('.');
69+
if (parts.Length < 5)
70+
{
71+
// For simple cases, assume no nested types and return as-is
72+
return displayTypeName;
73+
}
74+
75+
// For complex cases, apply the nested type logic
76+
int lastDotIndex = displayTypeName.LastIndexOf('.');
77+
if (lastDotIndex <= 0 || lastDotIndex == displayTypeName.Length - 1)
78+
{
79+
return displayTypeName;
80+
}
81+
82+
// Check if the part after the last dot looks like a type name (starts with uppercase)
83+
string lastPart = displayTypeName.Substring(lastDotIndex + 1);
84+
if (lastPart.Length > 0 && char.IsUpper(lastPart[0]))
85+
{
86+
// Check if the part before the last dot also looks like it could contain a type name
87+
string beforeLastDot = displayTypeName.Substring(0, lastDotIndex);
88+
int secondLastDotIndex = beforeLastDot.LastIndexOf('.');
89+
90+
if (secondLastDotIndex >= 0)
91+
{
92+
string potentialTypeName = beforeLastDot.Substring(secondLastDotIndex + 1);
93+
// If the potential type name starts with uppercase, treat the last dot as a nested type separator
94+
if (potentialTypeName.Length > 0 && char.IsUpper(potentialTypeName[0]))
95+
{
96+
return beforeLastDot + "+" + lastPart;
97+
}
98+
}
99+
}
100+
101+
// No nested type pattern detected, return as-is
102+
return displayTypeName;
103+
}
104+
105+
/// <summary>
106+
/// Gets the simple type name from a full type name, handling both . and + separators.
107+
/// This is useful for generating clean class names in code generation.
108+
/// </summary>
109+
/// <param name="fullTypeName">The full type name including namespace and nested type separators</param>
110+
/// <returns>Simple type name suitable for use as a class name</returns>
111+
public static string GetSimpleTypeName(string fullTypeName)
112+
{
113+
// Get the last part of the type name, handling both . and + separators
114+
int lastDotIndex = fullTypeName.LastIndexOf('.');
115+
int lastPlusIndex = fullTypeName.LastIndexOf('+');
116+
int lastSeparatorIndex = Math.Max(lastDotIndex, lastPlusIndex);
117+
118+
string simpleName = lastSeparatorIndex >= 0
119+
? fullTypeName.Substring(lastSeparatorIndex + 1)
120+
: fullTypeName;
121+
122+
// Clean up the name for use as a class name
123+
return simpleName.Replace("<", "_").Replace(">", "_").Replace(",", "_").Replace("+", "_");
124+
}
125+
}

0 commit comments

Comments
 (0)