Skip to content

Commit 4617e75

Browse files
Merge pull request #352 from reduckted/feature/initialize-commands-code-analyzer
Created a code analyzer to detect commands that have not been initialized
2 parents cd2a7d6 + dd73829 commit 4617e75

File tree

7 files changed

+667
-0
lines changed

7 files changed

+667
-0
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
using System;
2+
using System.Collections.Concurrent;
3+
using System.Collections.Generic;
4+
using System.Collections.Immutable;
5+
using System.Linq;
6+
using Microsoft.CodeAnalysis;
7+
using Microsoft.CodeAnalysis.CSharp;
8+
using Microsoft.CodeAnalysis.CSharp.Syntax;
9+
using Microsoft.CodeAnalysis.Diagnostics;
10+
11+
namespace Community.VisualStudio.Toolkit.Analyzers
12+
{
13+
/// <summary>
14+
/// Detects commands that have not been initialized.
15+
/// </summary>
16+
[DiagnosticAnalyzer(LanguageNames.CSharp)]
17+
public class CVST005InitializeCommandsAnalyzer : AnalyzerBase
18+
{
19+
internal const string DiagnosticId = Diagnostics.InitializeCommands;
20+
internal const string CommandNameKey = "CommandName";
21+
22+
private static readonly DiagnosticDescriptor _rule = new(
23+
DiagnosticId,
24+
GetLocalizableString(nameof(Resources.CVST005_Title)),
25+
GetLocalizableString(nameof(Resources.CVST005_MessageFormat)),
26+
"Usage",
27+
DiagnosticSeverity.Warning,
28+
isEnabledByDefault: true,
29+
description: GetLocalizableString(nameof(Resources.CVST005_Description)));
30+
31+
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = ImmutableArray.Create(_rule);
32+
33+
public override void Initialize(AnalysisContext context)
34+
{
35+
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
36+
context.EnableConcurrentExecution();
37+
38+
context.RegisterCompilationStartAction(OnCompilationStart);
39+
}
40+
41+
private static void OnCompilationStart(CompilationStartAnalysisContext context)
42+
{
43+
INamedTypeSymbol asyncPackageType;
44+
INamedTypeSymbol baseCommandType;
45+
46+
47+
asyncPackageType = context.Compilation.GetTypeByMetadataName(KnownTypeNames.AsyncPackage);
48+
baseCommandType = context.Compilation.GetTypeByMetadataName(KnownTypeNames.BaseCommand);
49+
50+
if ((asyncPackageType is not null) && (baseCommandType is not null))
51+
{
52+
State state = new(asyncPackageType, baseCommandType);
53+
54+
context.RegisterSyntaxNodeAction(
55+
(x) => RecordCommandsAndPackages(state, x),
56+
SyntaxKind.ClassDeclaration
57+
);
58+
59+
context.RegisterCompilationEndAction((x) =>
60+
{
61+
CheckInitialization(state, x);
62+
state.Dispose();
63+
});
64+
}
65+
}
66+
67+
private static void RecordCommandsAndPackages(State state, SyntaxNodeAnalysisContext context)
68+
{
69+
ClassDeclarationSyntax declaration = (ClassDeclarationSyntax)context.Node;
70+
INamedTypeSymbol classSymbol = context.SemanticModel.GetDeclaredSymbol(declaration, context.CancellationToken);
71+
if (classSymbol is not null)
72+
{
73+
// Abstract classes cannot be created, which means that even if this class
74+
// inherits from `BaseCommand`, it's not a command that can be registered.
75+
if (!declaration.Modifiers.Any(SyntaxKind.AbstractKeyword))
76+
{
77+
if (IsCommand(classSymbol, state.BaseCommandType))
78+
{
79+
state.Commands[classSymbol] = false;
80+
}
81+
}
82+
83+
if (classSymbol.IsAssignableTo(state.AsyncPackageType))
84+
{
85+
state.Packages.Add((declaration, context.SemanticModel));
86+
}
87+
}
88+
}
89+
90+
private static bool IsCommand(INamedTypeSymbol classType, INamedTypeSymbol baseCommandType)
91+
{
92+
INamedTypeSymbol? baseType = classType.BaseType;
93+
while (baseType is not null)
94+
{
95+
if (baseType.IsGenericType && baseType.OriginalDefinition.Equals(baseCommandType))
96+
{
97+
return true;
98+
}
99+
100+
baseType = baseType.BaseType;
101+
}
102+
103+
return false;
104+
}
105+
106+
private static void CheckInitialization(State state, CompilationAnalysisContext context)
107+
{
108+
if ((state.Commands.Count == 0) || (state.Packages.Count == 0))
109+
{
110+
return;
111+
}
112+
113+
bool initializeIndividually = false;
114+
foreach ((ClassDeclarationSyntax Class, SemanticModel SemanticModel) package in state.Packages)
115+
{
116+
switch (CheckInitializationInPackage(package.Class, package.SemanticModel, state.Commands))
117+
{
118+
case InitializationMode.Bulk:
119+
// All commands are being initialized in bulk,
120+
// so we don't need to check anything else.
121+
return;
122+
123+
case InitializationMode.Individual:
124+
// Commands are being initialized individually. If we have
125+
// to report diagnostics, we'll tell the code fix that the
126+
// uninitialized commands should be initialized individually.
127+
initializeIndividually = true;
128+
break;
129+
130+
}
131+
}
132+
133+
if (state.Commands.Any((x) => !x.Value))
134+
{
135+
// Usually there's only one package, so just
136+
// report the diagnostics in the first one.
137+
ClassDeclarationSyntax packageToReport = state.Packages
138+
.Select((x) => x.Class)
139+
.OrderBy((x) => x.Identifier.ValueText)
140+
.First();
141+
142+
MethodDeclarationSyntax? initializeAsyncMethod = FindInitializeAsyncMethod(packageToReport);
143+
144+
foreach (INamedTypeSymbol command in state.Commands.Where((x) => !x.Value).Select((x) => x.Key))
145+
{
146+
147+
// If the commands should be initialized individually,
148+
// then we need to include a property in the diagnostic
149+
// that tells the code fix what the name of the command is.
150+
ImmutableDictionary<string, string> properties =
151+
initializeIndividually ?
152+
ImmutableDictionary.CreateRange(new[] { new KeyValuePair<string, string>(CommandNameKey, command.Name) }) :
153+
ImmutableDictionary<string, string>.Empty;
154+
155+
context.ReportDiagnostic(
156+
Diagnostic.Create(
157+
_rule,
158+
(initializeAsyncMethod?.Identifier ?? packageToReport.Identifier).GetLocation(),
159+
properties,
160+
command.Name
161+
)
162+
);
163+
}
164+
}
165+
}
166+
167+
private static InitializationMode? CheckInitializationInPackage(
168+
ClassDeclarationSyntax package,
169+
SemanticModel semanticModel,
170+
ConcurrentDictionary<ISymbol, bool> commands
171+
)
172+
{
173+
InitializationMode? mode = null;
174+
175+
MethodDeclarationSyntax? initializeAsync = FindInitializeAsyncMethod(package);
176+
if (initializeAsync is not null)
177+
{
178+
foreach (StatementSyntax statement in initializeAsync.Body.Statements)
179+
{
180+
if (statement is ExpressionStatementSyntax expression)
181+
{
182+
if (expression.Expression is AwaitExpressionSyntax awaitExpression)
183+
{
184+
if (awaitExpression.Expression is InvocationExpressionSyntax invocation)
185+
{
186+
if (invocation.Expression is MemberAccessExpressionSyntax memberAccess)
187+
{
188+
if (memberAccess.Expression.IsKind(SyntaxKind.ThisExpression))
189+
{
190+
if (memberAccess.Name.Identifier.ValueText == "RegisterCommandsAsync")
191+
{
192+
// The statement is `this.RegisterCommandsAsync()`.
193+
return InitializationMode.Bulk;
194+
}
195+
}
196+
else if (memberAccess.Name.Identifier.ValueText == "InitializeAsync")
197+
{
198+
SymbolInfo symbolInfo = semanticModel.GetSymbolInfo(memberAccess.Expression);
199+
if (commands.ContainsKey(symbolInfo.Symbol))
200+
{
201+
// The statement is `Command.InitializeAsync(package)`.
202+
commands[symbolInfo.Symbol] = true;
203+
mode = InitializationMode.Individual;
204+
}
205+
}
206+
}
207+
}
208+
}
209+
}
210+
}
211+
}
212+
213+
return mode;
214+
}
215+
216+
internal static MethodDeclarationSyntax? FindInitializeAsyncMethod(ClassDeclarationSyntax package)
217+
{
218+
foreach (MethodDeclarationSyntax method in package.Members.OfType<MethodDeclarationSyntax>())
219+
{
220+
if (method.Modifiers.Any(SyntaxKind.ProtectedKeyword) && method.Modifiers.Any(SyntaxKind.OverrideKeyword))
221+
{
222+
if (string.Equals(method.Identifier.ValueText, "InitializeAsync"))
223+
{
224+
return method;
225+
}
226+
}
227+
}
228+
229+
return null;
230+
}
231+
232+
private enum InitializationMode
233+
{
234+
Individual,
235+
Bulk
236+
}
237+
238+
private class State : IDisposable
239+
{
240+
public State(INamedTypeSymbol asyncPackageType, INamedTypeSymbol baseCommandType)
241+
{
242+
AsyncPackageType = asyncPackageType;
243+
BaseCommandType = baseCommandType;
244+
}
245+
246+
public INamedTypeSymbol AsyncPackageType { get; }
247+
248+
public INamedTypeSymbol BaseCommandType { get; }
249+
250+
public ConcurrentDictionary<ISymbol, bool> Commands { get; } = new();
251+
252+
public ConcurrentBag<(ClassDeclarationSyntax Class, SemanticModel SemanticModel)> Packages { get; } = new();
253+
254+
public void Dispose()
255+
{
256+
// ConcurrentBag stores data per-thread, and thata data remains in memory
257+
// until it is removed from the bag. Prevent a memory leak by emptying the bag.
258+
while (Packages.Count > 0)
259+
{
260+
Packages.TryTake(out _);
261+
}
262+
}
263+
}
264+
}
265+
}

0 commit comments

Comments
 (0)