Skip to content

Commit 28000a1

Browse files
Saalvagetritao
authored andcommitted
Fix #1251 three parameter equality operator
- Operators in generic classes do not attempt to generate as extension methods anymore - Empty `...Extensions` classes are no longer generated - `string` as a template argument is correctly cast - `MarshalCharAsManagedChar` option also generates correct casts - Suppress warning regarding returning struct field by ref - Eliminate some tabs that snuck into the test C++ header
1 parent d7faf5f commit 28000a1

File tree

7 files changed

+120
-26
lines changed

7 files changed

+120
-26
lines changed

src/AST/TypeExtensions.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,22 @@ public static bool IsDependentPointer(this Type type)
439439
return false;
440440
}
441441

442+
public static bool IsTemplate(this Type type)
443+
{
444+
if (type is TemplateParameterType or TemplateParameterSubstitutionType)
445+
return true;
446+
447+
var ptr = type;
448+
while (ptr is PointerType pType)
449+
{
450+
ptr = pType.Pointee;
451+
if (ptr is TemplateParameterType or TemplateParameterSubstitutionType)
452+
return true;
453+
}
454+
455+
return false;
456+
}
457+
442458
public static Module GetModule(this Type type)
443459
{
444460
Declaration declaration;

src/Generator/Generators/CSharp/CSharpMarshal.cs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ public override bool VisitPointerType(PointerType pointer, TypeQualifiers quals)
601601
if (Context.Context.Options.MarshalCharAsManagedChar &&
602602
primitive == PrimitiveType.Char)
603603
{
604-
Context.Return.Write($"({typePrinter.PrintNative(pointer)})");
604+
Context.Return.StringBuilder.Insert(0, $"({typePrinter.PrintNative(pointer)}) ");
605605
if (isConst)
606606
Context.Return.Write("&");
607607
Context.Return.Write(param.Name);
@@ -643,8 +643,13 @@ public override bool VisitPointerType(PointerType pointer, TypeQualifiers quals)
643643
}
644644
else
645645
{
646-
Context.Before.WriteLine("var {0} = {1}.{2};",
647-
arg, Context.Parameter.Name, Helpers.InstanceIdentifier);
646+
Context.Before.Write($"var {arg} = ");
647+
if (pointer.Pointee.IsTemplate())
648+
Context.Before.Write($"(({Context.Parameter.Type}) (object) {Context.Parameter.Name})");
649+
else
650+
Context.Before.WriteLine(Context.Parameter.Name);
651+
Context.Before.WriteLine($".{Helpers.InstanceIdentifier};");
652+
648653
Context.Return.Write($"new {typePrinter.IntPtrType}(&{arg})");
649654
}
650655

@@ -805,7 +810,12 @@ private void MarshalRefClass(Class @class)
805810

806811
private void MarshalValueClass()
807812
{
808-
Context.Return.Write("{0}.{1}", Context.Parameter.Name, Helpers.InstanceIdentifier);
813+
if (Context.Parameter.Type.IsTemplate())
814+
Context.Return.Write($"(({Context.Parameter.Type}) (object) {Context.Parameter.Name})");
815+
else
816+
Context.Return.Write(Context.Parameter.Name);
817+
818+
Context.Return.Write($".{Helpers.InstanceIdentifier}");
809819
}
810820

811821
public override bool VisitFieldDecl(Field field)

src/Generator/Generators/CSharp/CSharpSources.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ public override void Process()
8585
GenerateUsings();
8686

8787
WriteLine("#pragma warning disable CS0109 // Member does not hide an inherited member; new keyword is not required");
88+
WriteLine("#pragma warning disable CS9084 // Struct member returns 'this' or other instance members by reference");
8889
NewLine();
8990

9091
if (!string.IsNullOrEmpty(Module.OutputNamespace))

src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,19 @@ public override bool VisitClassDecl(Class @class)
5252
if (!methodsWithDependentPointers.Any())
5353
return false;
5454

55+
var hasMethods = false;
5556
var classExtensions = new Class { Name = $"{@class.Name}Extensions", IsStatic = true };
5657
foreach (var specialization in @class.Specializations.Where(s => s.IsGenerated))
5758
foreach (var method in methodsWithDependentPointers.Where(
5859
m => m.SynthKind == FunctionSynthKind.None))
5960
{
6061
var specializedMethod = specialization.Methods.FirstOrDefault(
6162
m => m.InstantiatedFrom == method);
62-
if (specializedMethod == null)
63+
if (specializedMethod == null || specializedMethod.IsOperator)
6364
continue;
6465

66+
hasMethods = true;
67+
6568
Method extensionMethod = GetExtensionMethodForDependentPointer(specializedMethod);
6669
classExtensions.Methods.Add(extensionMethod);
6770
extensionMethod.Namespace = classExtensions;
@@ -75,6 +78,10 @@ public override bool VisitClassDecl(Class @class)
7578
extensionMethod.GenerationKind = GenerationKind.Generate;
7679
}
7780
}
81+
82+
if (!hasMethods)
83+
return false;
84+
7885
classExtensions.Namespace = @class.Namespace;
7986
classExtensions.OriginalClass = @class;
8087
extensions.Add(classExtensions);

src/Generator/Types/Std/Stdlib.CSharp.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,16 +329,24 @@ public override void CSharpMarshalToNative(CSharpMarshalContext ctx)
329329
ctx.Return.Write($@"{qualifiedBasicString}Extensions.{
330330
Helpers.InternalStruct}.{assign.Name}(new {
331331
typePrinter.IntPtrType}(&{
332-
ctx.ReturnVarName}), {ctx.Parameter.Name})");
332+
ctx.ReturnVarName}), ");
333+
if (ctx.Parameter.Type.IsTemplate())
334+
ctx.Return.Write("(string) (object) ");
335+
ctx.Return.Write($"{ctx.Parameter.Name})");
333336
ctx.ReturnVarName = string.Empty;
334337
}
335338
else
336339
{
337340
var varBasicString = $"__basicString{ctx.ParameterIndex}";
338341
ctx.Before.WriteLine($@"var {varBasicString} = new {
339342
basicString.Visit(typePrinter)}();");
340-
ctx.Before.WriteLine($@"{qualifiedBasicString}Extensions.{
341-
assign.Name}({varBasicString}, {ctx.Parameter.Name});");
343+
344+
ctx.Before.Write($@"{qualifiedBasicString}Extensions.{
345+
assign.Name}({varBasicString}, ");
346+
if (ctx.Parameter.Type.IsTemplate())
347+
ctx.Before.Write("(string) (object) ");
348+
ctx.Before.WriteLine($"{ctx.Parameter.Name});");
349+
342350
ctx.Return.Write($"{varBasicString}.{Helpers.InstanceIdentifier}");
343351
ctx.Cleanup.WriteLine($@"{varBasicString}.Dispose({
344352
(!Type.IsAddress() || ctx.Parameter?.IsIndirect == true ? "disposing: true, callNativeDtor:false" : string.Empty)});");

tests/dotnet/CSharp/CSharp.Tests.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,4 +2003,25 @@ public void TestValueTypeOutParameter()
20032003
Assert.AreEqual(2, unionTestA.A);
20042004
Assert.AreEqual(2, unionTestB.B);
20052005
}
2006+
2007+
[TestCase("hi")]
2008+
[TestCase(2u)]
2009+
public void TestOptional<T>(T value)
2010+
{
2011+
Assert.That(new CSharp.Optional<T>() != new CSharp.Optional<T>(value));
2012+
Assert.That(new CSharp.Optional<T>() != value);
2013+
Assert.That(new CSharp.Optional<T>() == new CSharp.Optional<T>());
2014+
Assert.That(new CSharp.Optional<T>(value) == new CSharp.Optional<T>(value));
2015+
Assert.That(new CSharp.Optional<T>(value) == value);
2016+
}
2017+
2018+
[Test]
2019+
public void TestOptionalIntPtr()
2020+
{
2021+
Assert.That(new CSharp.Optional<IntPtr>() != new CSharp.Optional<IntPtr>(IntPtr.MaxValue));
2022+
Assert.That(new CSharp.Optional<IntPtr>() != IntPtr.MaxValue);
2023+
Assert.That(new CSharp.Optional<IntPtr>() == new CSharp.Optional<IntPtr>());
2024+
Assert.That(new CSharp.Optional<IntPtr>(IntPtr.MaxValue) == new CSharp.Optional<IntPtr>(IntPtr.MaxValue));
2025+
Assert.That(new CSharp.Optional<IntPtr>(IntPtr.MaxValue) == IntPtr.MaxValue);
2026+
}
20062027
}

tests/dotnet/CSharp/CSharp.h

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -733,22 +733,22 @@ class DLL_API TestParamToInterfacePassBaseOne
733733

734734
class DLL_API TestParamToInterfacePassBaseTwo
735735
{
736-
int m;
736+
int m;
737737
public:
738-
int getM();
739-
void setM(int n);
740-
const TestParamToInterfacePassBaseTwo& operator++();
741-
TestParamToInterfacePassBaseTwo();
742-
TestParamToInterfacePassBaseTwo(int n);
738+
int getM();
739+
void setM(int n);
740+
const TestParamToInterfacePassBaseTwo& operator++();
741+
TestParamToInterfacePassBaseTwo();
742+
TestParamToInterfacePassBaseTwo(int n);
743743
};
744744

745745
class DLL_API TestParamToInterfacePass : public TestParamToInterfacePassBaseOne, public TestParamToInterfacePassBaseTwo
746746
{
747747
public:
748-
TestParamToInterfacePassBaseTwo addM(TestParamToInterfacePassBaseTwo b);
749-
TestParamToInterfacePassBaseTwo operator+(TestParamToInterfacePassBaseTwo b);
750-
TestParamToInterfacePass(TestParamToInterfacePassBaseTwo b);
751-
TestParamToInterfacePass();
748+
TestParamToInterfacePassBaseTwo addM(TestParamToInterfacePassBaseTwo b);
749+
TestParamToInterfacePassBaseTwo operator+(TestParamToInterfacePassBaseTwo b);
750+
TestParamToInterfacePass(TestParamToInterfacePassBaseTwo b);
751+
TestParamToInterfacePass();
752752
};
753753

754754
class DLL_API HasProtectedVirtual
@@ -973,18 +973,18 @@ class DLL_API ClassWithVirtualBase : public virtual Foo
973973

974974
namespace NamespaceA
975975
{
976-
CS_VALUE_TYPE class DLL_API A
977-
{
978-
};
976+
CS_VALUE_TYPE class DLL_API A
977+
{
978+
};
979979
}
980980

981981
namespace NamespaceB
982982
{
983-
class DLL_API B
984-
{
985-
public:
986-
void Function(CS_OUT NamespaceA::A &a);
987-
};
983+
class DLL_API B
984+
{
985+
public:
986+
void Function(CS_OUT NamespaceA::A &a);
987+
};
988988
}
989989

990990
class DLL_API HasPrivateVirtualProperty
@@ -1607,6 +1607,37 @@ DLL_API extern PointerTester* PointerToClass;
16071607
union DLL_API UnionTester {
16081608
float a;
16091609
int b;
1610+
inline bool operator ==(const UnionTester& other) const {
1611+
return b == other.b;
1612+
}
16101613
};
16111614

16121615
int DLL_API ValueTypeOutParameter(CS_OUT UnionTester* testerA, CS_OUT UnionTester* testerB);
1616+
1617+
template <class T>
1618+
class Optional {
1619+
public:
1620+
T m_value;
1621+
bool m_hasValue;
1622+
1623+
Optional() {
1624+
m_hasValue = false;
1625+
}
1626+
1627+
Optional(T value) {
1628+
m_value = std::move(value);
1629+
m_hasValue = true;
1630+
}
1631+
1632+
inline bool operator ==(const Optional<T>& rhs) const {
1633+
return (m_hasValue == rhs.m_hasValue && (!m_hasValue || m_value == rhs.m_value));
1634+
}
1635+
1636+
inline bool operator ==(const T& rhs) const {
1637+
return (m_hasValue && m_value == rhs);
1638+
}
1639+
};
1640+
1641+
// We just need a method that uses various instantiations of Optional.
1642+
inline void DLL_API InstantiateOptionalTemplate(Optional<unsigned int>, Optional<std::string>,
1643+
Optional<TestComparison>, Optional<char*>, Optional<UnionTester>) { }

0 commit comments

Comments
 (0)