Skip to content

Commit 3caa8c5

Browse files
committed
Restore removed specializations
All specializations which only use pointers as their type arguments need at most one internal representation since pointers are mapped to IntPtr. This was achieved by removing the unneeded specializations from their containing list. This was, however, a bug because specializations were thus removed not only as internal structures but in their entirety. Signed-off-by: Dimitar Dobrev <[email protected]>
1 parent 304d673 commit 3caa8c5

File tree

8 files changed

+124
-77
lines changed

8 files changed

+124
-77
lines changed

src/AST/Type.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ public override int GetHashCode() =>
551551
/// <summary>
552552
/// Represents a template argument within a class template specialization.
553553
/// </summary>
554-
public struct TemplateArgument
554+
public class TemplateArgument
555555
{
556556
/// The kind of template argument we're storing.
557557
public enum ArgumentKind

src/Generator/Generators/CSharp/CSharpSources.cs

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ classTemplate.OriginalNamespace is Class &&
312312
GenerateClassTemplateSpecializationsInternals(
313313
nestedTemplate, nestedTemplate.Specializations);
314314

315-
foreach (var specialization in generated)
315+
foreach (var specialization in generated.KeepSingleAllPointersSpecialization())
316316
GenerateClassInternals(specialization);
317317

318318
foreach (var group in generated.SelectMany(s => s.Classes).Where(
@@ -558,14 +558,32 @@ private IEnumerable<Function> GatherClassInternalFunctions(Class @class,
558558
functions.AddRange(GatherClassInternalFunctions(@base.Class, false));
559559

560560
var currentSpecialization = @class as ClassTemplateSpecialization;
561-
Class template;
562-
if (currentSpecialization != null &&
563-
(template = currentSpecialization.TemplatedDecl.TemplatedClass)
564-
.GetSpecializedClassesToGenerate().Count() == 1)
565-
foreach (var specialization in template.Specializations.Where(s => s.IsGenerated))
566-
GatherClassInternalFunctions(specialization, includeCtors, functions);
567-
else
568-
GatherClassInternalFunctions(@class, includeCtors, functions);
561+
if (currentSpecialization != null)
562+
{
563+
Class template = currentSpecialization.TemplatedDecl.TemplatedClass;
564+
IEnumerable<ClassTemplateSpecialization> specializations = null;
565+
if (template.GetSpecializedClassesToGenerate().Count() == 1)
566+
specializations = template.Specializations.Where(s => s.IsGenerated);
567+
else
568+
{
569+
Func<TemplateArgument, bool> allPointers = (TemplateArgument a) =>
570+
a.Type.Type?.Desugar().IsAddress() == true;
571+
if (currentSpecialization.Arguments.All(allPointers))
572+
{
573+
specializations = template.Specializations.Where(
574+
s => s.IsGenerated && s.Arguments.All(allPointers));
575+
}
576+
}
577+
578+
if (specializations != null)
579+
{
580+
foreach (var specialization in specializations)
581+
GatherClassInternalFunctions(specialization, includeCtors, functions);
582+
return functions;
583+
}
584+
}
585+
586+
GatherClassInternalFunctions(@class, includeCtors, functions);
569587

570588
return functions;
571589
}

src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,37 @@ public static void GenerateNativeConstructorsByValue(
2323
{
2424
var printedClass = @class.Visit(gen.TypePrinter);
2525
if (@class.IsDependent)
26-
foreach (var specialization in @class.GetSpecializedClassesToGenerate(
27-
).Where(s => s.IsGenerated))
26+
{
27+
IEnumerable<Class> specializations =
28+
@class.GetSpecializedClassesToGenerate().Where(s => s.IsGenerated);
29+
if (@class.IsTemplate)
30+
specializations = specializations.KeepSingleAllPointersSpecialization();
31+
foreach (var specialization in specializations)
2832
gen.GenerateNativeConstructorByValue(specialization, printedClass.Type);
33+
}
2934
else
35+
{
3036
gen.GenerateNativeConstructorByValue(@class, printedClass.Type);
37+
}
38+
}
39+
40+
public static IEnumerable<Class> KeepSingleAllPointersSpecialization(
41+
this IEnumerable<Class> specializations)
42+
{
43+
Func<TemplateArgument, bool> allPointers = (TemplateArgument a) =>
44+
a.Type.Type?.Desugar().IsAddress() == true;
45+
var groups = (from ClassTemplateSpecialization spec in specializations
46+
group spec by spec.Arguments.All(allPointers)
47+
into @group
48+
select @group).ToList();
49+
foreach (var group in groups)
50+
{
51+
if (group.Key)
52+
yield return group.First();
53+
else
54+
foreach (var specialization in group)
55+
yield return specialization;
56+
}
3157
}
3258

3359
public static void GenerateField(this CSharpSources gen, Class @class,
@@ -112,7 +138,7 @@ private static void WriteTemplateSpecializationCheck(CSharpSources gen,
112138
Enumerable.Range(0, @class.TemplateParameters.Count).Select(
113139
i =>
114140
{
115-
CppSharp.AST.Type type = specialization.Arguments[i].Type.Type.Desugar();
141+
CppSharp.AST.Type type = specialization.Arguments[i].Type.Type;
116142
return type.IsPointerToPrimitiveType() ?
117143
$"__{@class.TemplateParameters[i].Name}.FullName == \"System.IntPtr\"" :
118144
$"__{@class.TemplateParameters[i].Name}.IsAssignableFrom(typeof({type}))";

src/Generator/Generators/CSharp/CSharpTypePrinter.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,14 +560,14 @@ public TypePrinterResult VisitTemplateArgument(TemplateArgument a)
560560
{
561561
if (a.Type.Type == null)
562562
return a.Integral.ToString(CultureInfo.InvariantCulture);
563-
var type = a.Type.Type.Desugar();
563+
var type = a.Type.Type;
564564
PrimitiveType pointee;
565565
if (type.IsPointerToPrimitiveType(out pointee) && !type.IsConstCharString())
566566
{
567567
return $@"CppSharp.Runtime.Pointer<{(pointee == PrimitiveType.Void ? IntPtrType :
568568
VisitPrimitiveType(pointee, new TypeQualifiers()).Type)}>";
569569
}
570-
return (type.IsPrimitiveType(PrimitiveType.Void)) ? "object" : type.Visit(this).Type;
570+
return type.IsPrimitiveType(PrimitiveType.Void) ? "object" : type.Visit(this).Type;
571571
}
572572

573573
public override TypePrinterResult VisitParameterDecl(Parameter parameter)

src/Generator/Passes/CheckDuplicatedNamesPass.cs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,16 @@ public bool Equals(TemplateArgument x, TemplateArgument y)
159159
if (x.Kind != TemplateArgument.ArgumentKind.Type ||
160160
y.Kind != TemplateArgument.ArgumentKind.Type)
161161
return x.Equals(y);
162-
return x.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
163-
ParameterTypeComparer.GeneratorKind).Equals(
164-
y.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
165-
ParameterTypeComparer.GeneratorKind));
162+
Type left = x.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
163+
ParameterTypeComparer.GeneratorKind);
164+
Type right = y.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
165+
ParameterTypeComparer.GeneratorKind);
166+
// consider Type and const Type the same
167+
if (left.IsReference() && !left.IsPointerToPrimitiveType())
168+
left = left.GetPointee();
169+
if (right.IsReference() && !right.IsPointerToPrimitiveType())
170+
right = right.GetPointee();
171+
return left.Equals(right);
166172
}
167173

168174
public int GetHashCode(TemplateArgument obj)

src/Generator/Passes/DelegatesPass.cs

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,38 @@ public override bool VisitClassDecl(Class @class)
4646
return base.VisitClassDecl(@class);
4747
}
4848

49+
public override bool VisitClassTemplateSpecializationDecl(ClassTemplateSpecialization specialization)
50+
{
51+
if (!base.VisitClassTemplateSpecializationDecl(specialization) ||
52+
!specialization.IsGenerated || !specialization.TemplatedDecl.TemplatedDecl.IsGenerated)
53+
return false;
54+
55+
foreach (TemplateArgument arg in specialization.Arguments.Where(
56+
a => a.Kind == TemplateArgument.ArgumentKind.Type))
57+
{
58+
arg.Type = CheckForDelegate(arg.Type, specialization);
59+
}
60+
61+
return true;
62+
}
63+
4964
public override bool VisitMethodDecl(Method method)
5065
{
5166
if (!base.VisitMethodDecl(method) || !method.IsVirtual || method.Ignore)
5267
return false;
5368

54-
method.FunctionType = CheckForDelegate(method.FunctionType, method);
69+
var functionType = new FunctionType
70+
{
71+
CallingConvention = method.CallingConvention,
72+
IsDependent = method.IsDependent,
73+
ReturnType = method.ReturnType
74+
};
75+
76+
functionType.Parameters.AddRange(
77+
method.GatherInternalParams(Context.ParserOptions.IsItaniumLikeAbi, true));
78+
79+
method.FunctionType = CheckForDelegate(new QualifiedType(functionType),
80+
method.Namespace, @private: true);
5581

5682
return true;
5783
}
@@ -61,7 +87,8 @@ public override bool VisitFunctionDecl(Function function)
6187
if (!base.VisitFunctionDecl(function) || function.Ignore)
6288
return false;
6389

64-
function.ReturnType = CheckForDelegate(function.ReturnType, function);
90+
function.ReturnType = CheckForDelegate(function.ReturnType,
91+
function.Namespace);
6592
return true;
6693
}
6794

@@ -71,7 +98,8 @@ public override bool VisitParameterDecl(Parameter parameter)
7198
parameter.Namespace.Ignore)
7299
return false;
73100

74-
parameter.QualifiedType = CheckForDelegate(parameter.QualifiedType, parameter);
101+
parameter.QualifiedType = CheckForDelegate(parameter.QualifiedType,
102+
parameter.Namespace);
75103

76104
return true;
77105
}
@@ -81,7 +109,8 @@ public override bool VisitProperty(Property property)
81109
if (!base.VisitProperty(property))
82110
return false;
83111

84-
property.QualifiedType = CheckForDelegate(property.QualifiedType, property);
112+
property.QualifiedType = CheckForDelegate(property.QualifiedType,
113+
property.Namespace);
85114

86115
return true;
87116
}
@@ -91,12 +120,14 @@ public override bool VisitFieldDecl(Field field)
91120
if (!base.VisitFieldDecl(field))
92121
return false;
93122

94-
field.QualifiedType = CheckForDelegate(field.QualifiedType, field);
123+
field.QualifiedType = CheckForDelegate(field.QualifiedType,
124+
field.Namespace);
95125

96126
return true;
97127
}
98128

99-
private QualifiedType CheckForDelegate(QualifiedType type, ITypedDecl decl)
129+
private QualifiedType CheckForDelegate(QualifiedType type,
130+
DeclarationContext declarationContext, bool @private = false)
100131
{
101132
if (type.Type is TypedefType)
102133
return type;
@@ -109,22 +140,21 @@ private QualifiedType CheckForDelegate(QualifiedType type, ITypedDecl decl)
109140
if (pointee is TypedefType)
110141
return type;
111142

112-
var functionType = pointee.Desugar() as FunctionType;
113-
if (functionType == null)
143+
desugared = pointee.Desugar();
144+
FunctionType functionType = desugared as FunctionType;
145+
if (functionType == null && !desugared.IsPointerTo(out functionType))
114146
return type;
115147

116-
TypedefDecl @delegate = GetDelegate(type, decl);
148+
TypedefDecl @delegate = GetDelegate(functionType, declarationContext, @private);
117149
return new QualifiedType(new TypedefType { Declaration = @delegate });
118150
}
119151

120-
private TypedefDecl GetDelegate(QualifiedType type, ITypedDecl typedDecl)
152+
private TypedefDecl GetDelegate(FunctionType functionType,
153+
DeclarationContext declarationContext, bool @private = false)
121154
{
122-
FunctionType newFunctionType = GetNewFunctionType(typedDecl, type);
123-
124-
var delegateName = GetDelegateName(newFunctionType);
125-
var access = typedDecl is Method ? AccessSpecifier.Private : AccessSpecifier.Public;
126-
var decl = (Declaration) typedDecl;
127-
Module module = decl.TranslationUnit.Module;
155+
var delegateName = GetDelegateName(functionType);
156+
var access = @private ? AccessSpecifier.Private : AccessSpecifier.Public;
157+
Module module = declarationContext.TranslationUnit.Module;
128158
var existingDelegate = delegates.Find(t => Match(t, delegateName, module));
129159
if (existingDelegate != null)
130160
{
@@ -135,18 +165,18 @@ private TypedefDecl GetDelegate(QualifiedType type, ITypedDecl typedDecl)
135165

136166
// Check if there is an existing delegate with a different calling convention
137167
if (((FunctionType) existingDelegate.Type.GetPointee()).CallingConvention ==
138-
newFunctionType.CallingConvention)
168+
functionType.CallingConvention)
139169
return existingDelegate;
140170

141171
// Add a new delegate with the calling convention appended to its name
142-
delegateName += '_' + newFunctionType.CallingConvention.ToString();
172+
delegateName += '_' + functionType.CallingConvention.ToString();
143173
existingDelegate = delegates.Find(t => Match(t, delegateName, module));
144174
if (existingDelegate != null)
145175
return existingDelegate;
146176
}
147177

148-
var namespaceDelegates = GetDeclContextForDelegates(decl.Namespace);
149-
var delegateType = new QualifiedType(new PointerType(new QualifiedType(newFunctionType)));
178+
var namespaceDelegates = GetDeclContextForDelegates(declarationContext);
179+
var delegateType = new QualifiedType(new PointerType(new QualifiedType(functionType)));
150180
existingDelegate = new TypedefDecl
151181
{
152182
Access = access,
@@ -160,30 +190,6 @@ private TypedefDecl GetDelegate(QualifiedType type, ITypedDecl typedDecl)
160190
return existingDelegate;
161191
}
162192

163-
private FunctionType GetNewFunctionType(ITypedDecl decl, QualifiedType type)
164-
{
165-
var functionType = new FunctionType();
166-
var method = decl as Method;
167-
if (method != null && method.FunctionType == type)
168-
{
169-
functionType.Parameters.AddRange(
170-
method.GatherInternalParams(Context.ParserOptions.IsItaniumLikeAbi, true));
171-
functionType.CallingConvention = method.CallingConvention;
172-
functionType.IsDependent = method.IsDependent;
173-
functionType.ReturnType = method.ReturnType;
174-
}
175-
else
176-
{
177-
var funcTypeParam = (FunctionType) decl.Type.Desugar().GetFinalPointee().Desugar();
178-
functionType = new FunctionType(funcTypeParam);
179-
}
180-
181-
for (int i = 0; i < functionType.Parameters.Count; i++)
182-
functionType.Parameters[i].Name = $"_{i}";
183-
184-
return functionType;
185-
}
186-
187193
private static bool Match(TypedefDecl t, string delegateName, Module module)
188194
{
189195
return t.Name == delegateName &&

src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,11 @@ public override bool VisitClassDecl(Class @class)
7272
foreach (var method in methodsWithDependentPointers.Where(
7373
m => m.SynthKind == FunctionSynthKind.None))
7474
{
75-
var specializedMethod = specialization.Methods.First(
75+
var specializedMethod = specialization.Methods.FirstOrDefault(
7676
m => m.InstantiatedFrom == method);
77+
if (specializedMethod == null)
78+
continue;
79+
7780
Method extensionMethod = GetExtensionMethodForDependentPointer(specializedMethod);
7881
classExtensions.Methods.Add(extensionMethod);
7982
extensionMethod.Namespace = classExtensions;

src/Generator/Passes/TrimSpecializationsPass.cs

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -127,22 +127,10 @@ private void CleanSpecializations(Class template)
127127
s => !s.IsExplicitlyGenerated && internalSpecializations.Contains(s)))
128128
specialization.GenerationKind = GenerationKind.Internal;
129129

130-
Func<TemplateArgument, bool> allPointers =
131-
a => a.Type.Type != null && a.Type.Type.IsAddress();
132-
var groups = (from specialization in template.Specializations
133-
group specialization by specialization.Arguments.All(allPointers)
134-
into @group
135-
select @group).ToList();
136-
137-
foreach (var group in groups.Where(g => g.Key))
138-
foreach (var specialization in group.Skip(1))
139-
template.Specializations.Remove(specialization);
140-
141130
for (int i = template.Specializations.Count - 1; i >= 0; i--)
142131
{
143132
var specialization = template.Specializations[i];
144-
if (specialization is ClassTemplatePartialSpecialization &&
145-
!specialization.Arguments.All(allPointers))
133+
if (specialization is ClassTemplatePartialSpecialization)
146134
template.Specializations.RemoveAt(i);
147135
}
148136

0 commit comments

Comments
 (0)