Skip to content

Commit 5ba929f

Browse files
committed
Improved code generation for IRecipient generator
1 parent 9184ba5 commit 5ba929f

File tree

3 files changed

+89
-57
lines changed

3 files changed

+89
-57
lines changed

Microsoft.Toolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,13 @@ public void Execute(GeneratorExecutionContext context)
6363

6464
foreach (INamedTypeSymbol classSymbol in syntaxReceiver.GatheredInfo)
6565
{
66-
// Create a static method to register all messages for a given recipient type.
66+
// Create a static factory method to register all messages for a given recipient type.
67+
// This follows the same pattern used in ObservableValidatorValidateAllPropertiesGenerator,
68+
// with the same advantages mentioned there (type safety, more AOT-friendly, etc.).
6769
// There are two versions that are generated: a non-generic one doing the registration
6870
// with no tokens, which is the most common scenario and will help particularly in AOT
6971
// scenarios, and a generic version that will support all other cases with custom tokens.
72+
// Note: the generic overload has a different name to simplify the lookup with reflection.
7073
// This code takes a class symbol and produces a compilation unit as follows:
7174
//
7275
// // Licensed to the .NET Foundation under one or more agreements.
@@ -86,17 +89,27 @@ public void Execute(GeneratorExecutionContext context)
8689
// {
8790
// [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
8891
// [global::System.Obsolete("This method is not intended to be called directly by user code")]
89-
// public static void RegisterAll(IMessenger messenger, <RECIPIENT_TYPE> recipient)
92+
// public static global::System.Action<IMessenger, object> CreateAllMessagesRegistrator(<RECIPIENT_TYPE> _)
9093
// {
91-
// <BODY>
94+
// static void RegisterAll(IMessenger messenger, <INSTANCE_TYPE> instance)
95+
// {
96+
// <BODY>
97+
// }
98+
//
99+
// return static (m, r) => RegisterAll(m, (<INSTANCE_TYPE>)r);
92100
// }
93101
//
94102
// [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
95103
// [global::System.Obsolete("This method is not intended to be called directly by user code")]
96-
// public static void RegisterAll<TToken>(IMessenger messenger, <RECIPIENT_TYPE> recipient, TToken token)
104+
// public static global::System.Action<IMessenger, object, TToken> CreateAllMessagesRegistratorWithToken<TToken>(<RECIPIENT_TYPE> _)
97105
// where TToken : global::System.IEquatable<TToken>
98106
// {
99-
// <BODY>
107+
// static void RegisterAll(IMessenger messenger, <INSTANCE_TYPE> instance, TToken token)
108+
// {
109+
// <BODY>
110+
// }
111+
//
112+
// return static (m, r, t) => RegisterAll(m, (<INSTANCE_TYPE>)r, t);
100113
// }
101114
// }
102115
// }
@@ -112,8 +125,10 @@ public void Execute(GeneratorExecutionContext context)
112125
Token(SyntaxKind.StaticKeyword),
113126
Token(SyntaxKind.PartialKeyword)).AddAttributeLists(classAttributes).AddMembers(
114127
MethodDeclaration(
115-
PredefinedType(Token(SyntaxKind.VoidKeyword)),
116-
Identifier("RegisterAll")).AddAttributeLists(
128+
GenericName("global::System.Action").AddTypeArgumentListArguments(
129+
IdentifierName("IMessenger"),
130+
PredefinedType(Token(SyntaxKind.ObjectKeyword))),
131+
Identifier("CreateAllMessagesRegistrator")).AddAttributeLists(
117132
AttributeList(SingletonSeparatedList(
118133
Attribute(IdentifierName("global::System.ComponentModel.EditorBrowsable")).AddArgumentListArguments(
119134
AttributeArgument(ParseExpression("global::System.ComponentModel.EditorBrowsableState.Never"))))),
@@ -124,12 +139,35 @@ public void Execute(GeneratorExecutionContext context)
124139
Literal("This method is not intended to be called directly by user code"))))))).AddModifiers(
125140
Token(SyntaxKind.PublicKeyword),
126141
Token(SyntaxKind.StaticKeyword)).AddParameterListParameters(
127-
Parameter(Identifier("messenger")).WithType(IdentifierName("IMessenger")),
128-
Parameter(Identifier("recipient")).WithType(IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))))
129-
.WithBody(Block(EnumerateRegistrationStatements(classSymbol, iRecipientSymbol).ToArray())),
142+
Parameter(Identifier("_")).WithType(IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))))
143+
.WithBody(Block(
144+
LocalFunctionStatement(
145+
PredefinedType(Token(SyntaxKind.VoidKeyword)),
146+
Identifier("RegisterAll"))
147+
.AddModifiers(Token(SyntaxKind.StaticKeyword))
148+
.AddParameterListParameters(
149+
Parameter(Identifier("messenger")).WithType(IdentifierName("IMessenger")),
150+
Parameter(Identifier("recipient")).WithType(IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))))
151+
.WithBody(Block(EnumerateRegistrationStatements(classSymbol, iRecipientSymbol).ToArray())),
152+
ReturnStatement(
153+
ParenthesizedLambdaExpression()
154+
.AddModifiers(Token(SyntaxKind.StaticKeyword))
155+
.AddParameterListParameters(
156+
Parameter(Identifier("m")),
157+
Parameter(Identifier("r")))
158+
.WithExpressionBody(
159+
InvocationExpression(IdentifierName("RegisterAll"))
160+
.AddArgumentListArguments(
161+
Argument(IdentifierName("m")),
162+
Argument(CastExpression(
163+
IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)),
164+
IdentifierName("r")))))))),
130165
MethodDeclaration(
131-
PredefinedType(Token(SyntaxKind.VoidKeyword)),
132-
Identifier("RegisterAll")).AddAttributeLists(
166+
GenericName("global::System.Action").AddTypeArgumentListArguments(
167+
IdentifierName("IMessenger"),
168+
PredefinedType(Token(SyntaxKind.ObjectKeyword)),
169+
IdentifierName("TToken")),
170+
Identifier("CreateAllMessagesRegistratorWithToken")).AddAttributeLists(
133171
AttributeList(SingletonSeparatedList(
134172
Attribute(IdentifierName("global::System.ComponentModel.EditorBrowsable")).AddArgumentListArguments(
135173
AttributeArgument(ParseExpression("global::System.ComponentModel.EditorBrowsableState.Never"))))),
@@ -140,14 +178,36 @@ public void Execute(GeneratorExecutionContext context)
140178
Literal("This method is not intended to be called directly by user code"))))))).AddModifiers(
141179
Token(SyntaxKind.PublicKeyword),
142180
Token(SyntaxKind.StaticKeyword)).AddParameterListParameters(
143-
Parameter(Identifier("messenger")).WithType(IdentifierName("IMessenger")),
144-
Parameter(Identifier("recipient")).WithType(IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))),
145-
Parameter(Identifier("token")).WithType(IdentifierName("TToken")))
181+
Parameter(Identifier("_")).WithType(IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))))
146182
.AddTypeParameterListParameters(TypeParameter("TToken"))
147183
.AddConstraintClauses(
148184
TypeParameterConstraintClause("TToken")
149185
.AddConstraints(TypeConstraint(GenericName("global::System.IEquatable").AddTypeArgumentListArguments(IdentifierName("TToken")))))
150-
.WithBody(Block(EnumerateRegistrationStatementsWithTokens(classSymbol, iRecipientSymbol).ToArray())))))
186+
.WithBody(Block(
187+
LocalFunctionStatement(
188+
PredefinedType(Token(SyntaxKind.VoidKeyword)),
189+
Identifier("RegisterAll"))
190+
.AddModifiers(Token(SyntaxKind.StaticKeyword))
191+
.AddParameterListParameters(
192+
Parameter(Identifier("messenger")).WithType(IdentifierName("IMessenger")),
193+
Parameter(Identifier("recipient")).WithType(IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))),
194+
Parameter(Identifier("token")).WithType(IdentifierName("TToken")))
195+
.WithBody(Block(EnumerateRegistrationStatementsWithTokens(classSymbol, iRecipientSymbol).ToArray())),
196+
ReturnStatement(
197+
ParenthesizedLambdaExpression()
198+
.AddModifiers(Token(SyntaxKind.StaticKeyword))
199+
.AddParameterListParameters(
200+
Parameter(Identifier("m")),
201+
Parameter(Identifier("r")),
202+
Parameter(Identifier("t")))
203+
.WithExpressionBody(
204+
InvocationExpression(IdentifierName("RegisterAll"))
205+
.AddArgumentListArguments(
206+
Argument(IdentifierName("m")),
207+
Argument(CastExpression(
208+
IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)),
209+
IdentifierName("r"))),
210+
Argument(IdentifierName("t"))))))))))
151211
.NormalizeWhitespace()
152212
.ToFullString();
153213

Microsoft.Toolkit.Mvvm/Messaging/IMessengerExtensions.cs

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,13 @@ public static void RegisterAll(this IMessenger messenger, object recipient)
8686
{
8787
// We use this method as a callback for the conditional weak table, which will handle
8888
// thread-safety for us. This first callback will try to find a generated method for the
89-
// target recipient type, and just create a delegate wrapping that method if it is found.
89+
// target recipient type, and just invoke it to get the delegate to cache and use later.
9090
static Action<IMessenger, object>? LoadRegistrationMethodsForType(Type recipientType)
9191
{
9292
if (recipientType.Assembly.GetType("Microsoft.Toolkit.Mvvm.Messaging.__Internals.__IMessengerExtensions") is Type extensionsType &&
93-
extensionsType.GetMethod("RegisterAll", new[] { typeof(IMessenger), recipientType }) is MethodInfo methodInfo)
93+
extensionsType.GetMethod("CreateAllMessagesRegistrator", new[] { recipientType }) is MethodInfo methodInfo)
9494
{
95-
Type delegateType = typeof(Action<,>).MakeGenericType(typeof(IMessenger), recipientType);
96-
97-
// Create the delegate and use an unsafe cast to achieve input covariance (as detailed below)
98-
return Unsafe.As<Action<IMessenger, object>>(methodInfo.CreateDelegate(delegateType));
95+
return (Action<IMessenger, object>)methodInfo.Invoke(null, new object?[] { null })!;
9996
}
10097

10198
return null;
@@ -135,49 +132,21 @@ public static void RegisterAll<TToken>(this IMessenger messenger, object recipie
135132
{
136133
// We use this method as a callback for the conditional weak table, which will handle
137134
// thread-safety for us. This first callback will try to find a generated method for the
138-
// target recipient type, and just create a delegate wrapping that method if it is found.
135+
// target recipient type, and just invoke it to get the delegate to cache and use later.
136+
// In this case we also need to create a generic instantiation of the target method first.
139137
static Action<IMessenger, object, TToken> LoadRegistrationMethodsForType(Type recipientType)
140138
{
141-
if (recipientType.Assembly.GetType("Microsoft.Toolkit.Mvvm.Messaging.__Internals.__IMessengerExtensions") is Type extensionsType)
139+
if (recipientType.Assembly.GetType("Microsoft.Toolkit.Mvvm.Messaging.__Internals.__IMessengerExtensions") is Type extensionsType &&
140+
extensionsType.GetMethod("CreateAllMessagesRegistratorWithToken", new[] { recipientType }) is MethodInfo methodInfo)
142141
{
143-
#if NETSTANDARD2_0
144-
// .NET Standard 2.0 doesn't have Type.MakeGenericMethodParameter, so we need to iterate manually
145-
foreach (MethodInfo methodInfo in extensionsType.GetMethods(BindingFlags.Static | BindingFlags.Public))
146-
{
147-
if (methodInfo.Name is "RegisterAll" &&
148-
methodInfo.IsGenericMethod &&
149-
methodInfo.GetParameters()[1].ParameterType == recipientType)
150-
{
151-
return CreateGenericDelegate(recipientType, methodInfo);
152-
}
153-
}
154-
#else
155-
// On .NET Standard 2.1 and up, we can directly look for the target method in one call
156-
Type[] methodTypes = new[] { typeof(IMessenger), recipientType, Type.MakeGenericMethodParameter(0) };
142+
MethodInfo genericMethodInfo = methodInfo.MakeGenericMethod(typeof(TToken));
157143

158-
if (extensionsType.GetMethod("RegisterAll", methodTypes) is MethodInfo methodInfo)
159-
{
160-
return CreateGenericDelegate(recipientType, methodInfo);
161-
}
162-
#endif
144+
return (Action<IMessenger, object, TToken>)genericMethodInfo.Invoke(null, new object?[] { null })!;
163145
}
164146

165147
return LoadRegistrationMethodsForTypeFallback(recipientType);
166148
}
167149

168-
// A shared method to create a generic delegate from an identified method
169-
static Action<IMessenger, object, TToken> CreateGenericDelegate(Type recipientType, MethodInfo methodInfo)
170-
{
171-
MethodInfo genericMethodInfo = methodInfo.MakeGenericMethod(typeof(TToken));
172-
Type delegateType = typeof(Action<,,>).MakeGenericType(typeof(IMessenger), recipientType, typeof(TToken));
173-
174-
// We need an unsafe cast here like we did in StrongReferenceMessenger to be able to treat the new delegate
175-
// type as if it was covariant in its input recipient. This allows us to keep the type-specific overloads in
176-
// the generated code while still creating non-generic delegates here. This code is technically safe since
177-
// we have control over what types we're working with, and we know the type conversions will always be valid.
178-
return Unsafe.As<Action<IMessenger, object, TToken>>(genericMethodInfo.CreateDelegate(delegateType));
179-
}
180-
181150
// Fallback method when a generated method is not found.
182151
// This method is only invoked once per recipient type and token type, so we're not
183152
// worried about making it super efficient, and we can use the LINQ code for clarity.

UnitTests/UnitTests.NetCore/Mvvm/Test_IRecipientGenerator.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#pragma warning disable CS0618
66

7+
using System;
78
using Microsoft.Toolkit.Mvvm.Messaging;
89
using Microsoft.VisualStudio.TestTools.UnitTesting;
910

@@ -22,7 +23,9 @@ public void Test_IRecipientGenerator_GeneratedRegistration()
2223
var messageA = new MessageA();
2324
var messageB = new MessageB();
2425

25-
Microsoft.Toolkit.Mvvm.Messaging.__Internals.__IMessengerExtensions.RegisterAll(messenger, recipient, 42);
26+
Action<IMessenger, object, int> registrator = Microsoft.Toolkit.Mvvm.Messaging.__Internals.__IMessengerExtensions.CreateAllMessagesRegistratorWithToken<int>(recipient);
27+
28+
registrator(messenger, recipient, 42);
2629

2730
Assert.IsTrue(messenger.IsRegistered<MessageA, int>(recipient, 42));
2831
Assert.IsTrue(messenger.IsRegistered<MessageB, int>(recipient, 42));

0 commit comments

Comments
 (0)