Skip to content

Commit 57c2299

Browse files
committed
More tests and cleanup
1 parent cb232a0 commit 57c2299

File tree

2 files changed

+132
-68
lines changed

2 files changed

+132
-68
lines changed

src/embed_tests/EnumTests.cs

Lines changed: 93 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public enum HorizontalDirection
4242
public void CSharpEnumsBehaveAsEnumsInPython()
4343
{
4444
using var _ = Py.GIL();
45-
var module = PyModule.FromString("CSharpEnumsBehaveAsEnumsInPython", $@"
45+
using var module = PyModule.FromString("CSharpEnumsBehaveAsEnumsInPython", $@"
4646
from clr import AddReference
4747
AddReference(""Python.EmbeddingTest"")
4848
@@ -103,7 +103,7 @@ def operation2():
103103
public void ArithmeticOperatorsWorkWithoutExplicitCast(string @operator, VerticalDirection operand1, double operand2, double expectedResult, double invertedOperationExpectedResult)
104104
{
105105
using var _ = Py.GIL();
106-
var module = GetTestOperatorsModule(@operator, operand1, operand2);
106+
using var module = GetTestOperatorsModule(@operator, operand1, operand2);
107107

108108
Assert.AreEqual(expectedResult, module.InvokeMethod("operation1").As<double>());
109109

@@ -178,7 +178,7 @@ public void ArithmeticOperatorsWorkWithoutExplicitCast(string @operator, Vertica
178178
public void IntComparisonOperatorsWorkWithoutExplicitCast(string @operator, VerticalDirection operand1, int operand2, bool expectedResult)
179179
{
180180
using var _ = Py.GIL();
181-
var module = GetTestOperatorsModule(@operator, operand1, operand2);
181+
using var module = GetTestOperatorsModule(@operator, operand1, operand2);
182182

183183
Assert.AreEqual(expectedResult, module.InvokeMethod("operation1").As<bool>());
184184

@@ -289,7 +289,7 @@ public void IntComparisonOperatorsWorkWithoutExplicitCast(string @operator, Vert
289289
public void FloatComparisonOperatorsWorkWithoutExplicitCast(string @operator, VerticalDirection operand1, double operand2, bool expectedResult)
290290
{
291291
using var _ = Py.GIL();
292-
var module = GetTestOperatorsModule(@operator, operand1, operand2);
292+
using var module = GetTestOperatorsModule(@operator, operand1, operand2);
293293

294294
Assert.AreEqual(expectedResult, module.InvokeMethod("operation1").As<bool>());
295295

@@ -324,7 +324,7 @@ public static IEnumerable<TestCaseData> SameEnumTypeComparisonOperatorsTestCases
324324
public void SameEnumTypeComparisonOperatorsWorkWithoutExplicitCast(string @operator, VerticalDirection operand1, VerticalDirection operand2, bool expectedResult)
325325
{
326326
using var _ = Py.GIL();
327-
var module = PyModule.FromString("SameEnumTypeComparisonOperatorsWorkWithoutExplicitCast", $@"
327+
using var module = PyModule.FromString("SameEnumTypeComparisonOperatorsWorkWithoutExplicitCast", $@"
328328
from clr import AddReference
329329
AddReference(""Python.EmbeddingTest"")
330330
@@ -358,7 +358,7 @@ def operation():
358358
public void EnumComparisonOperatorsWorkWithString(string @operator, VerticalDirection operand1, string operand2, bool expectedResult)
359359
{
360360
using var _ = Py.GIL();
361-
var module = PyModule.FromString("EnumComparisonOperatorsWorkWithString", $@"
361+
using var module = PyModule.FromString("EnumComparisonOperatorsWorkWithString", $@"
362362
from clr import AddReference
363363
AddReference(""Python.EmbeddingTest"")
364364
@@ -403,7 +403,7 @@ public static IEnumerable<TestCaseData> OtherEnumsComparisonOperatorsTestCases
403403
public void OtherEnumsComparisonOperatorsWorkWithoutExplicitCast(string @operator, VerticalDirection operand1, HorizontalDirection operand2, bool expectedResult, bool invertedOperationExpectedResult)
404404
{
405405
using var _ = Py.GIL();
406-
var module = PyModule.FromString("OtherEnumsComparisonOperatorsWorkWithoutExplicitCast", $@"
406+
using var module = PyModule.FromString("OtherEnumsComparisonOperatorsWorkWithoutExplicitCast", $@"
407407
from clr import AddReference
408408
AddReference(""Python.EmbeddingTest"")
409409
@@ -444,7 +444,7 @@ public void CSharpEnumsAreSingletonsInPthonAndIdentityComparisonWorks(VerticalDi
444444
var enumValue2Str = $"{nameof(EnumTests)}.{nameof(VerticalDirection)}.{enumValue2}";
445445

446446
using var _ = Py.GIL();
447-
var module = PyModule.FromString("CSharpEnumsAreSingletonsInPthonAndIdentityComparisonWorks", $@"
447+
using var module = PyModule.FromString("CSharpEnumsAreSingletonsInPthonAndIdentityComparisonWorks", $@"
448448
from clr import AddReference
449449
AddReference(""Python.EmbeddingTest"")
450450
@@ -505,7 +505,7 @@ public void IdentityComparisonBetweenDifferentEnumTypesIsNeverTrue(
505505
var enumValue2Str = $"{nameof(EnumTests)}.{nameof(HorizontalDirection)}.{enumValue2}";
506506

507507
using var _ = Py.GIL();
508-
var module = PyModule.FromString("IdentityComparisonBetweenDifferentEnumTypesIsNeverTrue", $@"
508+
using var module = PyModule.FromString("IdentityComparisonBetweenDifferentEnumTypesIsNeverTrue", $@"
509509
from clr import AddReference
510510
AddReference(""Python.EmbeddingTest"")
511511
@@ -544,5 +544,89 @@ def are_same7():
544544
Assert.IsFalse(module.InvokeMethod("are_same6").As<bool>());
545545
Assert.IsFalse(module.InvokeMethod("are_same7").As<bool>());
546546
}
547+
548+
private PyModule GetCSharpObjectsComparisonTestModule(string @operator)
549+
{
550+
return PyModule.FromString("GetCSharpObjectsComparisonTestModule", $@"
551+
from clr import AddReference
552+
AddReference(""Python.EmbeddingTest"")
553+
554+
from Python.EmbeddingTest import *
555+
556+
enum_value = {nameof(EnumTests)}.{nameof(VerticalDirection)}.{VerticalDirection.Up}
557+
558+
def compare_with_none1():
559+
return enum_value {@operator} None
560+
561+
def compare_with_none2():
562+
return None {@operator} enum_value
563+
564+
def compare_with_csharp_object1(csharp_object):
565+
return enum_value {@operator} csharp_object
566+
567+
def compare_with_csharp_object2(csharp_object):
568+
return csharp_object {@operator} enum_value
569+
");
570+
}
571+
572+
[TestCase("==", false)]
573+
[TestCase("!=", true)]
574+
public void EqualityComparisonWithNull(string @operator, bool expectedResult)
575+
{
576+
using var _ = Py.GIL();
577+
using var module = GetCSharpObjectsComparisonTestModule(@operator);
578+
579+
Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_none1").As<bool>());
580+
Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_none2").As<bool>());
581+
582+
using var pyNull = ((TestClass)null).ToPython();
583+
Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_csharp_object1", pyNull).As<bool>());
584+
Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_csharp_object2", pyNull).As<bool>());
585+
}
586+
587+
[Test]
588+
public void SortingComparisonWithNullThrows([Values("<", "<=", ">", ">=")] string @operator)
589+
{
590+
using var _ = Py.GIL();
591+
using var module = GetCSharpObjectsComparisonTestModule(@operator);
592+
593+
using var pyNull = ((TestClass)null).ToPython();
594+
595+
var exception = Assert.Throws<PythonException>(() => module.InvokeMethod("compare_with_csharp_object1", pyNull));
596+
Assert.IsTrue(exception.Message.Contains("Cannot compare"));
597+
Assert.IsTrue(exception.Message.Contains("with null"));
598+
}
599+
600+
private static IEnumerable<TestCaseData> ComparisonWithNonEnumObjectsTestCases
601+
{
602+
get
603+
{
604+
foreach (var op in new[] { "==", "!=" })
605+
{
606+
yield return new TestCaseData(op, new[] { "No method matched to compare" });
607+
}
608+
609+
foreach (var op in new[] { "<", "<=", ">", ">=" })
610+
{
611+
yield return new TestCaseData(op, new[] { "Cannot compare", "with null" });
612+
}
613+
}
614+
}
615+
616+
[Test]
617+
public void ComparisonOperatorsWithNonEnumObjectsThrows([Values("==", "!=", "<", "<=", ">", ">=")] string @operator)
618+
{
619+
using var _ = Py.GIL();
620+
using var module = GetCSharpObjectsComparisonTestModule(@operator);
621+
622+
using var pyCSharpObject = new TestClass().ToPython();
623+
624+
var exception = Assert.Throws<PythonException>(() => module.InvokeMethod("compare_with_csharp_object1", pyCSharpObject));
625+
Assert.IsTrue(exception.Message.Contains("No method matched"), $"Expected exception message to contain 'No method matched' but got: {exception.Message}");
626+
}
627+
628+
public class TestClass
629+
{
630+
}
547631
}
548632
}

src/runtime/Util/OpsHelper.cs

Lines changed: 39 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -492,26 +492,43 @@ public static bool op_Inequality(string a, T b)
492492

493493
#region Other Enum comparison operators
494494

495-
private static bool IsEnum(object b, out Type type)
496-
{
495+
private static bool IsEnum(object b, out Type type, out Type underlyingType, bool throwOnNull = false)
496+
{
497+
type = null;
498+
underlyingType = null;
499+
if (b == null)
500+
{
501+
if (throwOnNull)
502+
{
503+
using (Py.GIL())
504+
{
505+
Exceptions.RaiseTypeError($"Cannot compare {typeof(T).Name} with null");
506+
PythonException.ThrowLastAsClrException();
507+
}
508+
}
509+
return false;
510+
}
511+
497512
type = b.GetType();
498513
if (type.IsEnum)
499514
{
515+
underlyingType = type.GetEnumUnderlyingType();
500516
return true;
501517
}
502518
using var _ = Py.GIL();
503519
Exceptions.RaiseTypeError($"No method matched to compare {typeof(T).Name} and {type.Name}");
520+
PythonException.ThrowLastAsClrException();
504521

505522
return false;
506523
}
507524

508525
public static bool op_Equality(T a, object b)
509526
{
510-
if (!IsEnum(b, out var bType))
527+
if (!IsEnum(b, out var bType, out var underlyingType))
511528
{
512529
return false;
513530
}
514-
if (bType.GetEnumUnderlyingType() == typeof(UInt64))
531+
if (underlyingType == typeof(UInt64))
515532
{
516533
return op_Equality(a, Convert.ToUInt64(b));
517534
}
@@ -520,16 +537,7 @@ public static bool op_Equality(T a, object b)
520537

521538
public static bool op_Equality(object a, T b)
522539
{
523-
if (!IsEnum(a, out var aType))
524-
{
525-
return false;
526-
}
527-
528-
if (aType.GetEnumUnderlyingType() == typeof(UInt64))
529-
{
530-
return op_Equality(b, Convert.ToUInt64(a));
531-
}
532-
return op_Equality(b, Convert.ToInt64(a));
540+
return op_Equality(b, a);
533541
}
534542

535543
public static bool op_Inequality(T a, object b)
@@ -539,16 +547,17 @@ public static bool op_Inequality(T a, object b)
539547

540548
public static bool op_Inequality(object a, T b)
541549
{
542-
return !op_Equality(a, b);
550+
return !op_Equality(b, a);
543551
}
544552

545553
public static bool op_LessThan(T a, object b)
546554
{
547-
if (!IsEnum(b, out var bType))
555+
if (!IsEnum(b, out var bType, out var underlyingType, throwOnNull: true))
548556
{
557+
// False although it means nothing: an exception will be raised
549558
return false;
550559
}
551-
if (bType.GetEnumUnderlyingType() == typeof(UInt64))
560+
if (underlyingType == typeof(UInt64))
552561
{
553562
return op_LessThan(a, Convert.ToUInt64(b));
554563
}
@@ -557,24 +566,17 @@ public static bool op_LessThan(T a, object b)
557566

558567
public static bool op_LessThan(object a, T b)
559568
{
560-
if (!IsEnum(a, out var aType))
561-
{
562-
return false;
563-
}
564-
if (aType.GetEnumUnderlyingType() == typeof(UInt64))
565-
{
566-
return op_LessThan(Convert.ToUInt64(a), b);
567-
}
568-
return op_LessThan(Convert.ToInt64(a), b);
569+
return op_GreaterThan(b, a);
569570
}
570571

571572
public static bool op_GreaterThan(T a, object b)
572573
{
573-
if (!IsEnum(b, out var bType))
574+
if (!IsEnum(b, out var bType, out var underlyingType, throwOnNull: true))
574575
{
576+
// False although it means nothing: an exception will be raised
575577
return false;
576578
}
577-
if (bType.GetEnumUnderlyingType() == typeof(UInt64))
579+
if (underlyingType == typeof(UInt64))
578580
{
579581
return op_GreaterThan(a, Convert.ToUInt64(b));
580582
}
@@ -583,24 +585,17 @@ public static bool op_GreaterThan(T a, object b)
583585

584586
public static bool op_GreaterThan(object a, T b)
585587
{
586-
if (!IsEnum(a, out var aType))
587-
{
588-
return false;
589-
}
590-
if (aType.GetEnumUnderlyingType() == typeof(UInt64))
591-
{
592-
return op_GreaterThan(Convert.ToUInt64(a), b);
593-
}
594-
return op_GreaterThan(Convert.ToInt64(a), b);
588+
return op_LessThan(b, a);
595589
}
596590

597591
public static bool op_LessThanOrEqual(T a, object b)
598592
{
599-
if (!IsEnum(b, out var bType))
593+
if (!IsEnum(b, out var bType, out var underlyingType, throwOnNull: true))
600594
{
595+
// False although it means nothing: an exception will be raised
601596
return false;
602597
}
603-
if (bType.GetEnumUnderlyingType() == typeof(UInt64))
598+
if (underlyingType == typeof(UInt64))
604599
{
605600
return op_LessThanOrEqual(a, Convert.ToUInt64(b));
606601
}
@@ -609,24 +604,17 @@ public static bool op_LessThanOrEqual(T a, object b)
609604

610605
public static bool op_LessThanOrEqual(object a, T b)
611606
{
612-
if (!IsEnum(a, out var aType))
613-
{
614-
return false;
615-
}
616-
if (aType.GetEnumUnderlyingType() == typeof(UInt64))
617-
{
618-
return op_LessThanOrEqual(Convert.ToUInt64(a), b);
619-
}
620-
return op_LessThanOrEqual(Convert.ToInt64(a), b);
607+
return op_GreaterThanOrEqual(b, a);
621608
}
622609

623610
public static bool op_GreaterThanOrEqual(T a, object b)
624611
{
625-
if (!IsEnum(b, out var bType))
612+
if (!IsEnum(b, out var bType, out var underlyingType, throwOnNull: true))
626613
{
614+
// False although it means nothing: an exception will be raised
627615
return false;
628616
}
629-
if (bType.GetEnumUnderlyingType() == typeof(UInt64))
617+
if (underlyingType == typeof(UInt64))
630618
{
631619
return op_GreaterThanOrEqual(a, Convert.ToUInt64(b));
632620
}
@@ -635,15 +623,7 @@ public static bool op_GreaterThanOrEqual(T a, object b)
635623

636624
public static bool op_GreaterThanOrEqual(object a, T b)
637625
{
638-
if (!IsEnum(a, out var aType))
639-
{
640-
return false;
641-
}
642-
if (aType.GetEnumUnderlyingType() == typeof(UInt64))
643-
{
644-
return op_GreaterThanOrEqual(Convert.ToUInt64(a), b);
645-
}
646-
return op_GreaterThanOrEqual(Convert.ToInt64(a), b);
626+
return op_LessThanOrEqual(b, a);
647627
}
648628

649629
#endregion

0 commit comments

Comments
 (0)