Skip to content

Commit c75344c

Browse files
huangjiyiSigureMo
andauthored
[API Compatibility] Fix tensor __eq__ and __ne__ for unsupported type (#76118)
* [API Compatibility] Support tensor equals None usage * update test case * support __ne__ for unsupported type * support pir.Value * update * update * update --------- Co-authored-by: SigureMo <[email protected]>
1 parent 158ddaa commit c75344c

File tree

4 files changed

+40
-4
lines changed

4 files changed

+40
-4
lines changed

paddle/fluid/pybind/eager_math_op_patch.cc

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2204,8 +2204,16 @@ static PyObject* tensor__ne__method(TensorObject* self,
22042204
other_tensor = paddle::empty({}, phi::DataType::FLOAT32, place);
22052205
InitTensorWithNumpyValue(numpy_value, place, &other_tensor);
22062206
} else {
2207-
paddle::experimental::Scalar value =
2208-
CastPyArg2Scalar(other_obj, "__ne__", 0);
2207+
paddle::experimental::Scalar value;
2208+
2209+
// return True if other_obj is unsupported type
2210+
try {
2211+
value = CastPyArg2Scalar(other_obj, "__ne__", 0);
2212+
} catch (const ::common::enforce::EnforceNotMet& e) {
2213+
Py_INCREF(Py_True);
2214+
return Py_True;
2215+
}
2216+
22092217
if (PyComplex_Check(other_obj)) {
22102218
eager_gil_scoped_release guard;
22112219
other_tensor =
@@ -2297,8 +2305,16 @@ static PyObject* tensor__eq__method(TensorObject* self,
22972305
other_tensor = paddle::empty({}, phi::DataType::FLOAT32, place);
22982306
InitTensorWithNumpyValue(numpy_value, place, &other_tensor);
22992307
} else {
2300-
paddle::experimental::Scalar value =
2301-
CastPyArg2Scalar(other_obj, "__eq__", 0);
2308+
paddle::experimental::Scalar value;
2309+
2310+
// return False if other_obj is unsupported type
2311+
try {
2312+
value = CastPyArg2Scalar(other_obj, "__eq__", 0);
2313+
} catch (const ::common::enforce::EnforceNotMet& e) {
2314+
Py_INCREF(Py_False);
2315+
return Py_False;
2316+
}
2317+
23022318
if (PyComplex_Check(other_obj)) {
23032319
eager_gil_scoped_release guard;
23042320
other_tensor =

python/paddle/pir/math_op_patch.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,13 @@ def __impl__(self, other_var):
528528
# but only +, -, *, / can use this method
529529
if scalar_method is not None:
530530
return scalar_method(self, other_var)
531+
elif other_var is None:
532+
if method_name == "__eq__":
533+
return False
534+
elif method_name == "__ne__":
535+
return True
536+
else:
537+
pass
531538
else:
532539
# do nothing
533540
pass

test/legacy_test/test_eager_tensor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,17 @@ def test_tensor__format__(self):
15921592
paddle_scalar = paddle.uniform([], min=-100, max=100)
15931593
self.assertRaises(ValueError, paddle_scalar.__format__, "3d")
15941594

1595+
def test_tensor_eq_unsupported_type(self):
1596+
a = paddle.empty([2])
1597+
1598+
# Compare with None
1599+
self.assertFalse(a == None) # noqa: E711
1600+
self.assertTrue(a != None) # noqa: E711
1601+
1602+
# Compare with other obj
1603+
self.assertFalse(a == object())
1604+
self.assertTrue(a != object())
1605+
15951606

15961607
class TestEagerTensorSetitem(unittest.TestCase):
15971608
def func_setUp(self):

test/legacy_test/test_math_op_patch_pir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,8 @@ def _test(x_np, y_np, input_dtype):
451451
e = x != y
452452
f = x.not_equal(y)
453453
g = x.__ne__(y)
454+
self.assertFalse(x == None) # noqa: E711
455+
self.assertTrue(x != None) # noqa: E711
454456
(e_np, f_np, g_np) = exe.run(
455457
main_program,
456458
feed={"x": x_np, "y": y_np},

0 commit comments

Comments
 (0)