Skip to content

Commit cb232a0

Browse files
committed
Minor fix
1 parent 009e9cf commit cb232a0

File tree

2 files changed

+63
-20
lines changed

2 files changed

+63
-20
lines changed

src/embed_tests/EnumTests.cs

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ namespace Python.EmbeddingTest
99
{
1010
public class EnumTests
1111
{
12+
private static VerticalDirection[] VerticalDirectionEnumValues = Enum.GetValues<VerticalDirection>();
13+
private static HorizontalDirection[] HorizontalDirectionEnumValues = Enum.GetValues<HorizontalDirection>();
14+
1215
[OneTimeSetUp]
1316
public void SetUp()
1417
{
@@ -301,11 +304,10 @@ public static IEnumerable<TestCaseData> SameEnumTypeComparisonOperatorsTestCases
301304
get
302305
{
303306
var operators = new[] { "==", "!=", "<", "<=", ">", ">=" };
304-
var enumValues = Enum.GetValues<VerticalDirection>();
305307

306-
foreach (var enumValue in enumValues)
308+
foreach (var enumValue in VerticalDirectionEnumValues)
307309
{
308-
foreach (var enumValue2 in enumValues)
310+
foreach (var enumValue2 in VerticalDirectionEnumValues)
309311
{
310312
yield return new TestCaseData("==", enumValue, enumValue2, enumValue == enumValue2);
311313
yield return new TestCaseData("!=", enumValue, enumValue2, enumValue != enumValue2);
@@ -378,12 +380,10 @@ public static IEnumerable<TestCaseData> OtherEnumsComparisonOperatorsTestCases
378380
get
379381
{
380382
var operators = new[] { "==", "!=", "<", "<=", ">", ">=" };
381-
var enumValues = Enum.GetValues<VerticalDirection>();
382-
var enum2Values = Enum.GetValues<HorizontalDirection>();
383383

384-
foreach (var enumValue in enumValues)
384+
foreach (var enumValue in VerticalDirectionEnumValues)
385385
{
386-
foreach (var enum2Value in enum2Values)
386+
foreach (var enum2Value in HorizontalDirectionEnumValues)
387387
{
388388
var intEnumValue = Convert.ToInt64(enumValue);
389389
var intEnum2Value = Convert.ToInt64(enum2Value);
@@ -424,10 +424,9 @@ private static IEnumerable<TestCaseData> IdentityComparisonTestCases
424424
{
425425
get
426426
{
427-
var enumValues = Enum.GetValues<VerticalDirection>();
428-
foreach (var enumValue1 in enumValues)
427+
foreach (var enumValue1 in VerticalDirectionEnumValues)
429428
{
430-
foreach (var enumValue2 in enumValues)
429+
foreach (var enumValue2 in VerticalDirectionEnumValues)
431430
{
432431
if (enumValue2 != enumValue1)
433432
{
@@ -443,8 +442,9 @@ public void CSharpEnumsAreSingletonsInPthonAndIdentityComparisonWorks(VerticalDi
443442
{
444443
var enumValue1Str = $"{nameof(EnumTests)}.{nameof(VerticalDirection)}.{enumValue1}";
445444
var enumValue2Str = $"{nameof(EnumTests)}.{nameof(VerticalDirection)}.{enumValue2}";
445+
446446
using var _ = Py.GIL();
447-
var module = PyModule.FromString("TESTTTT", $@"
447+
var module = PyModule.FromString("CSharpEnumsAreSingletonsInPthonAndIdentityComparisonWorks", $@"
448448
from clr import AddReference
449449
AddReference(""Python.EmbeddingTest"")
450450
@@ -495,5 +495,54 @@ def are_not_same4():
495495
Assert.IsTrue(module.InvokeMethod("are_not_same3").As<bool>());
496496
Assert.IsTrue(module.InvokeMethod("are_not_same4").As<bool>());
497497
}
498+
499+
[Test]
500+
public void IdentityComparisonBetweenDifferentEnumTypesIsNeverTrue(
501+
[ValueSource(nameof(VerticalDirectionEnumValues))] VerticalDirection enumValue1,
502+
[ValueSource(nameof(HorizontalDirectionEnumValues))] HorizontalDirection enumValue2)
503+
{
504+
var enumValue1Str = $"{nameof(EnumTests)}.{nameof(VerticalDirection)}.{enumValue1}";
505+
var enumValue2Str = $"{nameof(EnumTests)}.{nameof(HorizontalDirection)}.{enumValue2}";
506+
507+
using var _ = Py.GIL();
508+
var module = PyModule.FromString("IdentityComparisonBetweenDifferentEnumTypesIsNeverTrue", $@"
509+
from clr import AddReference
510+
AddReference(""Python.EmbeddingTest"")
511+
512+
from Python.EmbeddingTest import *
513+
514+
enum_value1 = {enumValue1Str}
515+
enum_value2 = {enumValue2Str}
516+
517+
def are_same1():
518+
return {enumValue1Str} is {enumValue2Str}
519+
520+
def are_same2():
521+
return enum_value1 is {enumValue2Str}
522+
523+
def are_same3():
524+
return {enumValue2Str} is enum_value1
525+
526+
def are_same4():
527+
return enum_value2 is {enumValue1Str}
528+
529+
def are_same5():
530+
return {enumValue1Str} is enum_value2
531+
532+
def are_same6():
533+
return enum_value1 is enum_value2
534+
535+
def are_same7():
536+
return enum_value2 is enum_value1
537+
");
538+
539+
Assert.IsFalse(module.InvokeMethod("are_same1").As<bool>());
540+
Assert.IsFalse(module.InvokeMethod("are_same2").As<bool>());
541+
Assert.IsFalse(module.InvokeMethod("are_same3").As<bool>());
542+
Assert.IsFalse(module.InvokeMethod("are_same4").As<bool>());
543+
Assert.IsFalse(module.InvokeMethod("are_same5").As<bool>());
544+
Assert.IsFalse(module.InvokeMethod("are_same6").As<bool>());
545+
Assert.IsFalse(module.InvokeMethod("are_same7").As<bool>());
546+
}
498547
}
499548
}

src/runtime/Converter.cs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ internal class Converter
2424
/// so the `is` identity comparison operator works for C# enums as well.
2525
/// </summary>
2626

27-
private static readonly Dictionary<Type, Dictionary<object, PyObject>> _enumCache = new();
27+
private static readonly Dictionary<object, PyObject> _enumCache = new();
2828
private Converter()
2929
{
3030
}
@@ -235,15 +235,9 @@ internal static NewReference ToPython(object? value, Type type)
235235

236236
if (type.IsEnum)
237237
{
238-
if (!_enumCache.TryGetValue(type, out var cache))
238+
if (!_enumCache.TryGetValue(value, out var cachedValue))
239239
{
240-
cache = new();
241-
_enumCache[type] = cache;
242-
}
243-
244-
if (!cache.TryGetValue(value, out var cachedValue))
245-
{
246-
cache[value] = cachedValue = CLRObject.GetReference(value, type).MoveToPyObject();
240+
_enumCache[value] = cachedValue = CLRObject.GetReference(value, type).MoveToPyObject();
247241
}
248242

249243
return cachedValue.NewReferenceOrNull();

0 commit comments

Comments
 (0)