Skip to content

Commit 62b2f65

Browse files
authored
Enable nullable annotations in StructType and UnionType from ctypes (#1921)
1 parent 99a014f commit 62b2f65

File tree

4 files changed

+104
-95
lines changed

4 files changed

+104
-95
lines changed

src/core/IronPython.Modules/_ctypes/INativeType.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ int Alignment {
5050
/// Serializes the provided value into the specified address at the given
5151
/// offset.
5252
/// </summary>
53-
object SetValue(MemoryHolder/*!*/ address, int offset, object value);
53+
object? SetValue(MemoryHolder/*!*/ address, int offset, object value);
5454

5555
/// <summary>
5656
/// Gets the .NET type which is used when calling or returning the value
@@ -68,12 +68,12 @@ int Alignment {
6868
/// Emits marshalling of an object from Python to native code. This produces the
6969
/// native type from the Python type.
7070
/// </summary>
71-
MarshalCleanup EmitMarshalling(ILGenerator/*!*/ method, LocalOrArg/*!*/ argIndex, List<object>/*!*/ constantPool, int constantPoolArgument);
71+
MarshalCleanup? EmitMarshalling(ILGenerator/*!*/ method, LocalOrArg/*!*/ argIndex, List<object>/*!*/ constantPool, int constantPoolArgument);
7272

7373
/// <summary>
7474
/// Emits marshalling from native code to Python code This produces the python type
7575
/// from the native type. This is used for return values and parameters
76-
/// to Python callable objects that are passed back out to native code.
76+
/// to Python callable objects that are passed back out of native code.
7777
/// </summary>
7878
void EmitReverseMarshalling(ILGenerator/*!*/ method, LocalOrArg/*!*/ value, List<object>/*!*/ constantPool, int constantPoolArgument);
7979

src/core/IronPython.Modules/_ctypes/StructType.cs

Lines changed: 68 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
33
// See the LICENSE file in the project root for more information.
44

5+
#nullable enable
6+
57
#if FEATURE_CTYPES
68

79
using System;
8-
using System.Collections;
910
using System.Collections.Generic;
1011
using System.Diagnostics;
12+
using System.Diagnostics.CodeAnalysis;
1113
using System.Numerics;
1214
using System.Reflection.Emit;
13-
using System.Runtime.InteropServices;
1415
using System.Text;
1516

1617
using Microsoft.Scripting;
@@ -19,6 +20,7 @@
1920
using IronPython.Runtime.Operations;
2021
using IronPython.Runtime.Types;
2122

23+
2224
namespace IronPython.Modules {
2325
/// <summary>
2426
/// Provides support for interop with native code from Python code.
@@ -30,15 +32,16 @@ public static partial class CTypes {
3032
/// </summary>
3133
[PythonType, PythonHidden]
3234
public class StructType : PythonType, INativeType {
33-
internal Field[] _fields;
35+
[DisallowNull]
36+
internal Field[]? _fields; // not null after type construction completes
3437
private int? _size, _alignment, _pack;
35-
private static readonly Field[] _emptyFields = System.Array.Empty<Field>(); // fields were never initialized before a type was created
38+
private static readonly Field[] _emptyFields = []; // fields were never initialized before a type was created
3639

37-
public StructType(CodeContext/*!*/ context, string name, PythonTuple bases, PythonDictionary members)
40+
public StructType(CodeContext/*!*/ context, [NotNone] string name, [NotNone] PythonTuple bases, [NotNone] PythonDictionary members)
3841
: base(context, name, bases, members) {
3942

4043
foreach (PythonType pt in ResolutionOrder) {
41-
StructType st = pt as StructType;
44+
StructType? st = pt as StructType;
4245
if (st != this) {
4346
st?.EnsureFinal();
4447
}
@@ -71,11 +74,11 @@ private StructType(Type underlyingSystemType)
7174
: base(underlyingSystemType) {
7275
}
7376

74-
public static ArrayType/*!*/ operator *(StructType type, int count) {
77+
public static ArrayType/*!*/ operator *([NotNone] StructType type, int count) {
7578
return MakeArrayType(type, count);
7679
}
7780

78-
public static ArrayType/*!*/ operator *(int count, StructType type) {
81+
public static ArrayType/*!*/ operator *(int count, [NotNone] StructType type) {
7982
return MakeArrayType(type, count);
8083
}
8184

@@ -93,13 +96,13 @@ public _Structure from_address(CodeContext/*!*/ context, IntPtr ptr) {
9396
return res;
9497
}
9598

96-
public _Structure from_buffer(CodeContext/*!*/ context, object/*?*/ data, int offset = 0) {
99+
public _Structure from_buffer(CodeContext/*!*/ context, object? data, int offset = 0) {
97100
_Structure res = (_Structure)CreateInstance(context);
98101
res.InitializeFromBuffer(data, offset, ((INativeType)this).Size);
99102
return res;
100103
}
101104

102-
public _Structure from_buffer_copy(CodeContext/*!*/ context, object/*?*/ data, int offset = 0) {
105+
public _Structure from_buffer_copy(CodeContext/*!*/ context, object? data, int offset = 0) {
103106
_Structure res = (_Structure)CreateInstance(context);
104107
res.InitializeFromBufferCopy(data, offset, ((INativeType)this).Size);
105108
return res;
@@ -110,19 +113,19 @@ public _Structure from_buffer_copy(CodeContext/*!*/ context, object/*?*/ data, i
110113
///
111114
/// Structures just return themselves.
112115
/// </summary>
113-
public object from_param(object obj) {
116+
public object from_param(object? obj) {
114117
if (!Builtin.isinstance(obj, this)) {
115-
throw PythonOps.TypeError("expected {0} instance got {1}", Name, PythonOps.GetPythonTypeName(obj));
118+
throw PythonOps.TypeError("expected {0} instance, got {1}", Name, PythonOps.GetPythonTypeName(obj));
116119
}
117120

118-
return obj;
121+
return obj!;
119122
}
120123

121-
public object in_dll(object library, string name) {
124+
public object in_dll(object? library, [NotNone] string name) {
122125
throw new NotImplementedException("in dll");
123126
}
124127

125-
public new virtual void __setattr__(CodeContext/*!*/ context, string name, object value) {
128+
public new virtual void __setattr__(CodeContext/*!*/ context, [NotNone] string name, object? value) {
126129
if (name == "_fields_") {
127130
lock (this) {
128131
if (_fields != null) {
@@ -160,7 +163,7 @@ object INativeType.GetValue(MemoryHolder/*!*/ owner, object readingFrom, int off
160163
return res;
161164
}
162165

163-
object INativeType.SetValue(MemoryHolder/*!*/ address, int offset, object value) {
166+
object? INativeType.SetValue(MemoryHolder/*!*/ address, int offset, object value) {
164167
try {
165168
return SetValueInternal(address, offset, value);
166169
} catch (ArgumentTypeException e) {
@@ -174,24 +177,21 @@ object INativeType.SetValue(MemoryHolder/*!*/ address, int offset, object value)
174177
}
175178
}
176179

177-
internal object SetValueInternal(MemoryHolder address, int offset, object value) {
178-
IList<object> init = value as IList<object>;
179-
if (init != null) {
180+
internal object? SetValueInternal(MemoryHolder address, int offset, object value) {
181+
if (value is IList<object> init) {
182+
EnsureFinal();
180183
if (init.Count > _fields.Length) {
181184
throw PythonOps.TypeError("too many initializers");
182185
}
183186

184187
for (int i = 0; i < init.Count; i++) {
185188
_fields[i].SetValue(address, offset, init[i]);
186189
}
190+
} else if (value is CData data) {
191+
data.MemHolder.CopyTo(address, offset, data.Size);
192+
return data.MemHolder.EnsureObjects();
187193
} else {
188-
CData data = value as CData;
189-
if (data != null) {
190-
data.MemHolder.CopyTo(address, offset, data.Size);
191-
return data.MemHolder.EnsureObjects();
192-
} else {
193-
throw new NotImplementedException("set value");
194-
}
194+
throw new NotImplementedException("set value");
195195
}
196196
return null;
197197
}
@@ -202,7 +202,7 @@ internal object SetValueInternal(MemoryHolder address, int offset, object value)
202202
return GetMarshalTypeFromSize(_size.Value);
203203
}
204204

205-
MarshalCleanup INativeType.EmitMarshalling(ILGenerator/*!*/ method, LocalOrArg argIndex, List<object>/*!*/ constantPool, int constantPoolArgument) {
205+
MarshalCleanup? INativeType.EmitMarshalling(ILGenerator/*!*/ method, LocalOrArg argIndex, List<object>/*!*/ constantPool, int constantPoolArgument) {
206206
Type argumentType = argIndex.Type;
207207
argIndex.Emit(method);
208208
if (argumentType.IsValueType) {
@@ -212,8 +212,8 @@ MarshalCleanup INativeType.EmitMarshalling(ILGenerator/*!*/ method, LocalOrArg a
212212
method.Emit(OpCodes.Ldarg, constantPoolArgument);
213213
method.Emit(OpCodes.Ldc_I4, constantPool.Count - 1);
214214
method.Emit(OpCodes.Ldelem_Ref);
215-
method.Emit(OpCodes.Call, typeof(ModuleOps).GetMethod(nameof(ModuleOps.CheckCDataType)));
216-
method.Emit(OpCodes.Call, typeof(CData).GetProperty(nameof(CData.UnsafeAddress)).GetGetMethod());
215+
method.Emit(OpCodes.Call, typeof(ModuleOps).GetMethod(nameof(ModuleOps.CheckCDataType))!);
216+
method.Emit(OpCodes.Call, typeof(CData).GetProperty(nameof(CData.UnsafeAddress))!.GetGetMethod()!);
217217
method.Emit(OpCodes.Ldobj, ((INativeType)this).GetNativeType());
218218
return null;
219219
}
@@ -251,24 +251,24 @@ internal static PythonType MakeSystemType(Type underlyingSystemType) {
251251
return PythonType.SetPythonType(underlyingSystemType, new StructType(underlyingSystemType));
252252
}
253253

254-
private void SetFields(object fields) {
254+
[MemberNotNull(nameof(_fields), nameof(_size), nameof(_alignment))]
255+
private void SetFields(object? fields) {
255256
lock (this) {
256-
IList<object> list = GetFieldsList(fields);
257+
IList<object> fieldDefList = GetFieldsList(fields);
257258

258259
int? bitCount = null;
259260
int? curBitCount = null;
260-
INativeType lastType = null;
261+
INativeType? lastType = null;
261262
List<Field> allFields = GetBaseSizeAlignmentAndFields(out int size, out int alignment);
262263

263-
IList<object> anonFields = GetAnonymousFields(this);
264+
IList<object>? anonFields = GetAnonymousFields(this);
264265

265-
for (int fieldIndex = 0; fieldIndex < list.Count; fieldIndex++) {
266-
object o = list[fieldIndex];
267-
GetFieldInfo(this, o, out string fieldName, out INativeType cdata, out bitCount);
266+
foreach (object fieldDef in fieldDefList) {
267+
GetFieldInfo(this, fieldDef, out string fieldName, out INativeType cdata, out bitCount);
268268

269269
int prevSize = UpdateSizeAndAlignment(cdata, bitCount, lastType, ref size, ref alignment, ref curBitCount);
270270

271-
Field newField = new Field(fieldName, cdata, prevSize, allFields.Count, bitCount, curBitCount - bitCount);
271+
var newField = new Field(fieldName, cdata, prevSize, allFields.Count, bitCount, curBitCount - bitCount);
272272
allFields.Add(newField);
273273
AddSlot(fieldName, newField);
274274

@@ -282,16 +282,18 @@ private void SetFields(object fields) {
282282
CheckAnonymousFields(allFields, anonFields);
283283

284284
if (bitCount != null) {
285-
size += lastType.Size;
285+
// incomplete last bitfield
286+
// bitCount not null implies at least one bitfield, so at least one iteration of the loop above
287+
size += lastType!.Size;
286288
}
287289

288-
_fields = allFields.ToArray();
290+
_fields = [..allFields];
289291
_size = PythonStruct.Align(size, alignment);
290292
_alignment = alignment;
291293
}
292294
}
293295

294-
internal static void CheckAnonymousFields(List<Field> allFields, IList<object> anonFields) {
296+
internal static void CheckAnonymousFields(List<Field> allFields, IList<object>? anonFields) {
295297
if (anonFields != null) {
296298
foreach (string s in anonFields) {
297299
bool found = false;
@@ -309,9 +311,9 @@ internal static void CheckAnonymousFields(List<Field> allFields, IList<object> a
309311
}
310312
}
311313

312-
internal static IList<object> GetAnonymousFields(PythonType type) {
314+
internal static IList<object>? GetAnonymousFields(PythonType type) {
313315
object anonymous;
314-
IList<object> anonFields = null;
316+
IList<object>? anonFields = null;
315317
if (type.TryGetBoundAttr(type.Context.SharedContext, type, "_anonymous_", out anonymous)) {
316318
anonFields = anonymous as IList<object>;
317319
if (anonFields == null) {
@@ -323,16 +325,18 @@ internal static IList<object> GetAnonymousFields(PythonType type) {
323325

324326
internal static void AddAnonymousFields(PythonType type, List<Field> allFields, INativeType cdata, Field newField) {
325327
Field[] childFields;
326-
if (cdata is StructType) {
327-
childFields = ((StructType)cdata)._fields;
328-
} else if (cdata is UnionType) {
329-
childFields = ((UnionType)cdata)._fields;
328+
if (cdata is StructType st) {
329+
st.EnsureFinal();
330+
childFields = st._fields;
331+
} else if (cdata is UnionType un) {
332+
un.EnsureFinal();
333+
childFields = un._fields;
330334
} else {
331335
throw PythonOps.TypeError("anonymous field must be struct or union");
332336
}
333337

334338
foreach (Field existingField in childFields) {
335-
Field anonField = new Field(
339+
var anonField = new Field(
336340
existingField.FieldName,
337341
existingField.NativeType,
338342
checked(existingField.offset + newField.offset),
@@ -347,12 +351,12 @@ internal static void AddAnonymousFields(PythonType type, List<Field> allFields,
347351
private List<Field> GetBaseSizeAlignmentAndFields(out int size, out int alignment) {
348352
size = 0;
349353
alignment = 1;
350-
List<Field> allFields = new List<Field>();
351-
INativeType lastType = null;
354+
List<Field> allFields = [];
355+
INativeType? lastType = null;
352356
int? totalBitCount = null;
353357
foreach (PythonType pt in BaseTypes) {
354-
StructType st = pt as StructType;
355-
if (st != null) {
358+
if (pt is StructType st) {
359+
st.EnsureFinal();
356360
foreach (Field f in st._fields) {
357361
allFields.Add(f);
358362
UpdateSizeAndAlignment(f.NativeType, f.BitCount, lastType, ref size, ref alignment, ref totalBitCount);
@@ -368,7 +372,8 @@ private List<Field> GetBaseSizeAlignmentAndFields(out int size, out int alignmen
368372
return allFields;
369373
}
370374

371-
private int UpdateSizeAndAlignment(INativeType cdata, int? bitCount, INativeType lastType, ref int size, ref int alignment, ref int? totalBitCount) {
375+
private int UpdateSizeAndAlignment(INativeType cdata, int? bitCount, INativeType? lastType, ref int size, ref int alignment, ref int? totalBitCount) {
376+
Debug.Assert(totalBitCount == null || lastType != null); // lastType is null only on the first iteration, when totalBitCount is null as well
372377
int prevSize = size;
373378
if (bitCount != null) {
374379
if (lastType != null && lastType.Size != cdata.Size) {
@@ -382,7 +387,7 @@ private int UpdateSizeAndAlignment(INativeType cdata, int? bitCount, INativeType
382387
if ((bitCount + totalBitCount + 7) / 8 <= cdata.Size) {
383388
totalBitCount = bitCount + totalBitCount;
384389
} else {
385-
size += lastType.Size;
390+
size += lastType!.Size;
386391
prevSize = size;
387392
totalBitCount = bitCount;
388393
}
@@ -391,7 +396,7 @@ private int UpdateSizeAndAlignment(INativeType cdata, int? bitCount, INativeType
391396
}
392397
} else {
393398
if (totalBitCount != null) {
394-
size += lastType.Size;
399+
size += lastType!.Size;
395400
prevSize = size;
396401
totalBitCount = null;
397402
}
@@ -411,6 +416,7 @@ private int UpdateSizeAndAlignment(INativeType cdata, int? bitCount, INativeType
411416
return prevSize;
412417
}
413418

419+
[MemberNotNull(nameof(_fields), nameof(_size), nameof(_alignment))]
414420
internal void EnsureFinal() {
415421
if (_fields == null) {
416422
SetFields(PythonTuple.EMPTY);
@@ -419,6 +425,8 @@ internal void EnsureFinal() {
419425
// track that we were initialized w/o fields.
420426
_fields = _emptyFields;
421427
}
428+
} else if (_size == null || _alignment == null) {
429+
throw new InvalidOperationException("fields initialized w/o size or alignment");
422430
}
423431
}
424432

@@ -427,19 +435,22 @@ internal void EnsureFinal() {
427435
/// from all of our base classes. If later new _fields_ are added we'll be
428436
/// initialized and these values will be replaced.
429437
/// </summary>
438+
[MemberNotNull(nameof(_size), nameof(_alignment))]
430439
private void EnsureSizeAndAlignment() {
431440
Debug.Assert(_size.HasValue == _alignment.HasValue);
432-
// these are always iniitalized together
441+
// these are always initialized together
433442
if (_size == null) {
434443
lock (this) {
435444
if (_size == null) {
436-
int size, alignment;
437-
GetBaseSizeAlignmentAndFields(out size, out alignment);
445+
GetBaseSizeAlignmentAndFields(out int size, out int alignment);
438446
_size = size;
439447
_alignment = alignment;
440448
}
441449
}
442450
}
451+
if (_alignment == null) {
452+
throw new InvalidOperationException("size and alignment should always be initialized together");
453+
}
443454
}
444455
}
445456
}

0 commit comments

Comments
 (0)