Skip to content

Commit 4b5b7a1

Browse files
committed
Fix comparison to null/None
1 parent 98828a1 commit 4b5b7a1

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed

src/embed_tests/EnumTests.cs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,5 +544,85 @@ def are_same7():
544544
Assert.IsFalse(module.InvokeMethod("are_same6").As<bool>());
545545
Assert.IsFalse(module.InvokeMethod("are_same7").As<bool>());
546546
}
547+
548+
private PyModule GetCSharpObjectsComparisonTestModule(string @operator)
549+
{
550+
return PyModule.FromString("GetCSharpObjectsComparisonTestModule", $@"
551+
from clr import AddReference
552+
AddReference(""Python.EmbeddingTest"")
553+
554+
from Python.EmbeddingTest import *
555+
556+
enum_value = {nameof(EnumTests)}.{nameof(VerticalDirection)}.{VerticalDirection.Up}
557+
558+
def compare_with_none1():
559+
return enum_value {@operator} None
560+
561+
def compare_with_none2():
562+
return None {@operator} enum_value
563+
564+
def compare_with_csharp_object1(csharp_object):
565+
return enum_value {@operator} csharp_object
566+
567+
def compare_with_csharp_object2(csharp_object):
568+
return csharp_object {@operator} enum_value
569+
");
570+
}
571+
572+
[TestCase("==", false)]
573+
[TestCase("!=", true)]
574+
public void EqualityComparisonWithNull(string @operator, bool expectedResult)
575+
{
576+
using var _ = Py.GIL();
577+
using var module = GetCSharpObjectsComparisonTestModule(@operator);
578+
579+
Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_none1").As<bool>());
580+
Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_none2").As<bool>());
581+
582+
using var pyNull = ((TestClass)null).ToPython();
583+
Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_csharp_object1", pyNull).As<bool>());
584+
Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_csharp_object2", pyNull).As<bool>());
585+
}
586+
587+
[TestCase("==", false)]
588+
[TestCase("!=", true)]
589+
public void ComparisonOperatorsWithNonEnumObjectsThrows(string @operator, bool expectedResult)
590+
{
591+
using var _ = Py.GIL();
592+
using var module = GetCSharpObjectsComparisonTestModule(@operator);
593+
594+
using var pyCSharpObject = new TestClass().ToPython();
595+
Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_csharp_object1", pyCSharpObject).As<bool>());
596+
Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_csharp_object2", pyCSharpObject).As<bool>());
597+
}
598+
599+
[Test]
600+
public void ThrowsOnObjectComparisonOperators([Values("<", "<=", ">", ">=")] string @operator)
601+
{
602+
using var _ = Py.GIL();
603+
using var module = GetCSharpObjectsComparisonTestModule(@operator);
604+
605+
using var pyCSharpObject = new TestClass().ToPython();
606+
Assert.Throws<PythonException>(() => module.InvokeMethod("compare_with_csharp_object1", pyCSharpObject));
607+
Assert.Throws<PythonException>(() => module.InvokeMethod("compare_with_csharp_object2", pyCSharpObject));
608+
}
609+
610+
[Test]
611+
public void ThrowsOnNullComparisonOperators([Values("<", "<=", ">", ">=")] string @operator)
612+
{
613+
using var _ = Py.GIL();
614+
using var module = GetCSharpObjectsComparisonTestModule(@operator);
615+
616+
Assert.Throws<PythonException>(() => module.InvokeMethod("compare_with_none1").As<bool>());
617+
Assert.Throws<PythonException>(() => module.InvokeMethod("compare_with_none2").As<bool>());
618+
619+
using var pyNull = ((TestClass)null).ToPython();
620+
Assert.Throws<PythonException>(() => module.InvokeMethod("compare_with_csharp_object1", pyNull));
621+
Assert.Throws<PythonException>(() => module.InvokeMethod("compare_with_csharp_object2", pyNull));
622+
}
623+
624+
public class TestClass
625+
{
626+
}
547627
}
548628
}

src/runtime/Util/OpsHelper.cs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using System;
22
using System.Linq.Expressions;
33
using System.Reflection;
4+
using System.Runtime.CompilerServices;
5+
46
using static Python.Runtime.OpsHelper;
57

68
namespace Python.Runtime
@@ -444,6 +446,11 @@ public static bool op_Inequality(string a, T b)
444446

445447
public static bool op_Equality(T a, Enum b)
446448
{
449+
if (b == null)
450+
{
451+
return false;
452+
}
453+
447454
if (b.GetType().GetEnumUnderlyingType() == typeof(UInt64))
448455
{
449456
return op_Equality(a, Convert.ToUInt64(b));
@@ -466,8 +473,23 @@ public static bool op_Inequality(Enum a, T b)
466473
return !op_Equality(b, a);
467474
}
468475

476+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
477+
private static void ThrowOnNull(object obj, string @operator)
478+
{
479+
if (obj == null)
480+
{
481+
using (Py.GIL())
482+
{
483+
Exceptions.RaiseTypeError($"'{@operator}' not supported between instances of '{typeof(T).Name}' and null/None");
484+
PythonException.ThrowLastAsClrException();
485+
}
486+
}
487+
}
488+
469489
public static bool op_LessThan(T a, Enum b)
470490
{
491+
ThrowOnNull(b, "<");
492+
471493
if (b.GetType().GetEnumUnderlyingType() == typeof(UInt64))
472494
{
473495
return op_LessThan(a, Convert.ToUInt64(b));
@@ -477,11 +499,14 @@ public static bool op_LessThan(T a, Enum b)
477499

478500
public static bool op_LessThan(Enum a, T b)
479501
{
502+
ThrowOnNull(a, "<");
480503
return op_GreaterThan(b, a);
481504
}
482505

483506
public static bool op_GreaterThan(T a, Enum b)
484507
{
508+
ThrowOnNull(b, ">");
509+
485510
if (b.GetType().GetEnumUnderlyingType() == typeof(UInt64))
486511
{
487512
return op_GreaterThan(a, Convert.ToUInt64(b));
@@ -491,11 +516,14 @@ public static bool op_GreaterThan(T a, Enum b)
491516

492517
public static bool op_GreaterThan(Enum a, T b)
493518
{
519+
ThrowOnNull(a, ">");
494520
return op_LessThan(b, a);
495521
}
496522

497523
public static bool op_LessThanOrEqual(T a, Enum b)
498524
{
525+
ThrowOnNull(b, "<=");
526+
499527
if (b.GetType().GetEnumUnderlyingType() == typeof(UInt64))
500528
{
501529
return op_LessThanOrEqual(a, Convert.ToUInt64(b));
@@ -505,11 +533,14 @@ public static bool op_LessThanOrEqual(T a, Enum b)
505533

506534
public static bool op_LessThanOrEqual(Enum a, T b)
507535
{
536+
ThrowOnNull(a, "<=");
508537
return op_GreaterThanOrEqual(b, a);
509538
}
510539

511540
public static bool op_GreaterThanOrEqual(T a, Enum b)
512541
{
542+
ThrowOnNull(b, ">=");
543+
513544
if (b.GetType().GetEnumUnderlyingType() == typeof(UInt64))
514545
{
515546
return op_GreaterThanOrEqual(a, Convert.ToUInt64(b));
@@ -519,9 +550,34 @@ public static bool op_GreaterThanOrEqual(T a, Enum b)
519550

520551
public static bool op_GreaterThanOrEqual(Enum a, T b)
521552
{
553+
ThrowOnNull(a, ">=");
522554
return op_LessThanOrEqual(b, a);
523555
}
524556

525557
#endregion
558+
559+
#region Object equality operators
560+
561+
public static bool op_Equality(T a, object b)
562+
{
563+
return false;
564+
}
565+
566+
public static bool op_Equality(object a, T b)
567+
{
568+
return false;
569+
}
570+
571+
public static bool op_Inequality(T a, object b)
572+
{
573+
return true;
574+
}
575+
576+
public static bool op_Inequality(object a, T b)
577+
{
578+
return true;
579+
}
580+
581+
#endregion
526582
}
527583
}

0 commit comments

Comments
 (0)