Skip to content

Commit b0ec6e8

Browse files
[cherry pick]add warning message when dtypes of operator are not same (#31136) (#31175)
ATT, cherry-pick #31136
1 parent a19154c commit b0ec6e8

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

python/paddle/fluid/dygraph/math_op_patch.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import numpy as np
2323
import six
24+
import warnings
2425

2526
_supported_int_dtype_ = [
2627
core.VarDesc.VarType.UINT8,
@@ -51,6 +52,11 @@
5152
'__matmul__',
5253
]
5354

55+
_complex_dtypes = [
56+
core.VarDesc.VarType.COMPLEX64,
57+
core.VarDesc.VarType.COMPLEX128,
58+
]
59+
5460
_already_patch_varbase = False
5561

5662

@@ -214,7 +220,9 @@ def __impl__(self, other_var):
214220
# 3. promote types or unify right var type to left var
215221
rhs_dtype = other_var.dtype
216222
if lhs_dtype != rhs_dtype:
217-
if method_name in _supported_promote_complex_types_:
223+
if method_name in _supported_promote_complex_types_ and (
224+
lhs_dtype in _complex_dtypes or
225+
rhs_dtype in _complex_dtypes):
218226
# only when lhs_dtype or rhs_dtype is complex type,
219227
# the dtype will promote, in other cases, directly
220228
# use lhs_dtype, this is consistent will original rule
@@ -225,6 +233,9 @@ def __impl__(self, other_var):
225233
other_var = other_var if rhs_dtype == promote_dtype else astype(
226234
other_var, promote_dtype)
227235
else:
236+
warnings.warn(
237+
'The dtype of left and right variables are not the same, left dtype is {}, but right dtype is {}, the right dtype will convert to {}'.
238+
format(lhs_dtype, rhs_dtype, lhs_dtype))
228239
other_var = astype(other_var, lhs_dtype)
229240

230241
if reverse:
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function, division
16+
17+
import unittest
18+
import numpy as np
19+
import warnings
20+
import paddle
21+
22+
23+
class TestTensorTypePromotion(unittest.TestCase):
24+
def setUp(self):
25+
self.x = paddle.to_tensor([2, 3])
26+
self.y = paddle.to_tensor([1.0, 2.0])
27+
28+
def test_operator(self):
29+
with warnings.catch_warnings(record=True) as context:
30+
warnings.simplefilter("always")
31+
self.x + self.y
32+
self.assertTrue(
33+
"The dtype of left and right variables are not the same" in
34+
str(context[-1].message))
35+
36+
with warnings.catch_warnings(record=True) as context:
37+
warnings.simplefilter("always")
38+
self.x - self.y
39+
self.assertTrue(
40+
"The dtype of left and right variables are not the same" in
41+
str(context[-1].message))
42+
43+
with warnings.catch_warnings(record=True) as context:
44+
warnings.simplefilter("always")
45+
self.x * self.y
46+
self.assertTrue(
47+
"The dtype of left and right variables are not the same" in
48+
str(context[-1].message))
49+
50+
with warnings.catch_warnings(record=True) as context:
51+
warnings.simplefilter("always")
52+
self.x / self.y
53+
self.assertTrue(
54+
"The dtype of left and right variables are not the same" in
55+
str(context[-1].message))
56+
57+
58+
if __name__ == '__main__':
59+
unittest.main()

0 commit comments

Comments
 (0)