Skip to content

Commit 636f559

Browse files
Use (Un)ManagedValuesSource.Length as length condition in marshalling and unmarshalling loops (#118190)
Instead of assuming the Span returned from Get(Un)ManagedValuesSource is the same length as the length parameter, use the length of the span returned from the marshaller methods in the for loop condition. This allows the marshaller to recognize native values that have a different length than expected. In particular, if an array parameter is null, allow the ArrayMarshaller to return empty spans. Also, updates the Array, Span, ReadOnlySpan, and PointerArray marshallers to check for null unmanaged values when creating unmanaged source/destination spans. Co-authored-by: Aaron Robinson <[email protected]>
1 parent 53aa2bc commit 636f559

File tree

9 files changed

+774
-13
lines changed

9 files changed

+774
-13
lines changed

src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/ArrayMarshaller.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ public static ReadOnlySpan<T> GetManagedValuesSource(T[]? managed)
5858
/// <param name="numElements">The unmanaged element count.</param>
5959
/// <returns>The <see cref="Span{TUnmanagedElement}"/> of unmanaged elements.</returns>
6060
public static Span<TUnmanagedElement> GetUnmanagedValuesDestination(TUnmanagedElement* unmanaged, int numElements)
61-
=> new Span<TUnmanagedElement>(unmanaged, numElements);
61+
{
62+
if (unmanaged is null)
63+
return [];
64+
65+
return new Span<TUnmanagedElement>(unmanaged, numElements);
66+
}
6267

6368
/// <summary>
6469
/// Allocates memory for the managed representation of the array.
@@ -89,7 +94,12 @@ public static Span<T> GetManagedValuesDestination(T[]? managed)
8994
/// <param name="numElements">The unmanaged element count.</param>
9095
/// <returns>The <see cref="ReadOnlySpan{TUnmanagedElement}"/> containing the unmanaged elements to marshal.</returns>
9196
public static ReadOnlySpan<TUnmanagedElement> GetUnmanagedValuesSource(TUnmanagedElement* unmanagedValue, int numElements)
92-
=> new ReadOnlySpan<TUnmanagedElement>(unmanagedValue, numElements);
97+
{
98+
if (unmanagedValue is null)
99+
return [];
100+
101+
return new ReadOnlySpan<TUnmanagedElement>(unmanagedValue, numElements);
102+
}
93103

94104
/// <summary>
95105
/// Frees memory for the unmanaged array.

src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/PointerArrayMarshaller.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@ public static ReadOnlySpan<IntPtr> GetManagedValuesSource(T*[]? managed)
5959
/// <param name="numElements">The unmanaged element count.</param>
6060
/// <returns>The <see cref="Span{TUnmanagedElement}"/> of unmanaged elements.</returns>
6161
public static Span<TUnmanagedElement> GetUnmanagedValuesDestination(TUnmanagedElement* unmanaged, int numElements)
62-
=> new Span<TUnmanagedElement>(unmanaged, numElements);
62+
{
63+
if (unmanaged is null)
64+
return [];
65+
66+
return new Span<TUnmanagedElement>(unmanaged, numElements);
67+
}
6368

6469
/// <summary>
6570
/// Allocates memory for the managed representation of the array.
@@ -90,7 +95,12 @@ public static Span<IntPtr> GetManagedValuesDestination(T*[]? managed)
9095
/// <param name="numElements">The unmanaged element count.</param>
9196
/// <returns>The <see cref="ReadOnlySpan{TUnmanagedElement}"/> containing the unmanaged elements to marshal.</returns>
9297
public static ReadOnlySpan<TUnmanagedElement> GetUnmanagedValuesSource(TUnmanagedElement* unmanagedValue, int numElements)
93-
=> new ReadOnlySpan<TUnmanagedElement>(unmanagedValue, numElements);
98+
{
99+
if (unmanagedValue is null)
100+
return [];
101+
102+
return new ReadOnlySpan<TUnmanagedElement>(unmanagedValue, numElements);
103+
}
94104

95105
/// <summary>
96106
/// Frees memory for the unmanaged array.

src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/ReadOnlySpanMarshaller.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,12 @@ public static ReadOnlySpan<T> GetManagedValuesSource(ReadOnlySpan<T> managed)
6969
/// <param name="numElements">The number of elements that will be copied into the memory block.</param>
7070
/// <returns>A span over the unmanaged memory that can contain the specified number of elements.</returns>
7171
public static Span<TUnmanagedElement> GetUnmanagedValuesDestination(TUnmanagedElement* unmanaged, int numElements)
72-
=> new Span<TUnmanagedElement>(unmanaged, numElements);
72+
{
73+
if (unmanaged == null)
74+
return [];
75+
76+
return new Span<TUnmanagedElement>(unmanaged, numElements);
77+
}
7378
}
7479

7580
/// <summary>
@@ -201,6 +206,9 @@ public ReadOnlySpan<T> ToManaged()
201206
/// <returns>A span over unmanaged values of the array.</returns>
202207
public ReadOnlySpan<TUnmanagedElement> GetUnmanagedValuesSource(int numElements)
203208
{
209+
if (_unmanagedArray is null)
210+
return [];
211+
204212
return new ReadOnlySpan<TUnmanagedElement>(_unmanagedArray, numElements);
205213
}
206214

src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/SpanMarshaller.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,12 @@ public static ReadOnlySpan<T> GetManagedValuesSource(Span<T> managed)
6363
/// <param name="numElements">The number of elements that will be copied into the memory block.</param>
6464
/// <returns>A span over the unmanaged memory that can contain the specified number of elements.</returns>
6565
public static Span<TUnmanagedElement> GetUnmanagedValuesDestination(TUnmanagedElement* unmanaged, int numElements)
66-
=> new Span<TUnmanagedElement>(unmanaged, numElements);
66+
{
67+
if (unmanaged == null)
68+
return [];
69+
70+
return new Span<TUnmanagedElement>(unmanaged, numElements);
71+
}
6772

6873
/// <summary>
6974
/// Allocates space to store the managed elements.
@@ -94,7 +99,12 @@ public static Span<T> GetManagedValuesDestination(Span<T> managed)
9499
/// <param name="numElements">The number of elements in the unmanaged collection.</param>
95100
/// <returns>A span over the native collection elements.</returns>
96101
public static ReadOnlySpan<TUnmanagedElement> GetUnmanagedValuesSource(TUnmanagedElement* unmanaged, int numElements)
97-
=> new ReadOnlySpan<TUnmanagedElement>(unmanaged, numElements);
102+
{
103+
if (unmanaged == null)
104+
return [];
105+
106+
return new ReadOnlySpan<TUnmanagedElement>(unmanaged, numElements);
107+
}
98108

99109
/// <summary>
100110
/// Frees the allocated unmanaged memory.

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ protected ElementsMarshalling(IElementsMarshallingCollectionSource collectionSou
3939
/// </code>
4040
/// </summary>
4141
public StatementSyntax GenerateClearUnmanagedDestination(StubIdentifierContext context)
42-
4342
{
4443
// <GetUnmanagedValuesDestination>.Clear();
4544
return MethodInvocationStatement(
@@ -284,7 +283,7 @@ public override StatementSyntax GenerateMarshalStatement(StubIdentifierContext c
284283
statements.Add(GenerateContentsMarshallingStatement(
285284
context,
286285
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
287-
IdentifierName(MarshallerHelpers.GetManagedSpanIdentifier(CollectionSource.TypeInfo, context)),
286+
IdentifierName(managedSpanIdentifier),
288287
IdentifierName("Length")),
289288
elementMarshaller,
290289
StubIdentifierContext.Stage.Marshal));
@@ -295,7 +294,6 @@ public override StatementSyntax GenerateUnmarshalStatement(StubIdentifierContext
295294
{
296295
string managedSpanIdentifier = MarshallerHelpers.GetManagedSpanIdentifier(CollectionSource.TypeInfo, context);
297296
string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(CollectionSource.TypeInfo, context);
298-
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(CollectionSource.TypeInfo, context);
299297

300298
// ReadOnlySpan<TUnmanagedElement> <nativeSpan> = <GetUnmanagedValuesSource>
301299
// Span<T> <managedSpan> = <GetManagedValuesDestination>
@@ -311,7 +309,9 @@ public override StatementSyntax GenerateUnmarshalStatement(StubIdentifierContext
311309
CollectionSource.GetManagedValuesDestination(context)),
312310
GenerateContentsMarshallingStatement(
313311
context,
314-
IdentifierName(numElementsIdentifier),
312+
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
313+
IdentifierName(nativeSpanIdentifier),
314+
IdentifierName("Length")),
315315
elementMarshaller,
316316
StubIdentifierContext.Stage.UnmarshalCapture, StubIdentifierContext.Stage.Unmarshal));
317317
}
@@ -356,7 +356,9 @@ public override StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalSta
356356
unmanagedValuesDeclaration,
357357
GenerateContentsMarshallingStatement(
358358
context,
359-
IdentifierName(numElementsIdentifier),
359+
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
360+
IdentifierName(managedSpanIdentifier),
361+
IdentifierName("Length")),
360362
elementMarshaller,
361363
StubIdentifierContext.Stage.UnmarshalCapture, StubIdentifierContext.Stage.Unmarshal));
362364
}
@@ -460,7 +462,9 @@ public override StatementSyntax GenerateUnmanagedToManagedByValueOutMarshalState
460462
managedValuesDestination,
461463
GenerateContentsMarshallingStatement(
462464
context,
463-
IdentifierName(numElementsIdentifier),
465+
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
466+
IdentifierName(nativeSpanIdentifier),
467+
IdentifierName("Length")),
464468
new FreeAlwaysOwnedOriginalValueGenerator(elementMarshaller),
465469
stagesToGenerate));
466470
}
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Runtime.InteropServices;
7+
using System.Runtime.InteropServices.Marshalling;
8+
using SharedTypes.ComInterfaces;
9+
using SharedTypes;
10+
using Xunit;
11+
using System.Runtime.CompilerServices;
12+
13+
namespace ComInterfaceGenerator.Tests
14+
{
15+
/// <summary>
16+
/// Tests for edge cases involving null arrays when their length parameters are non-zero.
17+
/// This addresses https://github.com/dotnet/runtime/issues/118135
18+
/// </summary>
19+
public unsafe partial class INullArrayTests
20+
{
21+
private static INullArrayCases CreateTestInterface()
22+
{
23+
INullArrayCases originalObject = new INullArrayCasesImpl();
24+
ComWrappers cw = new StrategyBasedComWrappers();
25+
nint ptr = cw.GetOrCreateComInterfaceForObject(originalObject, CreateComInterfaceFlags.None);
26+
object obj = cw.GetOrCreateObjectForComInstance(ptr, CreateObjectFlags.None);
27+
return (INullArrayCases)obj;
28+
}
29+
30+
[Theory]
31+
[InlineData(0)]
32+
[InlineData(10)]
33+
[InlineData(int.MaxValue)]
34+
public void SingleNullArray_DoesNotCrash(int length)
35+
{
36+
var testInterface = CreateTestInterface();
37+
38+
// Should not throw or crash
39+
testInterface.SingleNullArrayWithLength(length, null);
40+
}
41+
42+
[Fact]
43+
public void SingleNullArray_WithValidArray_WorksNormally()
44+
{
45+
var testInterface = CreateTestInterface();
46+
int[] array = new int[5];
47+
48+
testInterface.SingleNullArrayWithLength(5, array);
49+
50+
Assert.Equal(new int[] { 0, 2, 4, 6, 8 }, array);
51+
}
52+
53+
[Fact]
54+
public void MultipleArrays_SomeNull_DoesNotCrash()
55+
{
56+
var testInterface = CreateTestInterface();
57+
int[] array1 = new int[3];
58+
int[] array3 = new int[3];
59+
60+
testInterface.MultipleArraysSharedLength(3, array1, null, array3);
61+
62+
Assert.Equal(new int[] { 0, 1, 2 }, array1);
63+
Assert.Equal(new int[] { 0, 100, 200 }, array3);
64+
}
65+
66+
[Fact]
67+
public void MultipleArrays_AllNull_DoesNotCrash()
68+
{
69+
var testInterface = CreateTestInterface();
70+
71+
testInterface.MultipleArraysSharedLength(5, null, null, null);
72+
}
73+
74+
[Theory]
75+
[InlineData(0)]
76+
[InlineData(10)]
77+
[InlineData(int.MaxValue)]
78+
public void NonBlittableArray_Null_DoesNotCrash(int length)
79+
{
80+
var testInterface = CreateTestInterface();
81+
82+
testInterface.NonBlittableNullArray(length, null);
83+
}
84+
85+
[Fact]
86+
public void NonBlittableArray_ValidArray_WorksNormally()
87+
{
88+
var testInterface = CreateTestInterface();
89+
var array = new IntStructWrapper[3];
90+
91+
testInterface.NonBlittableNullArray(3, array);
92+
93+
Assert.Equal(0, array[0].Value);
94+
Assert.Equal(3, array[1].Value);
95+
Assert.Equal(6, array[2].Value);
96+
}
97+
98+
[Theory]
99+
[InlineData(0)]
100+
[InlineData(10)]
101+
[InlineData(int.MaxValue)]
102+
public void SpanNull_DoesNotCrash(int length)
103+
{
104+
var testInterface = CreateTestInterface();
105+
106+
Span<int> span = new Span<int>(null, 0);
107+
testInterface.SpanNullCase(length, ref span);
108+
Assert.True(default == span);
109+
}
110+
111+
[Fact]
112+
public void SpanValid_WorksNormally()
113+
{
114+
var testInterface = CreateTestInterface();
115+
116+
var span = new Span<int>(new int[5]);
117+
testInterface.SpanNullCase(5, ref span);
118+
119+
Assert.Equal(0, span[0]);
120+
Assert.Equal(5, span[1]);
121+
Assert.Equal(10, span[2]);
122+
Assert.Equal(15, span[3]);
123+
Assert.Equal(20, span[4]);
124+
}
125+
126+
[Theory]
127+
[InlineData(0)]
128+
[InlineData(10)]
129+
[InlineData(int.MaxValue)]
130+
public void SpanNonBlittable_Null_DoesNotCrash(int length)
131+
{
132+
var testInterface = CreateTestInterface();
133+
Span<IntStructWrapper> span = new Span<IntStructWrapper>(null, 0);
134+
testInterface.SpanNonBlittableNullCase(length, ref span);
135+
Assert.True(default == span);
136+
}
137+
138+
[Fact]
139+
public void SpanNonBlittable_Valid_WorksNormally()
140+
{
141+
var testInterface = CreateTestInterface();
142+
143+
var span = new Span<IntStructWrapper>(new IntStructWrapper[3]);
144+
testInterface.SpanNonBlittableNullCase(3, ref span);
145+
146+
Assert.Equal(0, span[0].Value);
147+
Assert.Equal(7, span[1].Value);
148+
Assert.Equal(14, span[2].Value);
149+
}
150+
151+
[Theory]
152+
[InlineData(0)]
153+
[InlineData(10)]
154+
[InlineData(int.MaxValue)]
155+
public void InputOnlyArray_Null_DoesNotCrash(int length)
156+
{
157+
var testInterface = CreateTestInterface();
158+
159+
testInterface.InOnlyNullArray(length, null);
160+
}
161+
162+
[Theory]
163+
[InlineData(0)]
164+
[InlineData(10)]
165+
[InlineData(int.MaxValue)]
166+
public void OutputOnlyArray_Null_DoesNotCrash(int length)
167+
{
168+
var testInterface = CreateTestInterface();
169+
170+
testInterface.OutOnlyNullArray(length, null);
171+
}
172+
173+
[Fact]
174+
public void OutputOnlyArray_Valid_WorksNormally()
175+
{
176+
var testInterface = CreateTestInterface();
177+
var array = new int[3];
178+
179+
testInterface.OutOnlyNullArray(3, array);
180+
181+
Assert.Equal(new int[] { 1000, 1001, 1002 }, array);
182+
}
183+
184+
[Theory]
185+
[InlineData(0)]
186+
[InlineData(10)]
187+
[InlineData(int.MaxValue)]
188+
public void ReferenceArray_Null_DoesNotCrash(int length)
189+
{
190+
var testInterface = CreateTestInterface();
191+
192+
testInterface.ReferenceArrayNullCase(length, null);
193+
}
194+
195+
[Fact]
196+
public void ReferenceArray_ValidArray_WorksNormally()
197+
{
198+
var testInterface = CreateTestInterface();
199+
string[] array = new string[3];
200+
201+
testInterface.ReferenceArrayNullCase(3, array);
202+
203+
Assert.Equal(new string[] { "Item 0", "Item 1", "Item 2" }, array);
204+
}
205+
}
206+
}

0 commit comments

Comments
 (0)