Skip to content

Commit eb0ee18

Browse files
committed
support nested types, improve code generated
1 parent f2a9dc6 commit eb0ee18

19 files changed

+342
-193
lines changed

src/Equatable.SourceGenerator/EquatableGenerator.cs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,13 @@ private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken
7272
if (context.TargetSymbol is not INamedTypeSymbol targetSymbol)
7373
return null;
7474

75+
var fullyQualified = targetSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
7576
var classNamespace = targetSymbol.ContainingNamespace.ToDisplayString();
7677
var className = targetSymbol.Name;
7778

79+
// support nested types
80+
var containingTypes = GetContainingTypes(targetSymbol);
81+
7882
var baseHashCode = GetBaseHashCodeMethod(targetSymbol);
7983
var baseEquals = GetBaseEqualsMethod(targetSymbol);
8084
var baseEquatable = GetBaseEquatableType(targetSymbol);
@@ -99,8 +103,10 @@ private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken
99103
seedHash = (seedHash * HashFactor) + GetFNVHashCode(property.PropertyName);
100104

101105
var entity = new EquatableClass(
106+
FullyQualified: fullyQualified,
102107
EntityNamespace: classNamespace,
103108
EntityName: className,
109+
ContainingTypes: containingTypes,
104110
Properties: propertyArray,
105111
IsRecord: targetSymbol.IsRecord,
106112
IsValueType: targetSymbol.IsValueType,
@@ -144,7 +150,8 @@ private static IEnumerable<IPropertySymbol> GetProperties(INamedTypeSymbol targe
144150

145151
private static EquatableProperty CreateProperty(IPropertySymbol propertySymbol)
146152
{
147-
var propertyType = propertySymbol.Type.ToDisplayString();
153+
var format = SymbolDisplayFormat.FullyQualifiedFormat.WithMiscellaneousOptions(SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier);
154+
var propertyType = propertySymbol.Type.ToDisplayString(format);
148155
var propertyName = propertySymbol.Name;
149156

150157
// look for custom equality
@@ -286,6 +293,31 @@ private static bool IsValueType(INamedTypeSymbol targetSymbol)
286293
};
287294
}
288295

296+
private static EquatableArray<ContainingClass> GetContainingTypes(INamedTypeSymbol targetSymbol)
297+
{
298+
if (targetSymbol.ContainingType is null)
299+
return Array.Empty<ContainingClass>();
300+
301+
var list = new List<ContainingClass>();
302+
var currentSymbol = targetSymbol.ContainingType;
303+
304+
while (currentSymbol != null)
305+
{
306+
var containingClass = new ContainingClass(
307+
EntityName: currentSymbol.Name,
308+
IsRecord: currentSymbol.IsRecord,
309+
IsValueType: currentSymbol.IsValueType
310+
);
311+
312+
list.Add(containingClass);
313+
314+
currentSymbol = currentSymbol.ContainingType;
315+
}
316+
317+
list.Reverse();
318+
319+
return list.ToArray();
320+
}
289321

290322
private static IMethodSymbol? GetBaseHashCodeMethod(INamedTypeSymbol targetSymbol)
291323
{

src/Equatable.SourceGenerator/EquatableWriter.cs

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ public static class EquatableWriter
88
{
99
public static string Generate(EquatableClass entityClass)
1010
{
11-
if (entityClass == null)
11+
if (entityClass is null)
1212
throw new ArgumentNullException(nameof(entityClass));
1313

1414
var codeBuilder = new IndentedStringBuilder();
@@ -23,6 +23,19 @@ public static string Generate(EquatableClass entityClass)
2323
.AppendLine("{")
2424
.IncrementIndent();
2525

26+
// support nested types
27+
foreach (var containingClass in entityClass.ContainingTypes)
28+
{
29+
codeBuilder
30+
.Append("partial ")
31+
.AppendIf("record ", containingClass.IsRecord)
32+
.AppendIf("class ", !containingClass.IsValueType)
33+
.AppendIf("struct ", containingClass.IsValueType)
34+
.AppendLine(containingClass.EntityName)
35+
.AppendLine("{")
36+
.IncrementIndent();
37+
}
38+
2639
codeBuilder
2740
.Append("partial ")
2841
.AppendIf("record ", entityClass.IsRecord)
@@ -34,7 +47,7 @@ public static string Generate(EquatableClass entityClass)
3447
{
3548
codeBuilder
3649
.Append(" : global::System.IEquatable<")
37-
.Append(entityClass.EntityName)
50+
.Append(entityClass.FullyQualified)
3851
.AppendIf("?", !entityClass.IsValueType)
3952
.Append(">");
4053
}
@@ -50,7 +63,17 @@ public static string Generate(EquatableClass entityClass)
5063

5164
codeBuilder
5265
.DecrementIndent()
53-
.AppendLine("}") // class
66+
.AppendLine("}"); // class
67+
68+
// support nested types
69+
foreach (var containingClass in entityClass.ContainingTypes)
70+
{
71+
codeBuilder
72+
.DecrementIndent()
73+
.AppendLine("}");
74+
}
75+
76+
codeBuilder
5477
.DecrementIndent()
5578
.AppendLine("}"); // namespace
5679

@@ -69,7 +92,7 @@ private static void GenerateEquatable(IndentedStringBuilder codeBuilder, Equatab
6992
.Append("public ")
7093
.AppendIf("virtual ", entityClass.IsRecord && !entityClass.IsSealed)
7194
.Append("bool Equals(")
72-
.Append(entityClass.EntityName)
95+
.Append(entityClass.FullyQualified)
7396
.AppendIf("?", !entityClass.IsValueType)
7497
.AppendLine(" other)")
7598
.AppendLine("{")
@@ -83,7 +106,7 @@ private static void GenerateEquatable(IndentedStringBuilder codeBuilder, Equatab
83106
}
84107
else
85108
{
86-
codeBuilder.Append("return other is not null");
109+
codeBuilder.Append("return !(other is null)");
87110
wrote = true;
88111
}
89112

@@ -195,7 +218,7 @@ private static void GenerateEquatable(IndentedStringBuilder codeBuilder, Equatab
195218

196219
private static void GenerateEquatableFunctions(IndentedStringBuilder codeBuilder, EquatableClass entityClass)
197220
{
198-
if (entityClass == null)
221+
if (entityClass is null)
199222
return;
200223

201224
if (entityClass.Properties.Any(p => p.ComparerType == ComparerTypes.Dictionary))
@@ -207,7 +230,7 @@ private static void GenerateEquatableFunctions(IndentedStringBuilder codeBuilder
207230
.AppendLine("if (global::System.Object.ReferenceEquals(left, right))")
208231
.AppendLine(" return true;")
209232
.AppendLine()
210-
.AppendLine("if (left == null || right == null)")
233+
.AppendLine("if (left is null || right is null)")
211234
.AppendLine(" return false;")
212235
.AppendLine()
213236
.AppendLine("if (left.Count != right.Count)")
@@ -244,7 +267,7 @@ private static void GenerateEquatableFunctions(IndentedStringBuilder codeBuilder
244267
.AppendLine("if (global::System.Object.ReferenceEquals(left, right))")
245268
.AppendLine(" return true;")
246269
.AppendLine()
247-
.AppendLine("if (left == null || right == null)")
270+
.AppendLine("if (left is null || right is null)")
248271
.AppendLine(" return false;")
249272
.AppendLine()
250273
.AppendLine("if (left is global::System.Collections.Generic.ISet<T> leftSet)")
@@ -269,7 +292,7 @@ private static void GenerateEquatableFunctions(IndentedStringBuilder codeBuilder
269292
.AppendLine("if (global::System.Object.ReferenceEquals(left, right))")
270293
.AppendLine(" return true;")
271294
.AppendLine()
272-
.AppendLine("if (left == null || right == null)")
295+
.AppendLine("if (left is null || right is null)")
273296
.AppendLine(" return false;")
274297
.AppendLine()
275298
.AppendLine("return global::System.Linq.Enumerable.SequenceEqual(left, right, global::System.Collections.Generic.EqualityComparer<T>.Default);")
@@ -300,14 +323,14 @@ private static void GenerateEquals(IndentedStringBuilder codeBuilder, EquatableC
300323
{
301324
codeBuilder
302325
.Append("return obj is ")
303-
.Append(entityClass.EntityName)
326+
.Append(entityClass.FullyQualified)
304327
.AppendLine(" instance && Equals(instance);");
305328
}
306329
else
307330
{
308331
codeBuilder
309332
.Append("return Equals(obj as ")
310-
.Append(entityClass.EntityName)
333+
.Append(entityClass.FullyQualified)
311334
.AppendLine(");");
312335
}
313336

@@ -324,16 +347,16 @@ private static void GenerateEquals(IndentedStringBuilder codeBuilder, EquatableC
324347
.Append(ThisAssembly.InformationalVersion)
325348
.AppendLine("\")]")
326349
.Append("public static bool operator ==(")
327-
.Append(entityClass.EntityName)
350+
.Append(entityClass.FullyQualified)
328351
.AppendIf("?", !entityClass.IsValueType)
329352
.Append(" left, ")
330-
.Append(entityClass.EntityName)
353+
.Append(entityClass.FullyQualified)
331354
.AppendIf("?", !entityClass.IsValueType)
332355
.AppendLine(" right)")
333356
.AppendLine("{")
334357
.IncrementIndent()
335358
.Append("return global::System.Collections.Generic.EqualityComparer<")
336-
.Append(entityClass.EntityName)
359+
.Append(entityClass.FullyQualified)
337360
.AppendIf("?", !entityClass.IsValueType)
338361
.AppendLine(">.Default.Equals(left, right);")
339362
.DecrementIndent()
@@ -348,10 +371,10 @@ private static void GenerateEquals(IndentedStringBuilder codeBuilder, EquatableC
348371
.Append(ThisAssembly.InformationalVersion)
349372
.AppendLine("\")]")
350373
.Append("public static bool operator !=(")
351-
.Append(entityClass.EntityName)
374+
.Append(entityClass.FullyQualified)
352375
.AppendIf("?", !entityClass.IsValueType)
353376
.Append(" left, ")
354-
.Append(entityClass.EntityName)
377+
.Append(entityClass.FullyQualified)
355378
.AppendIf("?", !entityClass.IsValueType)
356379
.AppendLine(" right)")
357380
.AppendLine("{")
@@ -460,7 +483,7 @@ private static void GenerateHashCode(IndentedStringBuilder codeBuilder, Equatabl
460483

461484
private static void GenerateHashCodeFunctions(IndentedStringBuilder codeBuilder, EquatableClass entityClass)
462485
{
463-
if (entityClass == null)
486+
if (entityClass is null)
464487
return;
465488

466489
if (entityClass.Properties.Any(p => p.ComparerType == ComparerTypes.Dictionary))
@@ -469,7 +492,7 @@ private static void GenerateHashCodeFunctions(IndentedStringBuilder codeBuilder,
469492
.AppendLine("static int DictionaryHashCode<TKey, TValue>(global::System.Collections.Generic.IDictionary<TKey, TValue>? items)")
470493
.AppendLine("{")
471494
.IncrementIndent()
472-
.AppendLine("if (items == null)")
495+
.AppendLine("if (items is null)")
473496
.AppendLine(" return 0;")
474497
.AppendLine();
475498

@@ -501,7 +524,7 @@ private static void GenerateHashCodeFunctions(IndentedStringBuilder codeBuilder,
501524
.AppendLine("static int HashSetHashCode<T>(global::System.Collections.Generic.IEnumerable<T>? items)")
502525
.AppendLine("{")
503526
.IncrementIndent()
504-
.AppendLine("if (items == null)")
527+
.AppendLine("if (items is null)")
505528
.AppendLine(" return 0;")
506529
.AppendLine()
507530
.Append("int hashCode = ")
@@ -524,7 +547,7 @@ private static void GenerateHashCodeFunctions(IndentedStringBuilder codeBuilder,
524547
.AppendLine("static int SequenceHashCode<T>(global::System.Collections.Generic.IEnumerable<T>? items)")
525548
.AppendLine("{")
526549
.IncrementIndent()
527-
.AppendLine("if (items == null)")
550+
.AppendLine("if (items is null)")
528551
.AppendLine(" return 0;")
529552
.AppendLine()
530553
.Append("int hashCode = ")
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
namespace Equatable.SourceGenerator.Models;
2+
3+
public record ContainingClass(
4+
string EntityName,
5+
bool IsRecord,
6+
bool IsValueType
7+
);
Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
namespace Equatable.SourceGenerator.Models;
22

33
public record EquatableClass(
4-
string EntityNamespace,
5-
string EntityName,
6-
EquatableArray<EquatableProperty> Properties,
7-
bool IsRecord,
8-
bool IsValueType,
9-
bool IsSealed,
10-
bool IncludeBaseEqualsMethod,
11-
bool IncludeBaseHashMethod,
12-
int SeedHash
4+
string FullyQualified,
5+
string EntityNamespace,
6+
string EntityName,
7+
EquatableArray<ContainingClass> ContainingTypes,
8+
EquatableArray<EquatableProperty> Properties,
9+
bool IsRecord,
10+
bool IsValueType,
11+
bool IsSealed,
12+
bool IncludeBaseEqualsMethod,
13+
bool IncludeBaseHashMethod,
14+
int SeedHash
1315
);

test/Equatable.Entities/Nested.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ namespace Equatable.Entities;
44

55
public partial class Nested
66
{
7-
//[Equatable]
7+
[Equatable]
88
public partial class Animal
99
{
1010
public int Id { get; set; }

test/Equatable.Generator.Tests/EquatableGeneratorTest.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,35 @@ public partial class Audit
357357
.ScrubLinesContaining("GeneratedCodeAttribute");
358358
}
359359

360+
[Fact]
361+
public Task GenerateNestedComparer()
362+
{
363+
var source = @"
364+
using Equatable.Attributes;
365+
366+
namespace Equatable.Entities;
367+
368+
public partial class Nested
369+
{
370+
[Equatable]
371+
public partial class Animal
372+
{
373+
public int Id { get; set; }
374+
public string? Name { get; set; }
375+
public string? Type { get; set; }
376+
}
377+
}
378+
";
379+
380+
var (diagnostics, output) = GetGeneratedOutput<EquatableGenerator>(source);
381+
382+
diagnostics.Should().BeEmpty();
383+
384+
return Verifier
385+
.Verify(output)
386+
.UseDirectory("Snapshots")
387+
.ScrubLinesContaining("GeneratedCodeAttribute");
388+
}
360389

361390
private static (ImmutableArray<Diagnostic> Diagnostics, string Output) GetGeneratedOutput<T>(string source)
362391
where T : IIncrementalGenerator, new()

test/Equatable.Generator.Tests/EquatableWriterTest.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ public class EquatableWriterTest
99
public async Task GenerateBasicUser()
1010
{
1111
var entityClass = new EquatableClass(
12+
FullyQualified: "global::Equatable.Entities.User",
1213
EntityNamespace: "Equatable.Entities",
1314
EntityName: "User",
15+
ContainingTypes: Array.Empty<ContainingClass>(),
1416
Properties: new EquatableArray<EquatableProperty>([
1517
new EquatableProperty("Id", "int"),
1618
new EquatableProperty("FirstName", "string?"),
@@ -38,8 +40,10 @@ public async Task GenerateBasicUser()
3840
public async Task GenerateUserStringSequence()
3941
{
4042
var entityClass = new EquatableClass(
43+
FullyQualified: "global::Equatable.Entities.User",
4144
EntityNamespace: "Equatable.Entities",
4245
EntityName: "User",
46+
ContainingTypes: Array.Empty<ContainingClass>(),
4347
Properties: new EquatableArray<EquatableProperty>([
4448
new EquatableProperty("Id", "int"),
4549
new EquatableProperty("FirstName", "string?"),
@@ -68,8 +72,10 @@ public async Task GenerateUserStringSequence()
6872
public async Task GenerateUserImportHashSetDictionary()
6973
{
7074
var entityClass = new EquatableClass(
75+
FullyQualified: "global::Equatable.Entities.UserImport",
7176
EntityNamespace: "Equatable.Entities",
7277
EntityName: "UserImport",
78+
ContainingTypes: Array.Empty<ContainingClass>(),
7379
Properties: new EquatableArray<EquatableProperty>([
7480
new EquatableProperty("EmailAddress", "string", ComparerTypes.String, "OrdinalIgnoreCase"),
7581
new EquatableProperty("DisplayName", "string?"),

0 commit comments

Comments
 (0)