Skip to content

Commit 009e9cf

Browse files
committed
Use single cached reference for C# enum values in Python
Make C# enums work as singletons in Python so that the `is` identity comparison operator works for C# enums as well.
1 parent 84a1be3 commit 009e9cf

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

src/embed_tests/EnumTests.cs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,82 @@ def operation2():
418418

419419
Assert.AreEqual(expectedResult, module.InvokeMethod("operation1").As<bool>());
420420
Assert.AreEqual(invertedOperationExpectedResult, module.InvokeMethod("operation2").As<bool>());
421+
}
422+
423+
private static IEnumerable<TestCaseData> IdentityComparisonTestCases
424+
{
425+
get
426+
{
427+
var enumValues = Enum.GetValues<VerticalDirection>();
428+
foreach (var enumValue1 in enumValues)
429+
{
430+
foreach (var enumValue2 in enumValues)
431+
{
432+
if (enumValue2 != enumValue1)
433+
{
434+
yield return new TestCaseData(enumValue1, enumValue2);
435+
}
436+
}
437+
}
438+
}
439+
}
440+
441+
[TestCaseSource(nameof(IdentityComparisonTestCases))]
442+
public void CSharpEnumsAreSingletonsInPthonAndIdentityComparisonWorks(VerticalDirection enumValue1, VerticalDirection enumValue2)
443+
{
444+
var enumValue1Str = $"{nameof(EnumTests)}.{nameof(VerticalDirection)}.{enumValue1}";
445+
var enumValue2Str = $"{nameof(EnumTests)}.{nameof(VerticalDirection)}.{enumValue2}";
446+
using var _ = Py.GIL();
447+
var module = PyModule.FromString("TESTTTT", $@"
448+
from clr import AddReference
449+
AddReference(""Python.EmbeddingTest"")
450+
451+
from Python.EmbeddingTest import *
452+
453+
def are_same1():
454+
return {enumValue1Str} is {enumValue1Str}
455+
456+
def are_same2():
457+
enum_value = {enumValue1Str}
458+
return enum_value is {enumValue1Str}
459+
460+
def are_same3():
461+
enum_value = {enumValue1Str}
462+
return {enumValue1Str} is enum_value
463+
464+
def are_same4():
465+
enum_value1 = {enumValue1Str}
466+
enum_value2 = {enumValue1Str}
467+
return enum_value1 is enum_value2
468+
469+
def are_not_same1():
470+
return {enumValue1Str} is not {enumValue2Str}
471+
472+
def are_not_same2():
473+
enum_value = {enumValue1Str}
474+
return enum_value is not {enumValue2Str}
475+
476+
def are_not_same3():
477+
enum_value = {enumValue2Str}
478+
return {enumValue1Str} is not enum_value
479+
480+
def are_not_same4():
481+
enum_value1 = {enumValue1Str}
482+
enum_value2 = {enumValue2Str}
483+
return enum_value1 is not enum_value2
484+
485+
486+
");
487+
488+
Assert.IsTrue(module.InvokeMethod("are_same1").As<bool>());
489+
Assert.IsTrue(module.InvokeMethod("are_same2").As<bool>());
490+
Assert.IsTrue(module.InvokeMethod("are_same3").As<bool>());
491+
Assert.IsTrue(module.InvokeMethod("are_same4").As<bool>());
421492

493+
Assert.IsTrue(module.InvokeMethod("are_not_same1").As<bool>());
494+
Assert.IsTrue(module.InvokeMethod("are_not_same2").As<bool>());
495+
Assert.IsTrue(module.InvokeMethod("are_not_same3").As<bool>());
496+
Assert.IsTrue(module.InvokeMethod("are_not_same4").As<bool>());
422497
}
423498
}
424499
}

src/runtime/Converter.cs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ namespace Python.Runtime
1818
[SuppressUnmanagedCodeSecurity]
1919
internal class Converter
2020
{
21+
/// <summary>
22+
/// We use a cache of the enum values references so that we treat them as singletons in Python.
23+
/// We just try to mimic Python enums behavior, since Python enum values are singletons,
24+
/// so the `is` identity comparison operator works for C# enums as well.
25+
/// </summary>
26+
27+
private static readonly Dictionary<Type, Dictionary<object, PyObject>> _enumCache = new();
2128
private Converter()
2229
{
2330
}
@@ -228,7 +235,18 @@ internal static NewReference ToPython(object? value, Type type)
228235

229236
if (type.IsEnum)
230237
{
231-
return CLRObject.GetReference(value, type);
238+
if (!_enumCache.TryGetValue(type, out var cache))
239+
{
240+
cache = new();
241+
_enumCache[type] = cache;
242+
}
243+
244+
if (!cache.TryGetValue(value, out var cachedValue))
245+
{
246+
cache[value] = cachedValue = CLRObject.GetReference(value, type).MoveToPyObject();
247+
}
248+
249+
return cachedValue.NewReferenceOrNull();
232250
}
233251

234252
// it the type is a python subclass of a managed type then return the

0 commit comments

Comments
 (0)