Skip to content

Commit 29adf57

Browse files
IIFEBuild Agent
andauthored
Correctly marshal constant arrays in C++/CLI (#1346)
Co-authored-by: Build Agent <[email protected]>
1 parent a169605 commit 29adf57

File tree

4 files changed

+226
-47
lines changed

4 files changed

+226
-47
lines changed

src/Generator/Generators/CLI/CLIMarshal.cs

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ public void WriteClassInstance(Class @class, string instance, bool ownNativeInst
305305
Context.Return.Write("({0} == nullptr) ? nullptr : gcnew ",
306306
instance);
307307

308-
Context.Return.Write("{0}(", QualifiedIdentifier(@class));
308+
Context.Return.Write("::{0}(", QualifiedIdentifier(@class));
309309
Context.Return.Write("(::{0}*)", @class.QualifiedOriginalName);
310310
Context.Return.Write("{0}{1})", instance, ownNativeInstance ? ", true" : "");
311311
}
@@ -433,39 +433,42 @@ public override bool VisitArrayType(ArrayType array, TypeQualifiers quals)
433433
case ArrayType.ArraySize.Constant:
434434
if (string.IsNullOrEmpty(Context.ReturnVarName))
435435
{
436-
const string pinnedPtr = "__pinnedPtr";
437-
Context.Before.WriteLine("cli::pin_ptr<{0}> {1} = &{2}[0];",
438-
array.Type, pinnedPtr, Context.Parameter.Name);
439-
const string arrayPtr = "__arrayPtr";
440-
Context.Before.WriteLine("{0}* {1} = {2};", array.Type, arrayPtr, pinnedPtr);
441-
Context.Return.Write("({0} (&)[{1}]) {2}", array.Type, array.Size, arrayPtr);
436+
string arrayPtrRet = $"__{Context.ParameterIndex}ArrayPtr";
437+
Context.Before.WriteLine($"{array.Type} {arrayPtrRet}[{array.Size}];");
438+
439+
Context.ReturnVarName = arrayPtrRet;
440+
441+
Context.Return.Write(arrayPtrRet);
442442
}
443-
else
443+
444+
bool isPointerToPrimitive = array.Type.IsPointerToPrimitiveType(PrimitiveType.Void);
445+
bool isPrimitive = array.Type.IsPrimitiveType();
446+
var supportBefore = Context.Before;
447+
supportBefore.WriteLine("if ({0} != nullptr)", Context.Parameter.Name);
448+
supportBefore.WriteOpenBraceAndIndent();
449+
450+
supportBefore.WriteLine($"if ({Context.Parameter.Name}->Length != {array.Size})");
451+
supportBefore.WriteOpenBraceAndIndent();
452+
supportBefore.WriteLine($"throw gcnew System::InvalidOperationException(\"Source array size must equal destination array size.\");");
453+
supportBefore.UnindentAndWriteCloseBrace();
454+
455+
string nativeVal = string.Empty;
456+
if (isPointerToPrimitive)
457+
{
458+
nativeVal = ".ToPointer()";
459+
}
460+
else if (!isPrimitive)
444461
{
445-
bool isPointerToPrimitive = array.Type.IsPointerToPrimitiveType(PrimitiveType.Void);
446-
bool isPrimitive = array.Type.IsPrimitiveType();
447-
var supportBefore = Context.Before;
448-
supportBefore.WriteLine("if ({0} != nullptr)", Context.ArgName);
449-
supportBefore.WriteOpenBraceAndIndent();
450-
451-
string nativeVal = string.Empty;
452-
if (isPointerToPrimitive)
453-
{
454-
nativeVal = ".ToPointer()";
455-
}
456-
else if (!isPrimitive)
457-
{
458-
nativeVal = "->NativePtr";
459-
}
460-
461-
supportBefore.WriteLine("for (int i = 0; i < {0}; i++)", array.Size);
462-
supportBefore.WriteLineIndent("{0}[i] = {1}{2}[i]{3};",
463-
Context.ReturnVarName,
464-
isPointerToPrimitive || isPrimitive ? string.Empty : "*",
465-
Context.ArgName,
466-
nativeVal);
467-
supportBefore.UnindentAndWriteCloseBrace();
462+
nativeVal = "->NativePtr";
468463
}
464+
465+
supportBefore.WriteLine("for (int i = 0; i < {0}; i++)", array.Size);
466+
supportBefore.WriteLineIndent("{0}[i] = {1}{2}[i]{3};",
467+
Context.ReturnVarName,
468+
isPointerToPrimitive || isPrimitive ? string.Empty : "*",
469+
Context.Parameter.Name,
470+
nativeVal);
471+
supportBefore.UnindentAndWriteCloseBrace();
469472
break;
470473
default:
471474
Context.Return.Write("null");
@@ -778,7 +781,8 @@ private void MarshalValueClassProperty(Property property, string marshalVar)
778781
{
779782
ArgName = fieldRef,
780783
ParameterIndex = Context.ParameterIndex++,
781-
MarshalVarPrefix = Context.MarshalVarPrefix
784+
MarshalVarPrefix = Context.MarshalVarPrefix,
785+
ReturnVarName = $"{marshalVar}.{property.Field.OriginalName}"
782786
};
783787

784788
var marshal = new CLIMarshalManagedToNativePrinter(marshalCtx);
@@ -789,23 +793,26 @@ private void MarshalValueClassProperty(Property property, string marshalVar)
789793
if (!string.IsNullOrWhiteSpace(marshal.Context.Before))
790794
Context.Before.Write(marshal.Context.Before);
791795

792-
Type type;
793-
Class @class;
794-
var isRef = property.Type.IsPointerTo(out type) &&
795-
!(type.TryGetClass(out @class) && @class.IsValueType) &&
796-
!type.IsPrimitiveType();
797-
798-
if (isRef)
796+
if (!string.IsNullOrWhiteSpace(marshal.Context.Return))
799797
{
800-
Context.Before.WriteLine("if ({0} != nullptr)", fieldRef);
801-
Context.Before.Indent();
802-
}
798+
Type type;
799+
Class @class;
800+
var isRef = property.Type.IsPointerTo(out type) &&
801+
!(type.TryGetClass(out @class) && @class.IsValueType) &&
802+
!type.IsPrimitiveType();
803803

804-
Context.Before.WriteLine("{0}.{1} = {2};", marshalVar,
805-
property.Field.OriginalName, marshal.Context.Return);
804+
if (isRef)
805+
{
806+
Context.Before.WriteLine("if ({0} != nullptr)", fieldRef);
807+
Context.Before.Indent();
808+
}
806809

807-
if (isRef)
808-
Context.Before.Unindent();
810+
Context.Before.WriteLine("{0}.{1} = {2};", marshalVar,
811+
property.Field.OriginalName, marshal.Context.Return);
812+
813+
if (isRef)
814+
Context.Before.Unindent();
815+
}
809816
}
810817

811818
public override bool VisitFieldDecl(Field field)

tests/CLI/CLI.Tests.cs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using CppSharp.Utils;
22
using NUnit.Framework;
33
using CLI;
4+
using System.Text;
5+
using System;
46

57
public class CLITests : GeneratorTestFixture
68
{
@@ -66,4 +68,114 @@ public void TestVectorPointerGetter()
6668
Assert.AreEqual("VectorPointerGetter", list[0]);
6769
}
6870
}
71+
72+
[Test]
73+
public void TestMultipleConstantArraysParamsTestMethod()
74+
{
75+
byte[] bytes = Encoding.ASCII.GetBytes("TestMulti");
76+
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
77+
78+
byte[] bytes2 = Encoding.ASCII.GetBytes("TestMulti2");
79+
sbyte[] sbytes2 = Array.ConvertAll(bytes2, q => Convert.ToSByte(q));
80+
81+
string s = CLI.CLI.MultipleConstantArraysParamsTestMethod(sbytes, sbytes2);
82+
Assert.AreEqual("TestMultiTestMulti2", s);
83+
}
84+
85+
[Test]
86+
public void TestMultipleConstantArraysParamsTestMethodLongerSourceArray()
87+
{
88+
byte[] bytes = Encoding.ASCII.GetBytes("TestMultipleConstantArraysParamsTestMethodLongerSourceArray");
89+
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
90+
91+
Assert.Throws<InvalidOperationException>(() => CLI.CLI.MultipleConstantArraysParamsTestMethod(sbytes, new sbyte[] { }));
92+
}
93+
94+
[Test]
95+
public void TestStructWithNestedUnionTestMethod()
96+
{
97+
using (var val = new StructWithNestedUnion())
98+
{
99+
byte[] bytes = Encoding.ASCII.GetBytes("TestUnions");
100+
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
101+
102+
UnionNestedInsideStruct unionNestedInsideStruct;
103+
unionNestedInsideStruct.SzText = sbytes;
104+
105+
Assert.AreEqual(sbytes.Length, unionNestedInsideStruct.SzText.Length);
106+
Assert.AreEqual("TestUnions", unionNestedInsideStruct.SzText);
107+
108+
val.NestedUnion = unionNestedInsideStruct;
109+
110+
Assert.AreEqual(10, val.NestedUnion.SzText.Length);
111+
Assert.AreEqual("TestUnions", val.NestedUnion.SzText);
112+
113+
string ret = CLI.CLI.StructWithNestedUnionTestMethod(val);
114+
115+
Assert.AreEqual("TestUnions", ret);
116+
}
117+
}
118+
119+
[Test]
120+
public void TestStructWithNestedUnionLongerSourceArray()
121+
{
122+
using (var val = new StructWithNestedUnion())
123+
{
124+
byte[] bytes = Encoding.ASCII.GetBytes("TestStructWithNestedUnionLongerSourceArray");
125+
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
126+
127+
UnionNestedInsideStruct unionNestedInsideStruct;
128+
unionNestedInsideStruct.SzText = sbytes;
129+
130+
Assert.Throws<InvalidOperationException>(() => val.NestedUnion = unionNestedInsideStruct);
131+
}
132+
}
133+
134+
[Test]
135+
public void TestUnionWithNestedStructTestMethod()
136+
{
137+
using (var val = new StructNestedInsideUnion())
138+
{
139+
byte[] bytes = Encoding.ASCII.GetBytes("TestUnions");
140+
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
141+
val.SzText = sbytes;
142+
143+
UnionWithNestedStruct unionWithNestedStruct;
144+
unionWithNestedStruct.NestedStruct = val;
145+
146+
Assert.AreEqual(10, unionWithNestedStruct.NestedStruct.SzText.Length);
147+
Assert.AreEqual("TestUnions", unionWithNestedStruct.NestedStruct.SzText);
148+
149+
string ret = CLI.CLI.UnionWithNestedStructTestMethod(unionWithNestedStruct);
150+
151+
Assert.AreEqual("TestUnions", ret);
152+
}
153+
}
154+
155+
[Test]
156+
public void TestUnionWithNestedStructArrayTestMethod()
157+
{
158+
using (var val = new StructNestedInsideUnion())
159+
{
160+
using (var val2 = new StructNestedInsideUnion())
161+
{
162+
byte[] bytes = Encoding.ASCII.GetBytes("TestUnion1");
163+
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
164+
val.SzText = sbytes;
165+
166+
byte[] bytes2 = Encoding.ASCII.GetBytes("TestUnion2");
167+
sbyte[] sbytes2 = Array.ConvertAll(bytes2, q => Convert.ToSByte(q));
168+
val2.SzText = sbytes2;
169+
170+
UnionWithNestedStructArray unionWithNestedStructArray;
171+
unionWithNestedStructArray.NestedStructs = new StructNestedInsideUnion[] { val, val2 };
172+
173+
Assert.AreEqual(2, unionWithNestedStructArray.NestedStructs.Length);
174+
175+
string ret = CLI.CLI.UnionWithNestedStructArrayTestMethod(unionWithNestedStructArray);
176+
177+
Assert.AreEqual("TestUnion1TestUnion2", ret);
178+
}
179+
}
180+
}
69181
}

tests/CLI/CLI.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,25 @@ VectorPointerGetter::~VectorPointerGetter()
6363
std::vector<std::string>* VectorPointerGetter::GetVecPtr()
6464
{
6565
return vecPtr;
66+
}
67+
68+
std::string DLL_API MultipleConstantArraysParamsTestMethod(char arr1[9], char arr2[10])
69+
{
70+
return std::string(arr1, arr1 + 9) + std::string(arr2, arr2 + 10);
71+
}
72+
73+
std::string DLL_API StructWithNestedUnionTestMethod(StructWithNestedUnion val)
74+
{
75+
return std::string(val.nestedUnion.szText, val.nestedUnion.szText + 10);
76+
}
77+
78+
std::string DLL_API UnionWithNestedStructTestMethod(UnionWithNestedStruct val)
79+
{
80+
return std::string(val.nestedStruct.szText, val.nestedStruct.szText + 10);
81+
}
82+
83+
std::string DLL_API UnionWithNestedStructArrayTestMethod(UnionWithNestedStructArray arr)
84+
{
85+
return std::string(arr.nestedStructs[0].szText, arr.nestedStructs[0].szText + 10)
86+
+ std::string(arr.nestedStructs[1].szText, arr.nestedStructs[1].szText + 10);
6687
}

tests/CLI/CLI.h

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,43 @@ class DLL_API VectorPointerGetter
8989

9090
private:
9191
std::vector<std::string>* vecPtr;
92-
};
92+
};
93+
94+
// Previously passing multiple constant arrays was generating the same variable name for each array inside the method body.
95+
// This is fixed by using the same generation code in CLIMarshal.VisitArrayType for both when there is a return var name specified and
96+
// for when no return var name is specified.
97+
std::string DLL_API MultipleConstantArraysParamsTestMethod(char arr1[9], char arr2[10]);
98+
99+
// Ensures marshalling arrays is handled correctly for value types used within reference types.
100+
union DLL_API UnionNestedInsideStruct
101+
{
102+
char szText[10];
103+
};
104+
105+
struct DLL_API StructWithNestedUnion
106+
{
107+
UnionNestedInsideStruct nestedUnion;
108+
};
109+
110+
std::string DLL_API StructWithNestedUnionTestMethod(StructWithNestedUnion val);
111+
112+
// Ensures marshalling arrays is handled correctly for reference types used within value types.
113+
struct DLL_API StructNestedInsideUnion
114+
{
115+
char szText[10];
116+
};
117+
118+
union DLL_API UnionWithNestedStruct
119+
{
120+
StructNestedInsideUnion nestedStruct;
121+
};
122+
123+
std::string DLL_API UnionWithNestedStructTestMethod(UnionWithNestedStruct val);
124+
125+
// Ensures marshalling arrays is handled corectly for arrays of reference types used within value types.
126+
union DLL_API UnionWithNestedStructArray
127+
{
128+
StructNestedInsideUnion nestedStructs[2];
129+
};
130+
131+
std::string DLL_API UnionWithNestedStructArrayTestMethod(UnionWithNestedStructArray val);

0 commit comments

Comments
 (0)