Skip to content

Commit 02c87eb

Browse files
committed
Added initial support for generated validation property attributes
1 parent e4f5648 commit 02c87eb

File tree

2 files changed

+100
-10
lines changed

2 files changed

+100
-10
lines changed

Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -165,22 +165,32 @@ private PropertyDeclarationSyntax CreatePropertyDeclaration(GeneratorExecutionCo
165165
IdentifierName("value"))),
166166
ExpressionStatement(InvocationExpression(IdentifierName("OnPropertyChanged"))));
167167

168-
INamedTypeSymbol attributeSymbol = context.Compilation.GetTypeByMetadataName(typeof(AlsoNotifyForAttribute).FullName)!;
168+
INamedTypeSymbol alsoNotifyForAttributeSymbol = context.Compilation.GetTypeByMetadataName(typeof(AlsoNotifyForAttribute).FullName)!;
169+
INamedTypeSymbol? validationAttributeSymbol = context.Compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.ValidationAttribute");
169170

170-
// Add dependent property notifications, if needed
171-
if (fieldSymbol.GetAttributes().FirstOrDefault(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, attributeSymbol)) is AttributeData attributeData &&
172-
attributeData.ConstructorArguments.Length == 1)
171+
List<AttributeSyntax> validationAttributes = new();
172+
173+
foreach (AttributeData attributeData in fieldSymbol.GetAttributes())
173174
{
174-
foreach (TypedConstant attributeArgument in attributeData.ConstructorArguments[0].Values)
175+
// Add dependent property notifications, if needed
176+
if (SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, alsoNotifyForAttributeSymbol))
175177
{
176-
if (attributeArgument.Value is string dependentPropertyName)
178+
foreach (TypedConstant attributeArgument in attributeData.ConstructorArguments[0].Values)
177179
{
178-
// OnPropertyChanged("OtherPropertyName");
179-
setter = setter.AddStatements(ExpressionStatement(
180-
InvocationExpression(IdentifierName("OnPropertyChanged"))
181-
.AddArgumentListArguments(Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(dependentPropertyName))))));
180+
if (attributeArgument.Value is string dependentPropertyName)
181+
{
182+
// OnPropertyChanged("OtherPropertyName");
183+
setter = setter.AddStatements(ExpressionStatement(
184+
InvocationExpression(IdentifierName("OnPropertyChanged"))
185+
.AddArgumentListArguments(Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(dependentPropertyName))))));
186+
}
182187
}
183188
}
189+
else if (validationAttributeSymbol is not null &&
190+
attributeData.AttributeClass?.InheritsFrom(validationAttributeSymbol) == true)
191+
{
192+
validationAttributes.Add(attributeData.AsAttributeSyntax());
193+
}
184194
}
185195

186196
// Construct the generated property as follows:
@@ -189,6 +199,7 @@ private PropertyDeclarationSyntax CreatePropertyDeclaration(GeneratorExecutionCo
189199
// [global::System.CodeDom.Compiler.GeneratedCode("...", "...")]
190200
// [global::System.Diagnostics.DebuggerNonUserCode]
191201
// [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
202+
// <VALIDATION_ATTRIBUTES> // Optional
192203
// public <FIELD_TYPE> <PROPERTY_NAME>
193204
// {
194205
// get => <FIELD_NAME>;
@@ -212,6 +223,7 @@ private PropertyDeclarationSyntax CreatePropertyDeclaration(GeneratorExecutionCo
212223
AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(GetType().Assembly.GetName().Version.ToString())))))),
213224
AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.DebuggerNonUserCode")))),
214225
AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage")))))
226+
.AddAttributeLists(validationAttributes.Select(static a => AttributeList(SingletonSeparatedList(a))).ToArray())
215227
.WithLeadingTrivia(leadingTrivia)
216228
.AddModifiers(Token(SyntaxKind.PublicKeyword))
217229
.AddAccessorListAccessors(

Microsoft.Toolkit.Mvvm.SourceGenerators/Extensions/AttributeDataExtensions.cs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.Collections.Generic;
67
using System.Diagnostics.Contracts;
8+
using System.Linq;
79
using Microsoft.CodeAnalysis;
10+
using Microsoft.CodeAnalysis.CSharp;
11+
using Microsoft.CodeAnalysis.CSharp.Syntax;
12+
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
813

914
namespace Microsoft.Toolkit.Mvvm.SourceGenerators.Extensions
1015
{
@@ -36,5 +41,78 @@ properties.Value.Value is T argumentValue &&
3641

3742
return false;
3843
}
44+
45+
/// <summary>
46+
/// Creates an <see cref="AttributeSyntax"/> node that is equivalent to the input <see cref="AttributeData"/> instance.
47+
/// </summary>
48+
/// <param name="attributeData">The input <see cref="AttributeData"/> instance to process.</param>
49+
/// <returns>An <see cref="AttributeSyntax"/> replicating the data in <paramref name="attributeData"/>.</returns>
50+
[Pure]
51+
public static AttributeSyntax AsAttributeSyntax(this AttributeData attributeData)
52+
{
53+
IdentifierNameSyntax attributeType = IdentifierName(attributeData.AttributeClass!.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));
54+
AttributeArgumentSyntax[] arguments =
55+
attributeData.ConstructorArguments
56+
.Select(static arg => AttributeArgument(ToExpression(arg))).Concat(
57+
attributeData.NamedArguments
58+
.Select(static arg =>
59+
AttributeArgument(ToExpression(arg.Value))
60+
.WithNameEquals(NameEquals(IdentifierName(arg.Key))))).ToArray();
61+
62+
return Attribute(attributeType, AttributeArgumentList(SeparatedList(SeparatedList(arguments))));
63+
64+
static ExpressionSyntax ToExpression(TypedConstant arg)
65+
{
66+
if (arg.IsNull)
67+
{
68+
return LiteralExpression(SyntaxKind.NullLiteralExpression);
69+
}
70+
71+
if (arg.Kind == TypedConstantKind.Array)
72+
{
73+
string elementType = ((IArrayTypeSymbol)arg.Type!).ElementType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
74+
75+
return
76+
ArrayCreationExpression(
77+
ArrayType(IdentifierName(elementType))
78+
.AddRankSpecifiers(ArrayRankSpecifier(SingletonSeparatedList<ExpressionSyntax>(OmittedArraySizeExpression()))))
79+
.WithInitializer(InitializerExpression(SyntaxKind.ArrayInitializerExpression)
80+
.AddExpressions(arg.Values.Select(ToExpression).ToArray()));
81+
}
82+
83+
switch ((arg.Kind, arg.Value))
84+
{
85+
case (TypedConstantKind.Primitive, string text):
86+
return LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(text));
87+
case (TypedConstantKind.Primitive, bool flag) when flag:
88+
return LiteralExpression(SyntaxKind.TrueLiteralExpression);
89+
case (TypedConstantKind.Primitive, bool):
90+
return LiteralExpression(SyntaxKind.FalseLiteralExpression);
91+
case (TypedConstantKind.Primitive, object value):
92+
return LiteralExpression(SyntaxKind.NumericLiteralExpression, value switch
93+
{
94+
byte b => Literal(b),
95+
char c => Literal(c),
96+
double d => Literal(d),
97+
float f => Literal(f),
98+
int i => Literal(i),
99+
long l => Literal(l),
100+
sbyte sb => Literal(sb),
101+
short sh => Literal(sh),
102+
uint ui => Literal(ui),
103+
ulong ul => Literal(ul),
104+
ushort ush => Literal(ush),
105+
_ => throw new ArgumentException()
106+
});
107+
case (TypedConstantKind.Type, ITypeSymbol type):
108+
return TypeOfExpression(IdentifierName(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)));
109+
case (TypedConstantKind.Enum, object value):
110+
return CastExpression(
111+
IdentifierName(arg.Type!.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)),
112+
LiteralExpression(SyntaxKind.NumericLiteralExpression, ParseToken(value.ToString())));
113+
default: throw new ArgumentException();
114+
}
115+
}
116+
}
39117
}
40118
}

0 commit comments

Comments
 (0)