@@ -72,6 +72,8 @@ private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken
7272 if ( context . TargetSymbol is not INamedTypeSymbol targetSymbol )
7373 return null ;
7474
75+ var diagnostics = new List < Diagnostic > ( ) ;
76+
7577 var fullyQualified = targetSymbol . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) ;
7678 var classNamespace = targetSymbol . ContainingNamespace . ToDisplayString ( ) ;
7779 var className = targetSymbol . Name ;
@@ -86,7 +88,7 @@ private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken
8688 var propertySymbols = GetProperties ( targetSymbol , baseHashCode == null && baseEquatable == null ) ;
8789
8890 var propertyArray = propertySymbols
89- . Select ( CreateProperty )
91+ . Select ( symbol => CreateProperty ( diagnostics , symbol ) )
9092 . ToArray ( ) ?? [ ] ;
9193
9294 // the seed value of the hash code method
@@ -116,7 +118,7 @@ private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken
116118 SeedHash : seedHash
117119 ) ;
118120
119- return new EquatableContext ( entity , null ) ;
121+ return new EquatableContext ( entity , diagnostics . ToArray ( ) ) ;
120122 }
121123
122124
@@ -148,7 +150,7 @@ private static IEnumerable<IPropertySymbol> GetProperties(INamedTypeSymbol targe
148150 return properties . Values ;
149151 }
150152
151- private static EquatableProperty CreateProperty ( IPropertySymbol propertySymbol )
153+ private static EquatableProperty CreateProperty ( List < Diagnostic > diagnostics , IPropertySymbol propertySymbol )
152154 {
153155 var format = SymbolDisplayFormat . FullyQualifiedFormat . WithMiscellaneousOptions ( SymbolDisplayMiscellaneousOptions . IncludeNullableReferenceTypeModifier ) ;
154156 var propertyType = propertySymbol . Type . ToDisplayString ( format ) ;
@@ -172,6 +174,17 @@ private static EquatableProperty CreateProperty(IPropertySymbol propertySymbol)
172174 if ( ! comparerType . HasValue )
173175 continue ;
174176
177+ var diagnostic = ValidateComparer ( propertySymbol , comparerType ) ;
178+ if ( diagnostic != null )
179+ {
180+ diagnostics . Add ( diagnostic ) ;
181+
182+ return new EquatableProperty (
183+ propertyName ,
184+ propertyType ,
185+ ComparerTypes . Default ) ;
186+ }
187+
175188 return new EquatableProperty (
176189 propertyName ,
177190 propertyType ,
@@ -186,6 +199,63 @@ private static EquatableProperty CreateProperty(IPropertySymbol propertySymbol)
186199 ComparerTypes . Default ) ;
187200 }
188201
202+ private static Diagnostic ? ValidateComparer ( IPropertySymbol propertySymbol , ComparerTypes ? comparerType )
203+ {
204+ // don't need to validate these types
205+ if ( comparerType is null or ComparerTypes . Default or ComparerTypes . Reference or ComparerTypes . Custom )
206+ return null ;
207+
208+ if ( comparerType == ComparerTypes . String )
209+ {
210+ if ( IsString ( propertySymbol . Type ) )
211+ return null ;
212+
213+ return Diagnostic . Create (
214+ DiagnosticDescriptors . InvalidStringEqualityAttributeUsage ,
215+ propertySymbol . Locations . FirstOrDefault ( ) ,
216+ propertySymbol . Name
217+ ) ;
218+ }
219+
220+ if ( comparerType == ComparerTypes . Dictionary )
221+ {
222+ if ( propertySymbol . Type . AllInterfaces . Any ( IsDictionary ) )
223+ return null ;
224+
225+ return Diagnostic . Create (
226+ DiagnosticDescriptors . InvalidDictionaryEqualityAttributeUsage ,
227+ propertySymbol . Locations . FirstOrDefault ( ) ,
228+ propertySymbol . Name
229+ ) ;
230+ }
231+
232+ if ( comparerType == ComparerTypes . HashSet )
233+ {
234+ if ( propertySymbol . Type . AllInterfaces . Any ( IsEnumerable ) )
235+ return null ;
236+
237+ return Diagnostic . Create (
238+ DiagnosticDescriptors . InvalidHashSetEqualityAttributeUsage ,
239+ propertySymbol . Locations . FirstOrDefault ( ) ,
240+ propertySymbol . Name
241+ ) ;
242+ }
243+
244+ if ( comparerType == ComparerTypes . Sequence )
245+ {
246+ if ( propertySymbol . Type . AllInterfaces . Any ( IsEnumerable ) )
247+ return null ;
248+
249+ return Diagnostic . Create (
250+ DiagnosticDescriptors . InvalidSequenceEqualityAttributeUsage ,
251+ propertySymbol . Locations . FirstOrDefault ( ) ,
252+ propertySymbol . Name
253+ ) ;
254+ }
255+
256+ return null ;
257+ }
258+
189259
190260 private static ( ComparerTypes ? comparerType , string ? comparerName , string ? comparerInstance ) GetComparer ( AttributeData ? attribute )
191261 {
@@ -293,6 +363,62 @@ private static bool IsValueType(INamedTypeSymbol targetSymbol)
293363 } ;
294364 }
295365
366+ private static bool IsEnumerable ( INamedTypeSymbol targetSymbol )
367+ {
368+ return targetSymbol is
369+ {
370+ Name : "IEnumerable" ,
371+ IsGenericType : true ,
372+ TypeArguments . Length : 1 ,
373+ TypeParameters . Length : 1 ,
374+ ContainingNamespace :
375+ {
376+ Name : "Generic" ,
377+ ContainingNamespace :
378+ {
379+ Name : "Collections" ,
380+ ContainingNamespace :
381+ {
382+ Name : "System"
383+ }
384+ }
385+ }
386+ } ;
387+ }
388+
389+ private static bool IsDictionary ( INamedTypeSymbol targetSymbol )
390+ {
391+ return targetSymbol is
392+ {
393+ Name : "IDictionary" ,
394+ IsGenericType : true ,
395+ TypeArguments . Length : 2 ,
396+ TypeParameters . Length : 2 ,
397+ ContainingNamespace :
398+ {
399+ Name : "Generic" ,
400+ ContainingNamespace :
401+ {
402+ Name : "Collections" ,
403+ ContainingNamespace :
404+ {
405+ Name : "System"
406+ }
407+ }
408+ }
409+ } ;
410+ }
411+
412+ private static bool IsString ( ITypeSymbol targetSymbol )
413+ {
414+ return targetSymbol is
415+ {
416+ Name : nameof ( String ) ,
417+ ContainingNamespace . Name : "System"
418+ } ;
419+ }
420+
421+
296422 private static EquatableArray < ContainingClass > GetContainingTypes ( INamedTypeSymbol targetSymbol )
297423 {
298424 if ( targetSymbol . ContainingType is null )
0 commit comments