Skip to content

Commit cd820d3

Browse files
committed
Fix member forwarding in UnionStructSourceGenerator + add more tests
1 parent 1079eb4 commit cd820d3

File tree

4 files changed

+167
-46
lines changed

4 files changed

+167
-46
lines changed
-1 KB
Binary file not shown.

Medicine.SourceGenerator~/Generators/UnionStructSourceGenerator.cs

Lines changed: 38 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using Microsoft.CodeAnalysis.CSharp.Syntax;
55
using static System.StringComparison;
66
using static Constants;
7+
using static Microsoft.CodeAnalysis.Accessibility;
78

89
[Generator]
910
public sealed class UnionStructSourceGenerator : IIncrementalGenerator
@@ -30,7 +31,6 @@ record struct DerivedInput
3031
public string Name { get; init; }
3132
public string FQN { get; init; }
3233
public EquatableArray<string> Declaration { get; init; }
33-
public string HeaderFieldName { get; init; }
3434
public byte AssignedId { get; init; }
3535
public EquatableArray<string> PubliclyImplementedMembers { get; init; }
3636
public EquatableArray<string> MemberNames { get; init; }
@@ -53,14 +53,13 @@ record struct Derived : IGeneratorTransformOutput
5353
public string DerivedFQN { get; init; }
5454
public string DerivedName { get; init; }
5555
public EquatableArray<string> Declaration { get; init; }
56-
public string HeaderFieldName { get; init; }
57-
public EquatableArray<string> HeaderChainFQNs { get; init; }
5856
public byte? ForcedId { get; init; }
5957

58+
public EquatableIgnore<Func<string, bool>?> HasHeaderInChainFunc { get; init; }
6059
public EquatableIgnore<Func<string, bool>?> ImplementsUnionInterfaceFunc { get; init; }
6160
public EquatableIgnore<Func<DerivedDeferredInput>?> DeferredInputBuilderFunc { get; init; }
6261

63-
// ReSharper disable once NotAccessedField.Local
62+
// ReSharper disable once UnusedAutoPropertyAccessor.Local
6463
public EquatableArray<byte> DerivedTextCheckSumForCache { get; init; }
6564
}
6665

@@ -74,11 +73,9 @@ record struct GeneratorInput : IGeneratorTransformOutput
7473
public EquatableArray<string> BaseDeclaration { get; init; }
7574
public string BaseTypeName { get; init; }
7675
public string BaseTypeFQN { get; init; }
77-
public string InterfaceName { get; init; }
7876
public string InterfaceFQN { get; init; }
7977
public string TypeIDEnumFQN { get; init; }
8078
public string TypeIDFieldName { get; init; }
81-
public string RootTypeName { get; init; }
8279
public string RootTypeFQN { get; init; }
8380
public string RootInterfaceFQN { get; init; }
8481
public string RootTypeIDEnumFQN { get; init; }
@@ -92,7 +89,7 @@ record struct GeneratorInput : IGeneratorTransformOutput
9289
public EquatableIgnore<Func<HeaderFieldInput[]>?> HeaderFieldsBuilderFunc { get; init; }
9390
public EquatableIgnore<Func<DerivedInput[]>?> DerivedStructsBuilderFunc { get; init; }
9491

95-
// ReSharper disable once NotAccessedField.Local
92+
// ReSharper disable once UnusedAutoPropertyAccessor.Local
9693
public EquatableArray<byte> BaseTextCheckSumForCache { get; init; }
9794
}
9895

@@ -132,18 +129,22 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
132129
static GeneratorInput TransformBase(GeneratorAttributeSyntaxContext context, CancellationToken ct)
133130
{
134131
if (context is not { TargetSymbol: ITypeSymbol symbol, TargetNode: StructDeclarationSyntax structDecl })
132+
{
135133
return new()
136134
{
137135
SourceGeneratorError = "Unexpected target shape for [UnionHeader].",
138136
SourceGeneratorErrorLocation = new LocationInfo(context.TargetNode.GetLocation()),
139137
};
138+
}
140139

141140
if (symbol is not INamedTypeSymbol baseSymbol)
141+
{
142142
return new()
143143
{
144144
SourceGeneratorError = "Unexpected target symbol for [UnionHeader].",
145145
SourceGeneratorErrorLocation = new LocationInfo(context.TargetNode.GetLocation()),
146146
};
147+
}
147148

148149
var symbolMembers = symbol.GetMembers().AsArray();
149150
var symbolTypeMembers = symbol.GetTypeMembers().AsArray();
@@ -152,10 +153,12 @@ static GeneratorInput TransformBase(GeneratorAttributeSyntaxContext context, Can
152153
?? symbolTypeMembers.FirstOrDefault(x => x.TypeKind is TypeKind.Interface);
153154

154155
if (interfaceSymbol is null)
156+
{
155157
return new()
156158
{
157159
SourceGeneratorOutputFilename = Utility.GetOutputFilename(structDecl.SyntaxTree.FilePath, symbol.Name, "Union"),
158160
};
161+
}
159162

160163
var typeIDEnumSymbol = symbolTypeMembers.FirstOrDefault(x => x.Name is "TypeIDs");
161164
var typeIDField = symbolMembers.FirstOrDefault(x => x is IFieldSymbol { Name: "TypeID", Type.Name: "TypeIDs" });
@@ -188,13 +191,11 @@ parentInterface is not null &&
188191
BaseDeclaration = Utility.DeconstructTypeDeclaration(structDecl, context.SemanticModel, ct),
189192
BaseTypeName = symbol.Name,
190193
BaseTypeFQN = symbol.FQN,
191-
InterfaceName = interfaceSymbol.Name,
192194
InterfaceFQN = interfaceSymbol.FQN,
193195
TypeIDEnumFQN = isRootTypeIdOwner
194196
? typeIDEnumSymbol?.FQN ?? $"{symbol.FQN}.TypeIDs"
195197
: rootTypeIDEnumSymbol?.FQN ?? $"{rootHeader.FQN}.TypeIDs",
196198
TypeIDFieldName = typeIDField?.Name ?? "TypeID",
197-
RootTypeName = rootHeader.Name,
198199
RootTypeFQN = rootHeader.FQN,
199200
RootInterfaceFQN = rootInterface.FQN,
200201
RootTypeIDEnumFQN = rootTypeIDEnumSymbol?.FQN ?? $"{rootHeader.FQN}.TypeIDs",
@@ -257,24 +258,18 @@ static InterfaceMemberInput[] BuildInterfaceMembers(INamedTypeSymbol interfaceSy
257258

258259
static HeaderFieldInput[] BuildHeaderFields(INamedTypeSymbol headerSymbol)
259260
{
260-
static bool IsAccessible(Accessibility accessibility)
261-
=> accessibility is
262-
Accessibility.Public
263-
or Accessibility.Internal
264-
or Accessibility.ProtectedOrInternal;
265-
266261
var fieldsAndProperties = headerSymbol.GetMembers()
267262
.Where(x => x is IFieldSymbol or IPropertySymbol)
268263
.OrderBy(x => x.Locations.FirstOrDefault()?.SourceSpan.Start ?? int.MaxValue);
269264

270-
var result = new List<HeaderFieldInput>();
265+
var result = new List<HeaderFieldInput>(capacity: 8);
271266
foreach (var member in fieldsAndProperties)
272267
{
273268
switch (member)
274269
{
275270
case IFieldSymbol { IsStatic: false, IsImplicitlyDeclared: false } field:
276271
{
277-
if (!IsAccessible(field.DeclaredAccessibility))
272+
if (!field.IsAccessible)
278273
break;
279274

280275
result.Add(
@@ -290,8 +285,9 @@ or Accessibility.Internal
290285
}
291286
case IPropertySymbol { IsStatic: false, IsIndexer: false } property:
292287
{
293-
bool canGet = property.GetMethod is { } getter && IsAccessible(getter.DeclaredAccessibility);
294-
bool canSet = property.SetMethod is { } setter && IsAccessible(setter.DeclaredAccessibility);
288+
bool canGet = property.GetMethod is { IsAccessible: true };
289+
bool canSet = property.SetMethod is { IsAccessible: true };
290+
295291
if (!canGet && !canSet)
296292
break;
297293

@@ -317,7 +313,7 @@ static Derived TransformDerivedCandidate(GeneratorAttributeSyntaxContext context
317313
if (context.TargetNode is not StructDeclarationSyntax structDecl)
318314
return default;
319315

320-
if (context.SemanticModel.GetDeclaredSymbol(structDecl, ct) is not INamedTypeSymbol symbol)
316+
if (context.SemanticModel.GetDeclaredSymbol(structDecl, ct) is not { } symbol)
321317
return default;
322318

323319
byte? forcedId = context.Attributes
@@ -333,9 +329,8 @@ static Derived TransformDerivedCandidate(GeneratorAttributeSyntaxContext context
333329
DerivedFQN = symbol.FQN,
334330
DerivedName = symbol.Name,
335331
Declaration = Utility.DeconstructTypeDeclaration(structDecl, context.SemanticModel, ct),
336-
HeaderFieldName = GetFirstHeaderField(symbol)?.Name ?? "Header",
337-
HeaderChainFQNs = BuildHeaderChainFQNs(symbol),
338332
ForcedId = forcedId,
333+
HasHeaderInChainFunc = new(headerFQN => HasHeaderInChain(symbol, headerFQN)),
339334
ImplementsUnionInterfaceFunc = new(interfaceFQN => symbol.AllInterfaces.Any(x => x.FQN == interfaceFQN)),
340335
DeferredInputBuilderFunc = new(() => BuildDerivedDeferredInput(symbol)),
341336
DerivedTextCheckSumForCache = structDecl.GetText().GetChecksum().AsArray(),
@@ -345,7 +340,7 @@ static Derived TransformDerivedCandidate(GeneratorAttributeSyntaxContext context
345340
static DerivedDeferredInput BuildDerivedDeferredInput(INamedTypeSymbol symbol)
346341
{
347342
var publicMembers = symbol.GetMembers()
348-
.Where(x => x is { DeclaredAccessibility: Accessibility.Public } and not IMethodSymbol { MethodKind: not MethodKind.Ordinary })
343+
.Where(x => x is { DeclaredAccessibility: Public } and not IMethodSymbol { MethodKind: not MethodKind.Ordinary })
349344
.Select(x => x.Name)
350345
.Distinct()
351346
.ToArray();
@@ -404,7 +399,7 @@ static DerivedInput[] BuildDerivedStructs(Derived[] candidates, string interface
404399
if (candidate.ImplementsUnionInterfaceFunc.Value?.Invoke(rootInterfaceFQN) is not true)
405400
continue;
406401

407-
if (!HasHeaderInChain(candidate, rootHeaderFQN))
402+
if (candidate.HasHeaderInChainFunc.Value?.Invoke(rootHeaderFQN) is not true)
408403
continue;
409404

410405
if (candidate.ForcedId is { } forcedId)
@@ -453,7 +448,7 @@ static DerivedInput[] BuildDerivedStructs(Derived[] candidates, string interface
453448
if (candidate.ImplementsUnionInterfaceFunc.Value?.Invoke(interfaceFQN) is not true)
454449
continue;
455450

456-
if (!HasHeaderInChain(candidate, headerFQN))
451+
if (candidate.HasHeaderInChainFunc.Value?.Invoke(headerFQN) is not true)
457452
continue;
458453

459454
var deferredInput = candidate.DeferredInputBuilderFunc.Value?.Invoke() ?? default;
@@ -464,7 +459,6 @@ static DerivedInput[] BuildDerivedStructs(Derived[] candidates, string interface
464459
Name = candidate.DerivedName,
465460
FQN = candidate.DerivedFQN,
466461
Declaration = candidate.Declaration,
467-
HeaderFieldName = candidate.HeaderFieldName,
468462
AssignedId = assigned.AssignedId,
469463
PubliclyImplementedMembers = deferredInput.PublicMembers,
470464
MemberNames = deferredInput.MemberNames,
@@ -489,36 +483,34 @@ static DerivedInput[] BuildDerivedStructs(Derived[] candidates, string interface
489483
return result;
490484
}
491485

492-
static bool HasHeaderInChain(Derived candidate, string headerFQN)
493-
=> candidate.HeaderChainFQNs.AsArray().Any(x => x.Equals(headerFQN, Ordinal));
494-
495-
static byte GetNextAvailableId(HashSet<byte> usedIds, ref byte nextId)
496-
{
497-
while (usedIds.Contains(nextId))
498-
nextId++;
499-
500-
return nextId++;
501-
}
502-
503-
static EquatableArray<string> BuildHeaderChainFQNs(INamedTypeSymbol symbol)
486+
static bool HasHeaderInChain(INamedTypeSymbol symbol, string headerFQN)
504487
{
505488
var firstHeaderFieldType = GetFirstHeaderFieldType(symbol);
506489
if (firstHeaderFieldType is null)
507-
return [];
490+
return false;
508491

509-
var chain = new List<string>();
510492
var visited = new HashSet<string>(StringComparer.Ordinal);
511493
var current = firstHeaderFieldType;
512494
while (current is not null && current.HasAttribute(UnionHeaderStructAttributeFQN))
513495
{
514496
if (!visited.Add(current.FQN))
515497
break;
516498

517-
chain.Add(current.FQN);
499+
if (current.FQN.Equals(headerFQN, Ordinal))
500+
return true;
501+
518502
current = GetFirstHeaderFieldType(current);
519503
}
520504

521-
return chain.ToArray();
505+
return false;
506+
}
507+
508+
static byte GetNextAvailableId(HashSet<byte> usedIds, ref byte nextId)
509+
{
510+
while (usedIds.Contains(nextId))
511+
nextId++;
512+
513+
return nextId++;
522514
}
523515

524516
static IFieldSymbol? GetFirstHeaderField(INamedTypeSymbol symbol)
@@ -545,10 +537,10 @@ x is
545537
{
546538
Name: "Interface",
547539
TypeKind: TypeKind.Interface,
548-
DeclaredAccessibility: Accessibility.Public,
540+
DeclaredAccessibility: Public,
549541
}
550542
) ??
551-
typeMembers.FirstOrDefault(x => x is { TypeKind: TypeKind.Interface, DeclaredAccessibility: Accessibility.Public });
543+
typeMembers.FirstOrDefault(x => x is { TypeKind: TypeKind.Interface, DeclaredAccessibility: Public });
552544
}
553545

554546
static INamedTypeSymbol GetRootHeader(INamedTypeSymbol headerSymbol, INamedTypeSymbol headerInterface)
@@ -611,10 +603,10 @@ static void GenerateSource(SourceProductionContext context, SourceWriter src, Ge
611603
using (src.Braces)
612604
{
613605
if (headerField.CanGet)
614-
src.Line.Write($"get => {derived.HeaderFieldName}.{headerField.Name};");
606+
src.Line.Write($"get => {m}UnsafeUtility.As<{derived.FQN}, {input.BaseTypeFQN}>(ref this).{headerField.Name};");
615607

616608
if (headerField.CanSet)
617-
src.Line.Write($"set => {derived.HeaderFieldName}.{headerField.Name} = value;");
609+
src.Line.Write($"set => {m}UnsafeUtility.As<{derived.FQN}, {input.BaseTypeFQN}>(ref this).{headerField.Name} = value;");
618610
}
619611

620612
src.Linebreak();

Medicine.SourceGenerator~/Utility/ExtensionMethods.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ public int Hash
8282

8383
public bool IsInMedicineNamespace
8484
=> self.ContainingNamespace is { Name: Constants.Namespace, ContainingNamespace.IsGlobalNamespace: true };
85+
86+
public bool IsAccessible
87+
=> self.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal or Accessibility.ProtectedOrInternal;
8588
}
8689

8790
extension(ISymbol? self)

0 commit comments

Comments
 (0)