|
2 | 2 | // The .NET Foundation licenses this file to you under the MIT license.
|
3 | 3 | // See the LICENSE file in the project root for more information.
|
4 | 4 |
|
5 |
| -using System.Collections.Generic; |
6 |
| -using System.Diagnostics.CodeAnalysis; |
| 5 | +using System.Collections.Immutable; |
7 | 6 | using System.Linq;
|
| 7 | +using CommunityToolkit.Mvvm.SourceGenerators.ComponentModel.Models; |
| 8 | +using CommunityToolkit.Mvvm.SourceGenerators.Diagnostics; |
| 9 | +using CommunityToolkit.Mvvm.SourceGenerators.Extensions; |
8 | 10 | using Microsoft.CodeAnalysis;
|
9 | 11 | using Microsoft.CodeAnalysis.CSharp;
|
10 | 12 | using Microsoft.CodeAnalysis.CSharp.Syntax;
|
11 |
| -using CommunityToolkit.Mvvm.SourceGenerators.Extensions; |
12 |
| -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; |
13 | 13 | using static CommunityToolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors;
|
| 14 | +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; |
14 | 15 |
|
15 | 16 | namespace CommunityToolkit.Mvvm.SourceGenerators;
|
16 | 17 |
|
17 | 18 | /// <summary>
|
18 | 19 | /// A source generator for the <c>ObservableRecipientAttribute</c> type.
|
19 | 20 | /// </summary>
|
20 |
| -[Generator] |
21 |
| -public sealed class ObservableRecipientGenerator : TransitiveMembersGenerator |
| 21 | +[Generator(LanguageNames.CSharp)] |
| 22 | +public sealed class ObservableRecipientGenerator : TransitiveMembersGenerator2<ObservableRecipientInfo> |
22 | 23 | {
|
23 | 24 | /// <summary>
|
24 | 25 | /// Initializes a new instance of the <see cref="ObservableRecipientGenerator"/> class.
|
25 | 26 | /// </summary>
|
26 | 27 | public ObservableRecipientGenerator()
|
27 |
| - : base("CommunityToolkit.Mvvm.ComponentModel.ObservableRecipientAttribute") |
| 28 | + : base("global::CommunityToolkit.Mvvm.ComponentModel.ObservableRecipientAttribute") |
28 | 29 | {
|
29 | 30 | }
|
30 | 31 |
|
31 | 32 | /// <inheritdoc/>
|
32 |
| - protected override DiagnosticDescriptor TargetTypeErrorDescriptor => ObservableRecipientGeneratorError; |
| 33 | + protected override ObservableRecipientInfo GetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData) |
| 34 | + { |
| 35 | + string typeName = typeSymbol.Name; |
| 36 | + bool hasExplicitConstructors = !(typeSymbol.InstanceConstructors.Length == 1 && typeSymbol.InstanceConstructors[0] is { Parameters.IsEmpty: true, IsImplicitlyDeclared: true }); |
| 37 | + bool isAbstract = typeSymbol.IsAbstract; |
| 38 | + bool isObservableValidator = typeSymbol.InheritsFrom("global::CommunityToolkit.Mvvm.ComponentModel.ObservableValidator"); |
| 39 | + |
| 40 | + return new( |
| 41 | + typeName, |
| 42 | + hasExplicitConstructors, |
| 43 | + isAbstract, |
| 44 | + isObservableValidator); |
| 45 | + } |
33 | 46 |
|
34 | 47 | /// <inheritdoc/>
|
35 |
| - protected override bool ValidateTargetType( |
36 |
| - GeneratorExecutionContext context, |
37 |
| - AttributeData attributeData, |
38 |
| - ClassDeclarationSyntax classDeclaration, |
39 |
| - INamedTypeSymbol classDeclarationSymbol, |
40 |
| - [NotNullWhen(false)] out DiagnosticDescriptor? descriptor) |
| 48 | + protected override bool ValidateTargetType(INamedTypeSymbol typeSymbol, ObservableRecipientInfo info, out ImmutableArray<Diagnostic> diagnostics) |
41 | 49 | {
|
42 |
| - INamedTypeSymbol observableRecipientSymbol = context.Compilation.GetTypeByMetadataName("CommunityToolkit.Mvvm.ComponentModel.ObservableRecipient")!; |
43 |
| - INamedTypeSymbol observableObjectSymbol = context.Compilation.GetTypeByMetadataName("CommunityToolkit.Mvvm.ComponentModel.ObservableObject")!; |
44 |
| - INamedTypeSymbol observableObjectAttributeSymbol = context.Compilation.GetTypeByMetadataName("CommunityToolkit.Mvvm.ComponentModel.ObservableObjectAttribute")!; |
45 |
| - INamedTypeSymbol iNotifyPropertyChangedSymbol = context.Compilation.GetTypeByMetadataName("System.ComponentModel.INotifyPropertyChanged")!; |
| 50 | + ImmutableArray<Diagnostic>.Builder builder = ImmutableArray.CreateBuilder<Diagnostic>(); |
46 | 51 |
|
47 | 52 | // Check if the type already inherits from ObservableRecipient
|
48 |
| - if (classDeclarationSymbol.InheritsFrom(observableRecipientSymbol)) |
| 53 | + if (typeSymbol.InheritsFrom("global::CommunityToolkit.Mvvm.ComponentModel.ObservableRecipient")) |
49 | 54 | {
|
50 |
| - descriptor = DuplicateObservableRecipientError; |
| 55 | + builder.Add(DuplicateObservableRecipientError, typeSymbol, typeSymbol); |
| 56 | + |
| 57 | + diagnostics = builder.ToImmutable(); |
51 | 58 |
|
52 | 59 | return false;
|
53 | 60 | }
|
54 | 61 |
|
55 | 62 | // In order to use [ObservableRecipient], the target type needs to inherit from ObservableObject,
|
56 | 63 | // or be annotated with [ObservableObject] or [INotifyPropertyChanged] (with additional helpers).
|
57 |
| - if (!classDeclarationSymbol.InheritsFrom(observableObjectSymbol) && |
58 |
| - !classDeclarationSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, observableObjectAttributeSymbol)) && |
59 |
| - !classDeclarationSymbol.GetAttributes().Any(a => |
60 |
| - SymbolEqualityComparer.Default.Equals(a.AttributeClass, iNotifyPropertyChangedSymbol) && |
| 64 | + if (!typeSymbol.InheritsFrom("global::CommunityToolkit.Mvvm.ComponentModel.ObservableObject") && |
| 65 | + !typeSymbol.GetAttributes().Any(static a => a.AttributeClass?.HasFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.ObservableObjectAttribute") == true) && |
| 66 | + !typeSymbol.GetAttributes().Any(static a => |
| 67 | + a.AttributeClass?.HasFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.INotifyPropertyChangedAttribute") == true && |
61 | 68 | !a.HasNamedArgument("IncludeAdditionalHelperMethods", false)))
|
62 | 69 | {
|
63 |
| - descriptor = MissingBaseObservableObjectFunctionalityError; |
| 70 | + builder.Add(MissingBaseObservableObjectFunctionalityError, typeSymbol, typeSymbol); |
| 71 | + |
| 72 | + diagnostics = builder.ToImmutable(); |
64 | 73 |
|
65 | 74 | return false;
|
66 | 75 | }
|
67 | 76 |
|
68 |
| - descriptor = null; |
| 77 | + diagnostics = builder.ToImmutable(); |
69 | 78 |
|
70 | 79 | return true;
|
71 | 80 | }
|
72 | 81 |
|
73 | 82 | /// <inheritdoc/>
|
74 |
| - protected override IEnumerable<MemberDeclarationSyntax> FilterDeclaredMembers( |
75 |
| - GeneratorExecutionContext context, |
76 |
| - AttributeData attributeData, |
77 |
| - ClassDeclarationSyntax classDeclaration, |
78 |
| - INamedTypeSymbol classDeclarationSymbol, |
79 |
| - ClassDeclarationSyntax sourceDeclaration) |
| 83 | + protected override ImmutableArray<MemberDeclarationSyntax> FilterDeclaredMembers(ObservableRecipientInfo info, ClassDeclarationSyntax classDeclaration) |
80 | 84 | {
|
| 85 | + ImmutableArray<MemberDeclarationSyntax>.Builder builder = ImmutableArray.CreateBuilder<MemberDeclarationSyntax>(); |
| 86 | + |
81 | 87 | // If the target type has no constructors, generate constructors as well
|
82 |
| - if (classDeclarationSymbol.InstanceConstructors.Length == 1 && |
83 |
| - classDeclarationSymbol.InstanceConstructors[0] is |
84 |
| - { |
85 |
| - Parameters: { IsEmpty: true }, |
86 |
| - DeclaringSyntaxReferences: { IsEmpty: true }, |
87 |
| - IsImplicitlyDeclared: true |
88 |
| - }) |
| 88 | + if (!info.HasExplicitConstructors) |
89 | 89 | {
|
90 |
| - foreach (ConstructorDeclarationSyntax ctor in sourceDeclaration.Members.OfType<ConstructorDeclarationSyntax>()) |
| 90 | + foreach (ConstructorDeclarationSyntax ctor in classDeclaration.Members.OfType<ConstructorDeclarationSyntax>()) |
91 | 91 | {
|
92 | 92 | string text = ctor.NormalizeWhitespace().ToFullString();
|
93 |
| - string replaced = text.Replace("ObservableRecipient", classDeclarationSymbol.Name); |
| 93 | + string replaced = text.Replace("ObservableRecipient", info.TypeName); |
94 | 94 |
|
95 | 95 | // Adjust the visibility of the constructors based on whether the target type is abstract.
|
96 | 96 | // If that is not the case, the constructors have to be declared as public and not protected.
|
97 |
| - if (!classDeclarationSymbol.IsAbstract) |
| 97 | + if (!info.IsAbstract) |
98 | 98 | {
|
99 | 99 | replaced = replaced.Replace("protected", "public");
|
100 | 100 | }
|
101 | 101 |
|
102 |
| - yield return (ConstructorDeclarationSyntax)ParseMemberDeclaration(replaced)!; |
| 102 | + builder.Add((ConstructorDeclarationSyntax)ParseMemberDeclaration(replaced)!); |
103 | 103 | }
|
104 | 104 | }
|
105 | 105 |
|
106 |
| - INamedTypeSymbol observableValidatorSymbol = context.Compilation.GetTypeByMetadataName("CommunityToolkit.Mvvm.ComponentModel.ObservableValidator")!; |
107 |
| - |
108 | 106 | // Skip the SetProperty overloads if the target type inherits from ObservableValidator, to avoid conflicts
|
109 |
| - if (classDeclarationSymbol.InheritsFrom(observableValidatorSymbol)) |
| 107 | + if (info.IsObservableValidator) |
110 | 108 | {
|
111 |
| - foreach (MemberDeclarationSyntax member in sourceDeclaration.Members.Where(static member => member is not ConstructorDeclarationSyntax)) |
| 109 | + foreach (MemberDeclarationSyntax member in classDeclaration.Members.Where(static member => member is not ConstructorDeclarationSyntax)) |
112 | 110 | {
|
113 |
| - if (member is not MethodDeclarationSyntax { Identifier: { ValueText: "SetProperty" } }) |
| 111 | + if (member is not MethodDeclarationSyntax { Identifier.ValueText: "SetProperty" }) |
114 | 112 | {
|
115 |
| - yield return member; |
| 113 | + builder.Add(member); |
116 | 114 | }
|
117 | 115 | }
|
118 | 116 |
|
119 |
| - yield break; |
| 117 | + return builder.ToImmutable(); |
120 | 118 | }
|
121 | 119 |
|
122 | 120 | // If the target type has at least one custom constructor, only generate methods
|
123 |
| - foreach (MemberDeclarationSyntax member in sourceDeclaration.Members.Where(static member => member is not ConstructorDeclarationSyntax)) |
| 121 | + foreach (MemberDeclarationSyntax member in classDeclaration.Members.Where(static member => member is not ConstructorDeclarationSyntax)) |
124 | 122 | {
|
125 |
| - yield return member; |
| 123 | + builder.Add(member); |
126 | 124 | }
|
| 125 | + |
| 126 | + return builder.ToImmutable(); |
127 | 127 | }
|
128 | 128 | }
|
0 commit comments