99[ DiagnosticAnalyzer ( LanguageNames . CSharp ) ]
1010public sealed class NullComparisonAnalyzer : DiagnosticAnalyzer
1111{
12- const string ExtensionsDefine = "MEDICINE_EXTENSIONS_LIB" ;
13- const string UnityObjectFqn = "global::UnityEngine.Object" ;
14-
1512 public static readonly DiagnosticDescriptor MED026 = new (
1613 id : nameof ( MED026 ) ,
1714 title : "Use faster IsNull extension method" ,
@@ -39,43 +36,49 @@ public override void Initialize(AnalysisContext context)
3936 context . EnableConcurrentExecution ( ) ;
4037
4138 context . RegisterCompilationStartAction ( startContext =>
42- {
43- if ( ! HasExtensionsSymbol ( startContext . Compilation ) )
44- return ;
45-
46- startContext . RegisterSyntaxNodeAction ( AnalyzeBinary , SyntaxKind . EqualsExpression , SyntaxKind . NotEqualsExpression ) ;
47- startContext . RegisterSyntaxNodeAction ( AnalyzeIsPattern , SyntaxKind . IsPatternExpression ) ;
48- } ) ;
39+ {
40+ if ( ! HasExtensionsSymbol ( startContext . Compilation ) )
41+ return ;
42+
43+ var knownSymbols = new KnownSymbols ( startContext . Compilation ) ;
44+ startContext . RegisterSyntaxNodeAction (
45+ syntaxContext => AnalyzeBinary ( syntaxContext , knownSymbols . UnityObject ) ,
46+ SyntaxKind . EqualsExpression ,
47+ SyntaxKind . NotEqualsExpression
48+ ) ;
49+
50+ startContext . RegisterSyntaxNodeAction (
51+ syntaxContext => AnalyzeIsPattern ( syntaxContext , knownSymbols . UnityObject ) ,
52+ SyntaxKind . IsPatternExpression
53+ ) ;
54+ }
55+ ) ;
4956 }
5057
5158 static bool HasExtensionsSymbol ( Compilation compilation )
5259 {
5360 foreach ( var tree in compilation . SyntaxTrees )
54- {
55- if ( tree . Options is not CSharpParseOptions options )
56- continue ;
57-
58- foreach ( var name in options . PreprocessorSymbolNames )
59- if ( name == ExtensionsDefine )
60- return true ;
61- }
61+ if ( tree . Options is CSharpParseOptions options )
62+ foreach ( var name in options . PreprocessorSymbolNames )
63+ if ( name is Constants . MedicineExtensionsDefine )
64+ return true ;
6265
6366 return false ;
6467 }
6568
66- static void AnalyzeBinary ( SyntaxNodeAnalysisContext context )
69+ static void AnalyzeBinary ( SyntaxNodeAnalysisContext context , INamedTypeSymbol unityObjectSymbol )
6770 {
6871 if ( context . Node is not BinaryExpressionSyntax binary )
6972 return ;
7073
71- if ( ! TryGetUnityOperand ( binary . Left , binary . Right , context . SemanticModel , context . CancellationToken , out _ ) )
74+ if ( ! TryGetUnityOperand ( binary . Left , binary . Right , context . SemanticModel , context . CancellationToken , unityObjectSymbol , out _ ) )
7275 return ;
7376
7477 var descriptor = binary . IsKind ( SyntaxKind . EqualsExpression ) ? MED026 : MED027 ;
7578 context . ReportDiagnostic ( Diagnostic . Create ( descriptor , binary . GetLocation ( ) ) ) ;
7679 }
7780
78- static void AnalyzeIsPattern ( SyntaxNodeAnalysisContext context )
81+ static void AnalyzeIsPattern ( SyntaxNodeAnalysisContext context , INamedTypeSymbol unityObjectSymbol )
7982 {
8083 if ( context . Node is not IsPatternExpressionSyntax isPattern )
8184 return ;
@@ -84,15 +87,12 @@ static void AnalyzeIsPattern(SyntaxNodeAnalysisContext context)
8487
8588 if ( IsNullPattern ( pattern ) )
8689 {
87- if ( IsUnityObjectExpression ( isPattern . Expression , context . SemanticModel , context . CancellationToken ) )
90+ if ( IsUnityObjectExpression ( isPattern . Expression , context . SemanticModel , context . CancellationToken , unityObjectSymbol ) )
8891 context . ReportDiagnostic ( Diagnostic . Create ( MED026 , isPattern . GetLocation ( ) ) ) ;
89-
90- return ;
9192 }
92-
93- if ( IsNotNullPattern ( pattern ) )
93+ else if ( IsNotNullPattern ( pattern ) )
9494 {
95- if ( IsUnityObjectExpression ( isPattern . Expression , context . SemanticModel , context . CancellationToken ) )
95+ if ( IsUnityObjectExpression ( isPattern . Expression , context . SemanticModel , context . CancellationToken , unityObjectSymbol ) )
9696 context . ReportDiagnostic ( Diagnostic . Create ( MED027 , isPattern . GetLocation ( ) ) ) ;
9797 }
9898 }
@@ -102,14 +102,15 @@ static bool TryGetUnityOperand(
102102 ExpressionSyntax right ,
103103 SemanticModel model ,
104104 CancellationToken ct ,
105+ INamedTypeSymbol unityObjectSymbol ,
105106 out ExpressionSyntax unityOperand
106107 )
107108 {
108109 unityOperand = null ! ;
109110
110111 if ( IsNullLiteral ( left ) )
111112 {
112- if ( IsUnityObjectExpression ( right , model , ct ) )
113+ if ( IsUnityObjectExpression ( right , model , ct , unityObjectSymbol ) )
113114 {
114115 unityOperand = right ;
115116 return true ;
@@ -120,7 +121,7 @@ out ExpressionSyntax unityOperand
120121
121122 if ( IsNullLiteral ( right ) )
122123 {
123- if ( ! IsUnityObjectExpression ( left , model , ct ) )
124+ if ( ! IsUnityObjectExpression ( left , model , ct , unityObjectSymbol ) )
124125 return false ;
125126
126127 unityOperand = left ;
@@ -159,27 +160,38 @@ static PatternSyntax UnwrapParenthesizedPattern(PatternSyntax pattern)
159160 return pattern ;
160161 }
161162
162- static bool IsUnityObjectExpression ( ExpressionSyntax expression , SemanticModel model , CancellationToken ct )
163+ static bool IsUnityObjectExpression (
164+ ExpressionSyntax expression ,
165+ SemanticModel model ,
166+ CancellationToken ct ,
167+ INamedTypeSymbol unityObjectSymbol
168+ )
163169 {
164170 var typeInfo = model . GetTypeInfo ( expression , ct ) ;
165171 var type = typeInfo . Type ?? typeInfo . ConvertedType ;
166- return IsUnityObjectType ( type ) ;
172+ return IsUnityObjectType ( type , unityObjectSymbol ) ;
167173 }
168174
169- static bool IsUnityObjectType ( ITypeSymbol ? type )
175+ static bool IsUnityObjectType ( ITypeSymbol ? type , INamedTypeSymbol unityObjectSymbol )
170176 {
171- if ( type is null or IErrorTypeSymbol )
172- return false ;
173-
174- if ( type is ITypeParameterSymbol typeParameter )
177+ switch ( type )
175178 {
176- foreach ( var constraint in typeParameter . ConstraintTypes )
177- if ( IsUnityObjectType ( constraint ) )
178- return true ;
179+ case null or IErrorTypeSymbol :
180+ {
181+ return false ;
182+ }
183+ case ITypeParameterSymbol typeParameter :
184+ {
185+ foreach ( var constraint in typeParameter . ConstraintTypes )
186+ if ( IsUnityObjectType ( constraint , unityObjectSymbol ) )
187+ return true ;
179188
180- return false ;
189+ return false ;
190+ }
191+ default :
192+ {
193+ return type . Is ( unityObjectSymbol ) || type . InheritsFrom ( unityObjectSymbol ) ;
194+ }
181195 }
182-
183- return type . Is ( UnityObjectFqn ) || type . InheritsFrom ( UnityObjectFqn ) ;
184196 }
185- }
197+ }
0 commit comments