Skip to content

Commit fe3a243

Browse files
Add two more cases for 'simplify linq expression'. (#76413)
Fixes #74968. Finds and fixes `.Where(a => a is Type).Cast<Type>()` and `.Where(a => a is Type).Select(b => (Type)b)` and converts both to `.OfType<Type>()`
2 parents e42a7b6 + 2f3531e commit fe3a243

File tree

15 files changed

+729
-10
lines changed

15 files changed

+729
-10
lines changed

src/Analyzers/CSharp/Analyzers/CSharpAnalyzers.projitems

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
<Compile Include="$(MSBuildThisFileDirectory)SimplifyInterpolation\CSharpSimplifyInterpolationHelpers.cs" />
7474
<Compile Include="$(MSBuildThisFileDirectory)SimplifyInterpolation\CSharpSimplifyInterpolationDiagnosticAnalyzer.cs" />
7575
<Compile Include="$(MSBuildThisFileDirectory)SimplifyLinqExpression\CSharpSimplifyLinqExpressionDiagnosticAnalyzer.cs" />
76+
<Compile Include="$(MSBuildThisFileDirectory)SimplifyLinqExpression\CSharpSimplifyLinqTypeCheckAndCastDiagnosticAnalyzer.cs" />
7677
<Compile Include="$(MSBuildThisFileDirectory)SimplifyPropertyPattern\CSharpSimplifyPropertyPatternDiagnosticAnalyzer.cs" />
7778
<Compile Include="$(MSBuildThisFileDirectory)SimplifyPropertyPattern\SimplifyPropertyPatternHelpers.cs" />
7879
<Compile Include="$(MSBuildThisFileDirectory)ConvertProgram\ConvertToProgramMainDiagnosticAnalyzer.cs" />
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Diagnostics.CodeAnalysis;
7+
using System.Linq;
8+
using System.Threading;
9+
using Microsoft.CodeAnalysis.CodeStyle;
10+
using Microsoft.CodeAnalysis.CSharp.Extensions;
11+
using Microsoft.CodeAnalysis.CSharp.Syntax;
12+
using Microsoft.CodeAnalysis.Diagnostics;
13+
14+
namespace Microsoft.CodeAnalysis.CSharp.SimplifyLinqExpression;
15+
16+
[DiagnosticAnalyzer(LanguageNames.CSharp)]
17+
internal sealed class CSharpSimplifyLinqTypeCheckAndCastDiagnosticAnalyzer()
18+
: AbstractBuiltInCodeStyleDiagnosticAnalyzer(
19+
IDEDiagnosticIds.SimplifyLinqTypeCheckAndCastDiagnosticId,
20+
EnforceOnBuildValues.SimplifyLinqExpression,
21+
option: null,
22+
title: new LocalizableResourceString(nameof(AnalyzersResources.Simplify_LINQ_expression), AnalyzersResources.ResourceManager, typeof(AnalyzersResources)))
23+
{
24+
public override DiagnosticAnalyzerCategory GetAnalyzerCategory()
25+
=> DiagnosticAnalyzerCategory.SemanticSpanAnalysis;
26+
27+
protected override void InitializeWorker(AnalysisContext context)
28+
{
29+
context.RegisterCompilationStartAction(context =>
30+
{
31+
var enumerableType = context.Compilation.GetTypeByMetadataName(typeof(Enumerable).FullName!);
32+
if (enumerableType is null)
33+
return;
34+
35+
context.RegisterSyntaxNodeAction(context => AnalyzeInvocationExpression(context, enumerableType), SyntaxKind.InvocationExpression);
36+
});
37+
}
38+
39+
private static bool TryGetSingleLambdaParameter(
40+
LambdaExpressionSyntax lambda,
41+
[NotNullWhen(true)] out ParameterSyntax? lambdaParameter)
42+
{
43+
lambdaParameter = null;
44+
var whereParameters = lambda switch
45+
{
46+
ParenthesizedLambdaExpressionSyntax parenthesizedLambda => parenthesizedLambda.ParameterList.Parameters,
47+
SimpleLambdaExpressionSyntax simpleLambda => [simpleLambda.Parameter],
48+
_ => [],
49+
};
50+
51+
if (whereParameters is not [var parameter])
52+
return false;
53+
54+
lambdaParameter = parameter;
55+
return true;
56+
}
57+
58+
private static bool AnalyzeWhereMethod(
59+
SemanticModel semanticModel,
60+
LambdaExpressionSyntax whereLambda,
61+
CancellationToken cancellationToken,
62+
[NotNullWhen(true)] out ITypeSymbol? whereType)
63+
{
64+
whereType = null;
65+
66+
// has to look like `a => a is ...` or `(T a) => a is ...`
67+
if (!TryGetSingleLambdaParameter(whereLambda, out var parameter))
68+
return false;
69+
70+
// Body needs to be `a is SomeType`
71+
var parameterName = parameter.Identifier.ValueText;
72+
if (whereLambda.Body is not BinaryExpressionSyntax(kind: SyntaxKind.IsExpression)
73+
{
74+
Left: IdentifierNameSyntax leftIdentifier,
75+
Right: TypeSyntax whereTypeSyntax
76+
})
77+
{
78+
return false;
79+
}
80+
81+
// Value being checked needs to be the parameter passed in.
82+
if (leftIdentifier.Identifier.ValueText != parameterName)
83+
return false;
84+
85+
whereType = semanticModel.GetTypeInfo(whereTypeSyntax, cancellationToken).Type;
86+
return whereType != null;
87+
}
88+
89+
private bool AnalyzeInvocationExpression(
90+
InvocationExpressionSyntax invocationExpression,
91+
[NotNullWhen(true)] out LambdaExpressionSyntax? whereLambda,
92+
[NotNullWhen(true)] out InvocationExpressionSyntax? whereInvocation,
93+
[NotNullWhen(true)] out SimpleNameSyntax? caseOrSelectName,
94+
[NotNullWhen(true)] out TypeSyntax? caseOrSelectType)
95+
{
96+
whereLambda = null;
97+
whereInvocation = null;
98+
caseOrSelectName = null;
99+
caseOrSelectType = null;
100+
101+
// Both forms need to be accessed off of `.Where(... => ...)`
102+
// Needs to look like `.Where(...).Cast<...>()`
103+
if (invocationExpression is not
104+
{
105+
Expression: MemberAccessExpressionSyntax
106+
{
107+
Expression: InvocationExpressionSyntax
108+
{
109+
// Needs to be `.Where(... => ...)`
110+
ArgumentList.Arguments: [{ Expression: LambdaExpressionSyntax whereLambda1 }],
111+
Expression: MemberAccessExpressionSyntax
112+
{
113+
Name: IdentifierNameSyntax { Identifier.ValueText: nameof(Enumerable.Where) },
114+
},
115+
} whereInvocation1,
116+
},
117+
})
118+
{
119+
return false;
120+
}
121+
122+
whereLambda = whereLambda1;
123+
whereInvocation = whereInvocation1;
124+
125+
if (invocationExpression is
126+
{
127+
// Needs to be `.Cast<T>()`
128+
ArgumentList.Arguments: [],
129+
Expression: MemberAccessExpressionSyntax
130+
{
131+
Name: GenericNameSyntax
132+
{
133+
Identifier.ValueText: nameof(Enumerable.Cast),
134+
TypeArgumentList.Arguments: [var castTypeArgument]
135+
} castName,
136+
},
137+
})
138+
{
139+
caseOrSelectName = castName;
140+
caseOrSelectType = castTypeArgument;
141+
return true;
142+
}
143+
144+
// Needs to be `.Select(a => (T)a)`
145+
if (invocationExpression is
146+
{
147+
ArgumentList.Arguments: [
148+
{
149+
// a => (T)a
150+
Expression: LambdaExpressionSyntax
151+
{
152+
ExpressionBody: CastExpressionSyntax
153+
{
154+
Type: var lambdaCastType,
155+
Expression: IdentifierNameSyntax castIdentifier,
156+
} lambdaCast,
157+
} selectLambda
158+
}],
159+
Expression: MemberAccessExpressionSyntax
160+
{
161+
Name: IdentifierNameSyntax
162+
{
163+
Identifier.ValueText: nameof(Enumerable.Select),
164+
} selectName,
165+
},
166+
} && TryGetSingleLambdaParameter(selectLambda, out var selectLambdaParameter) &&
167+
selectLambdaParameter.Identifier.ValueText == castIdentifier.Identifier.ValueText)
168+
{
169+
caseOrSelectName = selectName;
170+
caseOrSelectType = lambdaCastType;
171+
return true;
172+
}
173+
174+
return false;
175+
}
176+
177+
private void AnalyzeInvocationExpression(
178+
SyntaxNodeAnalysisContext context, INamedTypeSymbol enumerableType)
179+
{
180+
var cancellationToken = context.CancellationToken;
181+
var semanticModel = context.SemanticModel;
182+
183+
if (ShouldSkipAnalysis(context, notification: null))
184+
return;
185+
186+
var invocationExpression = (InvocationExpressionSyntax)context.Node;
187+
188+
if (!AnalyzeInvocationExpression(invocationExpression,
189+
out var whereLambda,
190+
out var whereInvocation,
191+
out var castOrSelectName,
192+
out var castTypeArgument))
193+
{
194+
return;
195+
}
196+
197+
if (!AnalyzeWhereMethod(semanticModel, whereLambda, cancellationToken, out var whereType))
198+
return;
199+
200+
// Ensure the `is SomeType` and `Cast<SomeType>` are the same type.
201+
var castType = semanticModel.GetTypeInfo(castTypeArgument, cancellationToken).Type;
202+
if (castType is null)
203+
return;
204+
205+
if (!whereType.Equals(castType))
206+
return;
207+
208+
var castOrSelectSymbol = semanticModel.GetSymbolInfo(invocationExpression, cancellationToken).Symbol;
209+
var whereSymbol = semanticModel.GetSymbolInfo(whereInvocation, cancellationToken).Symbol;
210+
211+
if (!enumerableType.Equals(castOrSelectSymbol?.OriginalDefinition.ContainingType) ||
212+
!enumerableType.Equals(whereSymbol?.OriginalDefinition.ContainingType))
213+
{
214+
return;
215+
}
216+
217+
context.ReportDiagnostic(Diagnostic.Create(
218+
Descriptor,
219+
castOrSelectName.Identifier.GetLocation(),
220+
additionalLocations: [invocationExpression.GetLocation(), castTypeArgument.GetLocation()]));
221+
}
222+
}

src/Analyzers/CSharp/CodeFixes/CSharpCodeFixes.projitems

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
<Compile Include="$(MSBuildThisFileDirectory)RemoveUnnecessaryNullableDirective\CSharpRemoveUnnecessaryNullableDirectiveCodeFixProvider.cs" />
9393
<Compile Include="$(MSBuildThisFileDirectory)RemoveUnusedLocalFunction\CSharpRemoveUnusedLocalFunctionCodeFixProvider.cs" />
9494
<Compile Include="$(MSBuildThisFileDirectory)ReplaceDefaultLiteral\CSharpReplaceDefaultLiteralCodeFixProvider.cs" />
95+
<Compile Include="$(MSBuildThisFileDirectory)SimplifyLinqExpression\CSharpSimplifyLinqTypeCheckAndCastCodeFixProvider.cs" />
9596
<Compile Include="$(MSBuildThisFileDirectory)UnsealClass\CSharpUnsealClassCodeFixProvider.cs" />
9697
<Compile Include="$(MSBuildThisFileDirectory)UpdateProjectToAllowUnsafe\CSharpUpdateProjectToAllowUnsafeCodeFixProvider.cs" />
9798
<Compile Include="$(MSBuildThisFileDirectory)UpgradeProject\CSharpUpgradeProjectCodeFixProvider.cs" />
@@ -184,4 +185,7 @@
184185
<ItemGroup Condition="'$(DefaultLanguageSourceExtension)' != '' AND '$(BuildingInsideVisualStudio)' != 'true'">
185186
<ExpectedCompile Include="$(MSBuildThisFileDirectory)**\*$(DefaultLanguageSourceExtension)" />
186187
</ItemGroup>
188+
<ItemGroup>
189+
<Folder Include="$(MSBuildThisFileDirectory)SimplifyLinqExpression\" />
190+
</ItemGroup>
187191
</Project>
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Immutable;
6+
using System.Composition;
7+
using System.Diagnostics.CodeAnalysis;
8+
using System.Linq;
9+
using System.Threading;
10+
using System.Threading.Tasks;
11+
using Microsoft.CodeAnalysis.CodeFixes;
12+
using Microsoft.CodeAnalysis.CSharp.Syntax;
13+
using Microsoft.CodeAnalysis.Diagnostics;
14+
using Microsoft.CodeAnalysis.Editing;
15+
using Microsoft.CodeAnalysis.Shared.Extensions;
16+
17+
namespace Microsoft.CodeAnalysis.CSharp.SimplifyLinqExpression;
18+
19+
using static SyntaxFactory;
20+
21+
[ExportCodeFixProvider(LanguageNames.CSharp, Name = PredefinedCodeFixProviderNames.SimplifyLinqTypeCheckAndCast), Shared]
22+
[method: ImportingConstructor]
23+
[method: SuppressMessage("RoslynDiagnosticsReliability", "RS0033:Importing constructor should be [Obsolete]", Justification = "Used in test code: https://github.com/dotnet/roslyn/issues/42814")]
24+
internal sealed class CSharpSimplifyLinqTypeCheckAndCastCodeFixProvider()
25+
: SyntaxEditorBasedCodeFixProvider
26+
{
27+
public override ImmutableArray<string> FixableDiagnosticIds
28+
=> [IDEDiagnosticIds.SimplifyLinqTypeCheckAndCastDiagnosticId];
29+
30+
public override Task RegisterCodeFixesAsync(CodeFixContext context)
31+
{
32+
RegisterCodeFix(context, AnalyzersResources.Simplify_LINQ_expression, nameof(AnalyzersResources.Simplify_LINQ_expression));
33+
return Task.CompletedTask;
34+
}
35+
36+
protected override Task FixAllAsync(
37+
Document document,
38+
ImmutableArray<Diagnostic> diagnostics,
39+
SyntaxEditor editor,
40+
CancellationToken cancellationToken)
41+
{
42+
// Because the pattern is very specific (`.Where(a => a is Type).Cast<Type>()`), we know that no diagnostic can
43+
// be nested in another. So we don't have to process these inside-out like we do with other fixers.
44+
foreach (var diagnostic in diagnostics)
45+
{
46+
var castOrSelectInvocation = (InvocationExpressionSyntax)diagnostic.AdditionalLocations[0].FindNode(getInnermostNodeForTie: true, cancellationToken);
47+
var typeSyntax = (TypeSyntax)diagnostic.AdditionalLocations[1].FindNode(getInnermostNodeForTie: true, cancellationToken);
48+
49+
var castOrSelectMemberAccess = (MemberAccessExpressionSyntax)castOrSelectInvocation.Expression;
50+
var castOrSelectName = castOrSelectMemberAccess.Name;
51+
var castOrSelectNameToken = castOrSelectName.Identifier;
52+
53+
var ofTypeToken = Identifier(nameof(Enumerable.OfType)).WithTriviaFrom(castOrSelectNameToken);
54+
if (castOrSelectName is GenericNameSyntax)
55+
{
56+
// Change .Cast<T>() to .OfType<T>()
57+
editor.ReplaceNode(
58+
castOrSelectName,
59+
castOrSelectName.ReplaceToken(castOrSelectNameToken, ofTypeToken));
60+
}
61+
else
62+
{
63+
// Change .Select(...) to .OfType<T>()
64+
editor.ReplaceNode(
65+
castOrSelectName,
66+
GenericName(ofTypeToken).AddTypeArgumentListArguments(typeSyntax.WithoutTrivia()));
67+
editor.ReplaceNode(
68+
castOrSelectInvocation.ArgumentList,
69+
castOrSelectInvocation.ArgumentList.WithArguments([]));
70+
}
71+
72+
var whereInvocation = (InvocationExpressionSyntax)castOrSelectMemberAccess.Expression;
73+
var whereMemberAccess = (MemberAccessExpressionSyntax)whereInvocation.Expression;
74+
75+
// Snip out the `.Where(...)` portion so that `expr.Where(...).OfType<T>()` becomes `expr.OfType<T>()`
76+
editor.ReplaceNode(whereInvocation, whereMemberAccess.Expression);
77+
}
78+
79+
return Task.CompletedTask;
80+
}
81+
}

src/Analyzers/CSharp/Tests/CSharpAnalyzers.UnitTests.projitems

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
<Compile Include="$(MSBuildThisFileDirectory)RemoveUnusedLocalFunction\RemoveUnusedLocalFunctionTests.cs" />
9898
<Compile Include="$(MSBuildThisFileDirectory)ReplaceDefaultLiteral\ReplaceDefaultLiteralTests.cs" />
9999
<Compile Include="$(MSBuildThisFileDirectory)SimplifyLinqExpression\CSharpSimplifyLinqExpressionFixAllTests.cs" />
100+
<Compile Include="$(MSBuildThisFileDirectory)SimplifyLinqExpression\CSharpSimplifyLinqTypeCheckAndCastTests.cs" />
100101
<Compile Include="$(MSBuildThisFileDirectory)SimplifyLinqExpression\CSharpSimplifyLinqExpressionTests.cs" />
101102
<Compile Include="$(MSBuildThisFileDirectory)TransposeRecordKeyword\TransposeRecordKeywordTests.cs" />
102103
<Compile Include="$(MSBuildThisFileDirectory)UnsealClass\UnsealClassTests.cs" />

0 commit comments

Comments
 (0)