Skip to content

Commit a412b47

Browse files
committed
CollectionsMarshal
1 parent 9940453 commit a412b47

File tree

3 files changed

+206
-0
lines changed

3 files changed

+206
-0
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
using System.Runtime.CompilerServices;
2+
using System.Runtime.InteropServices;
3+
4+
[assembly: TypeForwardedTo(typeof(CollectionsMarshal))]
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System.Collections.Generic;
2+
using System.Runtime.CompilerServices;
3+
4+
namespace System.Runtime.InteropServices
5+
{
6+
public static class CollectionsMarshal
7+
{
8+
public static Span<T> AsSpan<T>(List<T>? list)
9+
{
10+
if (list is null)
11+
{
12+
return Span<T>.Empty;
13+
}
14+
15+
return Unsafe.As<T[]>(CollectionsMarshalEx.ListFieldHolder<T>.ItemsField.GetValue(list));
16+
}
17+
}
18+
}
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
#if NET5_0_OR_GREATER
2+
#define HAS_ASSPAN
3+
#endif
4+
#if NET6_0_OR_GREATER
5+
#define HAS_GETVALUEREF
6+
#endif
7+
#if NET8_0_OR_GREATER
8+
#define HAS_SETCOUNT
9+
#endif
10+
11+
using System.Collections.Generic;
12+
using System.Reflection;
13+
using System.Reflection.Emit;
14+
using System.Runtime.CompilerServices;
15+
16+
namespace System.Runtime.InteropServices
17+
{
18+
public static unsafe class CollectionsMarshalEx
19+
{
20+
#if !HAS_SETCOUNT
21+
internal static class ListFieldHolder<T>
22+
{
23+
#if !HAS_ASSPAN
24+
public static FieldInfo ItemsField;
25+
#endif
26+
public static FieldInfo CountField;
27+
public static FieldInfo? VersionField;
28+
29+
static ListFieldHolder()
30+
{
31+
var t = typeof(List<T>);
32+
33+
#if !HAS_ASSPAN
34+
ItemsField = t.GetField("_items", BindingFlags.Instance | BindingFlags.NonPublic)
35+
?? throw new NotSupportedException("Could not get List items field");
36+
#endif
37+
CountField = t.GetField("_count", BindingFlags.Instance | BindingFlags.NonPublic)
38+
?? throw new NotSupportedException("Could not get List count field");
39+
VersionField = t.GetField("_version", BindingFlags.Instance | BindingFlags.NonPublic);
40+
}
41+
}
42+
#endif
43+
44+
#if !HAS_GETVALUEREF
45+
private static class DictRelfectionHolder<TKey, TValue> where TKey : notnull
46+
{
47+
public delegate ref TValue? EntryValueFieldRefGetter(Dictionary<TKey, TValue> dict, TKey key);
48+
public static EntryValueFieldRefGetter GetEntryValueFieldRef;
49+
50+
static DictRelfectionHolder()
51+
{
52+
var dictType = typeof(Dictionary<TKey, TValue>);
53+
54+
var findValueMethod = dictType.GetMethod("FindValue", BindingFlags.Instance | BindingFlags.NonPublic,
55+
null, [typeof(TKey)], null);
56+
57+
if (findValueMethod is not null)
58+
{
59+
GetEntryValueFieldRef = (EntryValueFieldRefGetter)Delegate.CreateDelegate(typeof(EntryValueFieldRefGetter), findValueMethod);
60+
return;
61+
}
62+
63+
var entriesField = dictType.GetField("_entries", BindingFlags.Instance | BindingFlags.NonPublic)
64+
?? throw new NotSupportedException("Could not get dictionary entries array field");
65+
66+
var entryType = entriesField.FieldType.GetElementType()!;
67+
var entryValueField = entryType.GetField("value",
68+
BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public)!;
69+
70+
var findEntryMethod = dictType.GetMethod("FindEntry", BindingFlags.Instance | BindingFlags.NonPublic)
71+
?? throw new NotSupportedException("Could not get dictionary find entry method");
72+
73+
var dm = new DynamicMethod("GetEntryValueFieldRef", typeof(TValue).MakeByRefType(), [dictType, typeof(TKey)], typeof(CollectionExtensionsEx), true);
74+
var il = dm.GetILGenerator();
75+
76+
var entryIndex = il.DeclareLocal(typeof(int));
77+
var successLabel = il.DefineLabel();
78+
79+
il.Emit(OpCodes.Ldarg_0);
80+
il.Emit(OpCodes.Ldarg_1);
81+
il.Emit(OpCodes.Callvirt, findEntryMethod);
82+
il.Emit(OpCodes.Stloc, entryIndex);
83+
il.Emit(OpCodes.Ldloc, entryIndex);
84+
il.Emit(OpCodes.Ldc_I4_0);
85+
il.Emit(OpCodes.Bge, successLabel);
86+
il.Emit(OpCodes.Ldc_I4_0);
87+
il.Emit(OpCodes.Conv_U);
88+
// im not taking any risks here
89+
il.Emit(OpCodes.Call, typeof(Unsafe).GetMethod("AsRef", [typeof(void*)])!.MakeGenericMethod(typeof(TValue)));
90+
il.Emit(OpCodes.Ret);
91+
il.MarkLabel(successLabel);
92+
il.Emit(OpCodes.Ldarg_0);
93+
il.Emit(OpCodes.Ldfld, entriesField);
94+
il.Emit(OpCodes.Ldloc, entryIndex);
95+
il.Emit(OpCodes.Ldelema, entriesField.FieldType.GetElementType()!);
96+
il.Emit(OpCodes.Ldflda, entryValueField);
97+
il.Emit(OpCodes.Ret);
98+
99+
GetEntryValueFieldRef = (EntryValueFieldRefGetter)dm.CreateDelegate(typeof(EntryValueFieldRefGetter));
100+
}
101+
}
102+
#endif
103+
104+
extension(CollectionsMarshal)
105+
{
106+
public static void SetCount<T>(List<T> list, int count)
107+
{
108+
#if HAS_SETCOUNT
109+
CollectionsMarshal.SetCount(list, count);
110+
#else
111+
ArgumentNullException.ThrowIfNull(list);
112+
if (count < 0)
113+
{
114+
throw new ArgumentOutOfRangeException(nameof(count));
115+
}
116+
117+
// setting the version field only really needs to be best effort
118+
if (ListFieldHolder<T>.VersionField is { } versionField)
119+
{
120+
versionField.SetValue(list, (int)versionField.GetValue(list) + 1);
121+
}
122+
123+
if (count > list.Capacity)
124+
{
125+
// taken from List<T>.EnsureCapacity
126+
var newCapacity = list.Capacity == 0 ? 4 : 2 * list.Capacity;
127+
128+
// Allow the list to grow to maximum possible capacity (~2G elements) before encountering overflow.
129+
// Note that this check works even when _items.Length overflowed thanks to the (uint) cast
130+
if ((uint)newCapacity > Array.MaxLength)
131+
{
132+
newCapacity = Array.MaxLength;
133+
}
134+
135+
// If the computed capacity is still less than specified, set to the original argument.
136+
// Capacities exceeding Array.MaxLength will be surfaced as OutOfMemoryException by Array.Resize.
137+
if (newCapacity < count)
138+
{
139+
newCapacity = count;
140+
}
141+
142+
list.Capacity = newCapacity;
143+
}
144+
145+
// TODO: IsReferenceOrContainsReferences
146+
if (count < list.Count)
147+
{
148+
CollectionsMarshal.AsSpan(list).Slice(count + 1).Clear();
149+
}
150+
151+
ListFieldHolder<T>.CountField.SetValue(list, count);
152+
#endif
153+
}
154+
155+
public static ref TValue GetValueRefOrNullRef<TKey, TValue>(Dictionary<TKey, TValue> dict, TKey key)
156+
where TKey : notnull
157+
{
158+
#if HAS_GETVALUEREF
159+
return ref CollectionsMarshal.GetValueRefOrNullRef(dict, key);
160+
#else
161+
// they don't validate for null so neither will we
162+
return ref DictRelfectionHolder<TKey, TValue>.GetEntryValueFieldRef(dict, key)!;
163+
#endif
164+
}
165+
166+
public static ref TValue? GetValueRefOrAddDefault<TKey, TValue>(Dictionary<TKey, TValue> dict, TKey key)
167+
where TKey : notnull
168+
{
169+
#if HAS_GETVALUEREF
170+
return ref CollectionsMarshal.GetValueRefOrAddDefault(dict, key);
171+
#else
172+
// they don't validate for null so neither will we
173+
if (dict.ContainsKey(key))
174+
{
175+
return ref DictRelfectionHolder<TKey, TValue>.GetEntryValueFieldRef(dict, key);
176+
}
177+
178+
dict.Add(key, default!);
179+
return ref DictRelfectionHolder<TKey, TValue>.GetEntryValueFieldRef(dict, key);
180+
#endif
181+
}
182+
}
183+
}
184+
}

0 commit comments

Comments
 (0)