Skip to content

Commit e0064ea

Browse files
committed
Support pythonic manipulation of managed enums.
Add support for 'len' method, 'in' operator and iteration of enum types.
1 parent dff82a1 commit e0064ea

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

src/embed_tests/ClassManagerTests.cs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,59 @@ def call(instance):
10031003
}
10041004

10051005
#endregion
1006+
1007+
public enum TestEnum
1008+
{
1009+
FirstEnumValue,
1010+
SecondEnumValue,
1011+
ThirdEnumValue
1012+
}
1013+
1014+
[Test]
1015+
public void EnumPythonOperationsCanBePerformedOnManagedEnum()
1016+
{
1017+
using (Py.GIL())
1018+
{
1019+
var module = PyModule.FromString("EnumPythonOperationsCanBePerformedOnManagedEnum", $@"
1020+
from clr import AddReference
1021+
AddReference(""Python.EmbeddingTest"")
1022+
1023+
from Python.EmbeddingTest import *
1024+
1025+
def get_enum_values():
1026+
return [x for x in ClassManagerTests.TestEnum]
1027+
1028+
def count_enum_values():
1029+
return len(ClassManagerTests.TestEnum)
1030+
1031+
def is_enum_value_defined(value):
1032+
return value in ClassManagerTests.TestEnum
1033+
");
1034+
1035+
using var pyEnumValues = module.InvokeMethod("get_enum_values");
1036+
var enumValues = pyEnumValues.As<List<TestEnum>>();
1037+
1038+
var expectedEnumValues = Enum.GetValues<TestEnum>();
1039+
CollectionAssert.AreEquivalent(expectedEnumValues, enumValues);
1040+
1041+
using var pyEnumCount = module.InvokeMethod("count_enum_values");
1042+
var enumCount = pyEnumCount.As<int>();
1043+
Assert.AreEqual(expectedEnumValues.Length, enumCount);
1044+
1045+
var validEnumValues = expectedEnumValues
1046+
.SelectMany(x => new object[] { x, (int)x, Enum.GetName(x.GetType(), x) })
1047+
.Select(x => (x, true));
1048+
var invalidEnumValues = new object[] { 5, "INVALID_ENUM_VALUE" }.Select(x => (x, false));
1049+
1050+
foreach (var (enumValue, isValid) in validEnumValues.Concat(invalidEnumValues))
1051+
{
1052+
using var pyEnumValue = enumValue.ToPython();
1053+
using var pyIsDefined = module.InvokeMethod("is_enum_value_defined", pyEnumValue);
1054+
var isDefined = pyIsDefined.As<bool>();
1055+
Assert.AreEqual(isValid, isDefined, $"Failed for {enumValue} ({enumValue.GetType()})");
1056+
}
1057+
}
1058+
}
10061059
}
10071060

10081061
public class NestedTestParent

src/runtime/Types/MetaType.cs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,5 +359,64 @@ public static NewReference __subclasscheck__(BorrowedReference tp, BorrowedRefer
359359
{
360360
return DoInstanceCheck(tp, args, true);
361361
}
362+
363+
/// <summary>
364+
/// Standard iteration support Enums. This allows natural interation
365+
/// over the available values an Enum defines.
366+
/// </summary>
367+
public static NewReference tp_iter(BorrowedReference tp)
368+
{
369+
var type = CheckAndGetEnumType(tp);
370+
var values = Enum.GetValues(type);
371+
return new Iterator(values.GetEnumerator(), type).Alloc();
372+
}
373+
374+
/// <summary>
375+
/// Implements __len__ for Enum types.
376+
/// </summary>
377+
public static int mp_length(BorrowedReference tp)
378+
{
379+
var type = CheckAndGetEnumType(tp);
380+
return Enum.GetValues(type).Length;
381+
}
382+
383+
/// <summary>
384+
/// Implements __contains__ for Enum types.
385+
/// </summary>
386+
public static int sq_contains(BorrowedReference tp, BorrowedReference v)
387+
{
388+
var type = CheckAndGetEnumType(tp);
389+
390+
if (!Converter.ToManaged(v, type, out var enumValue, false) &&
391+
!Converter.ToManaged(v, typeof(int), out enumValue, false) &&
392+
!Converter.ToManaged(v, typeof(string), out enumValue, false))
393+
{
394+
Exceptions.SetError(Exceptions.TypeError,
395+
$"invalid parameter type for sq_contains: should be {Converter.GetTypeByAlias(v)}, found {type}");
396+
}
397+
return Enum.IsDefined(type, enumValue) ? 1 : 0;
398+
}
399+
400+
private static Type CheckAndGetEnumType(BorrowedReference tp)
401+
{
402+
var cb = GetManagedObject(tp) as ClassBase;
403+
if (cb == null)
404+
{
405+
Exceptions.SetError(Exceptions.TypeError, "invalid object");
406+
}
407+
408+
if (!cb.type.Valid)
409+
{
410+
Exceptions.SetError(Exceptions.TypeError, "invalid type");
411+
}
412+
413+
var type = cb.type.Value;
414+
if (!type.IsEnum)
415+
{
416+
Exceptions.SetError(Exceptions.TypeError, "uniterable object");
417+
}
418+
419+
return type;
420+
}
362421
}
363422
}

0 commit comments

Comments
 (0)