Skip to content

Commit 7387438

Browse files
committed
Update EnumExtensions generator
1 parent b492595 commit 7387438

File tree

6 files changed

+559
-75
lines changed

6 files changed

+559
-75
lines changed

NetEscapades.EnumGenerators/src/NetEscapades.EnumGenerators/EnumGenerator.cs

Lines changed: 32 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,31 @@ public class EnumGenerator : IIncrementalGenerator
1313

1414
public void Initialize(IncrementalGeneratorInitializationContext context)
1515
{
16-
context.RegisterPostInitializationOutput(ctx => ctx.AddSource(
16+
context.RegisterPostInitializationOutput(static ctx => ctx.AddSource(
1717
"EnumExtensionsAttribute.g.cs", SourceText.From(SourceGenerationHelper.Attribute, Encoding.UTF8)));
1818

19-
IncrementalValuesProvider<EnumDeclarationSyntax> enumDeclarations = context.SyntaxProvider
19+
IncrementalValuesProvider<EnumToGenerate?> enumsToGenerate = context.SyntaxProvider
2020
.CreateSyntaxProvider(
2121
predicate: static (s, _) => IsSyntaxTargetForGeneration(s),
2222
transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx))
23-
.Where(static m => m is not null)!;
24-
25-
IncrementalValueProvider<(Compilation, ImmutableArray<EnumDeclarationSyntax>)> compilationAndEnums
26-
= context.CompilationProvider.Combine(enumDeclarations.Collect());
27-
28-
context.RegisterSourceOutput(compilationAndEnums,
29-
static (spc, source) => Execute(source.Item1, source.Item2, spc));
23+
.Where(static m => m is not null);
24+
25+
// If you're targeting the .NET 7 SDK, use this version instead:
26+
// IncrementalValuesProvider<EnumToGenerate?> enumsToGenerate = context.SyntaxProvider
27+
// .ForAttributeWithMetadataName(
28+
// "NetEscapades.EnumGenerators.EnumExtensionsAttribute",
29+
// predicate: static (s, _) => true,
30+
// transform: static (ctx, _) => GetEnumToGenerate(ctx.SemanticModel, ctx.TargetNode))
31+
// .Where(static m => m is not null);
32+
33+
context.RegisterSourceOutput(enumsToGenerate,
34+
static (spc, source) => Execute(source, spc));
3035
}
3136

3237
static bool IsSyntaxTargetForGeneration(SyntaxNode node)
3338
=> node is EnumDeclarationSyntax m && m.AttributeLists.Count > 0;
3439

35-
static EnumDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
40+
static EnumToGenerate? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
3641
{
3742
// we know the node is a EnumDeclarationSyntax thanks to IsSyntaxTargetForGeneration
3843
var enumDeclarationSyntax = (EnumDeclarationSyntax)context.Node;
@@ -55,7 +60,7 @@ static bool IsSyntaxTargetForGeneration(SyntaxNode node)
5560
if (fullName == EnumExtensionsAttribute)
5661
{
5762
// return the enum
58-
return enumDeclarationSyntax;
63+
return GetEnumToGenerate(context.SemanticModel, enumDeclarationSyntax);
5964
}
6065
}
6166
}
@@ -64,67 +69,36 @@ static bool IsSyntaxTargetForGeneration(SyntaxNode node)
6469
return null;
6570
}
6671

67-
static void Execute(Compilation compilation, ImmutableArray<EnumDeclarationSyntax> enums, SourceProductionContext context)
72+
static void Execute(EnumToGenerate? enumToGenerate, SourceProductionContext context)
6873
{
69-
if (enums.IsDefaultOrEmpty)
70-
{
71-
// nothing to do yet
72-
return;
73-
}
74-
75-
// I'm not sure if this is actually necessary, but `[LoggerMessage]` does it, so seems like a good idea!
76-
IEnumerable<EnumDeclarationSyntax> distinctEnums = enums.Distinct();
77-
78-
// Convert each EnumDeclarationSyntax to an EnumToGenerate
79-
List<EnumToGenerate> enumsToGenerate = GetTypesToGenerate(compilation, distinctEnums, context.CancellationToken);
80-
81-
// If there were errors in the EnumDeclarationSyntax, we won't create an
82-
// EnumToGenerate for it, so make sure we have something to generate
83-
if (enumsToGenerate.Count > 0)
74+
if (enumToGenerate is { } value)
8475
{
8576
// generate the source code and add it to the output
86-
string result = SourceGenerationHelper.GenerateExtensionClass(enumsToGenerate);
87-
context.AddSource("EnumExtensions.g.cs", SourceText.From(result, Encoding.UTF8));
77+
string result = SourceGenerationHelper.GenerateExtensionClass(value);
78+
context.AddSource($"EnumExtensions.{value.Name}.g.cs", SourceText.From(result, Encoding.UTF8));
8879
}
8980
}
9081

91-
static List<EnumToGenerate> GetTypesToGenerate(Compilation compilation, IEnumerable<EnumDeclarationSyntax> enums, CancellationToken ct)
82+
static EnumToGenerate? GetEnumToGenerate(SemanticModel semanticModel, SyntaxNode enumDeclarationSyntax)
9283
{
93-
var enumsToGenerate = new List<EnumToGenerate>();
94-
INamedTypeSymbol? enumAttribute = compilation.GetTypeByMetadataName(EnumExtensionsAttribute);
95-
if (enumAttribute == null)
84+
if (semanticModel.GetDeclaredSymbol(enumDeclarationSyntax) is not INamedTypeSymbol enumSymbol)
9685
{
97-
// nothing to do if this type isn't available
98-
return enumsToGenerate;
86+
// something went wrong
87+
return null;
9988
}
10089

101-
foreach (var enumDeclarationSyntax in enums)
102-
{
103-
// stop if we're asked to
104-
ct.ThrowIfCancellationRequested();
105-
106-
SemanticModel semanticModel = compilation.GetSemanticModel(enumDeclarationSyntax.SyntaxTree);
107-
if (semanticModel.GetDeclaredSymbol(enumDeclarationSyntax) is not INamedTypeSymbol enumSymbol)
108-
{
109-
// something went wrong
110-
continue;
111-
}
112-
113-
string enumName = enumSymbol.ToString();
114-
ImmutableArray<ISymbol> enumMembers = enumSymbol.GetMembers();
115-
var members = new List<string>(enumMembers.Length);
90+
string enumName = enumSymbol.ToString();
91+
ImmutableArray<ISymbol> enumMembers = enumSymbol.GetMembers();
92+
var members = new List<string>(enumMembers.Length);
11693

117-
foreach (ISymbol member in enumMembers)
94+
foreach (ISymbol member in enumMembers)
95+
{
96+
if (member is IFieldSymbol field && field.ConstantValue is not null)
11897
{
119-
if (member is IFieldSymbol field && field.ConstantValue is not null)
120-
{
121-
members.Add(member.Name);
122-
}
98+
members.Add(member.Name);
12399
}
124-
125-
enumsToGenerate.Add(new EnumToGenerate(enumName, members));
126100
}
127101

128-
return enumsToGenerate;
102+
return new EnumToGenerate(enumName, members);
129103
}
130104
}
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
namespace NetEscapades.EnumGenerators;
22

3-
public readonly struct EnumToGenerate
3+
public readonly record struct EnumToGenerate
44
{
55
public readonly string Name;
6-
public readonly List<string> Values;
6+
public readonly EquatableArray<string> Values;
77

88
public EnumToGenerate(string name, List<string> values)
99
{
1010
Name = name;
11-
Values = values;
11+
Values = new(values.ToArray());
1212
}
1313
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// <copyright file="EquatableArray.cs" company="Datadog">
2+
// Unless explicitly stated otherwise all files in this repository are licensed under the Apache 2 License.
3+
// This product includes software developed at Datadog (https://www.datadoghq.com/). Copyright 2017 Datadog, Inc.
4+
// </copyright>
5+
6+
using System.Collections;
7+
8+
namespace NetEscapades.EnumGenerators;
9+
10+
/// <summary>
11+
/// An immutable, equatable array. This is equivalent to <see cref="Array"/> but with value equality support.
12+
/// </summary>
13+
/// <typeparam name="T">The type of values in the array.</typeparam>
14+
public readonly struct EquatableArray<T> : IEquatable<EquatableArray<T>>, IEnumerable<T>
15+
where T : IEquatable<T>
16+
{
17+
/// <summary>
18+
/// The underlying <typeparamref name="T"/> array.
19+
/// </summary>
20+
private readonly T[]? _array;
21+
22+
/// <summary>
23+
/// Initializes a new instance of the <see cref="EquatableArray{T}"/> struct.
24+
/// </summary>
25+
/// <param name="array">The input array to wrap.</param>
26+
public EquatableArray(T[] array)
27+
{
28+
_array = array;
29+
}
30+
31+
/// <summary>
32+
/// Gets the length of the array, or 0 if the array is null
33+
/// </summary>
34+
public int Count => _array?.Length ?? 0;
35+
36+
/// <summary>
37+
/// Checks whether two <see cref="EquatableArray{T}"/> values are the same.
38+
/// </summary>
39+
/// <param name="left">The first <see cref="EquatableArray{T}"/> value.</param>
40+
/// <param name="right">The second <see cref="EquatableArray{T}"/> value.</param>
41+
/// <returns>Whether <paramref name="left"/> and <paramref name="right"/> are equal.</returns>
42+
public static bool operator ==(EquatableArray<T> left, EquatableArray<T> right)
43+
{
44+
return left.Equals(right);
45+
}
46+
47+
/// <summary>
48+
/// Checks whether two <see cref="EquatableArray{T}"/> values are not the same.
49+
/// </summary>
50+
/// <param name="left">The first <see cref="EquatableArray{T}"/> value.</param>
51+
/// <param name="right">The second <see cref="EquatableArray{T}"/> value.</param>
52+
/// <returns>Whether <paramref name="left"/> and <paramref name="right"/> are not equal.</returns>
53+
public static bool operator !=(EquatableArray<T> left, EquatableArray<T> right)
54+
{
55+
return !left.Equals(right);
56+
}
57+
58+
/// <inheritdoc/>
59+
public bool Equals(EquatableArray<T> array)
60+
{
61+
return AsSpan().SequenceEqual(array.AsSpan());
62+
}
63+
64+
/// <inheritdoc/>
65+
public override bool Equals(object? obj)
66+
{
67+
return obj is EquatableArray<T> array && Equals(this, array);
68+
}
69+
70+
/// <inheritdoc/>
71+
public override int GetHashCode()
72+
{
73+
if (_array is not T[] array)
74+
{
75+
return 0;
76+
}
77+
78+
HashCode hashCode = default;
79+
80+
foreach (T item in array)
81+
{
82+
hashCode.Add(item);
83+
}
84+
85+
return hashCode.ToHashCode();
86+
}
87+
88+
/// <summary>
89+
/// Returns a <see cref="ReadOnlySpan{T}"/> wrapping the current items.
90+
/// </summary>
91+
/// <returns>A <see cref="ReadOnlySpan{T}"/> wrapping the current items.</returns>
92+
public ReadOnlySpan<T> AsSpan()
93+
{
94+
return _array.AsSpan();
95+
}
96+
97+
/// <summary>
98+
/// Returns the underlying wrapped array.
99+
/// </summary>
100+
/// <returns>Returns the underlying array.</returns>
101+
public T[]? AsArray()
102+
{
103+
return _array;
104+
}
105+
106+
/// <inheritdoc/>
107+
IEnumerator<T> IEnumerable<T>.GetEnumerator()
108+
{
109+
return ((IEnumerable<T>)(_array ?? Array.Empty<T>())).GetEnumerator();
110+
}
111+
112+
/// <inheritdoc/>
113+
IEnumerator IEnumerable.GetEnumerator()
114+
{
115+
return ((IEnumerable<T>)(_array ?? Array.Empty<T>())).GetEnumerator();
116+
}
117+
}

0 commit comments

Comments
 (0)