Skip to content

Commit 27c2a37

Browse files
committed
fix a bug: When x is a scalar, the dtype returned by paddle.where is fixed to float64.
1 parent 2a78ac9 commit 27c2a37

File tree

2 files changed

+142
-6
lines changed

2 files changed

+142
-6
lines changed

python/paddle/tensor/search.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ def where(
796796
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
797797
798798
Returns:
799-
Tensor, A Tensor with the same shape as :attr:`condition` and same data type as :attr:`x` and :attr:`y`.
799+
Tensor, A Tensor with the same shape as :attr:`condition` and same data type as :attr:`x` and :attr:`y`. If :attr:`x` and :attr:`y` have different data types, type promotion rules will be applied (see `Auto Type Promotion <https://www.paddlepaddle.org.cn/documentation/docs/en/develop/guides/advanced/auto_type_promotion_en.html#introduction-to-data-type-promotion>`_).
800800
801801
Examples:
802802
@@ -814,15 +814,14 @@ def where(
814814
815815
>>> out = paddle.where(x>1)
816816
>>> print(out)
817-
(Tensor(shape=[2, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
818-
[[2],
819-
[3]]),)
817+
(Tensor(shape=[2], dtype=int64, place=Place(cpu), stop_gradient=True,
818+
[2, 3]),)
820819
"""
821820
if np.isscalar(x):
822-
x = paddle.full([1], x, np.array([x]).dtype.name)
821+
x = paddle.to_tensor(x)
823822

824823
if np.isscalar(y):
825-
y = paddle.full([1], y, np.array([y]).dtype.name)
824+
y = paddle.to_tensor(y)
826825

827826
if x is None and y is None:
828827
return nonzero(condition, as_tuple=True)

test/tensor/test_search.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle
18+
19+
20+
class TestSearchAPIs(unittest.TestCase):
21+
def __init__(self, method_name='runTest'):
22+
super().__init__(method_name)
23+
self.con = None
24+
self.con_2D = None
25+
26+
def setUp(self):
27+
self.con = paddle.to_tensor([0.4, 0.3, 0.6, 0.7], dtype="float32")
28+
self.con_2D = paddle.rand([4, 4], dtype='float32')
29+
30+
def test_where_with_float16_scalar(self):
31+
# TODO(hanchoa): Do not support float16 with cpu.
32+
pass
33+
34+
def test_where_with_bfloat16_scalar(self):
35+
# TODO(hanchoa): Do not support bfloat16 with cpu.
36+
pass
37+
38+
def test_where_with_float32_scalar(self):
39+
x = paddle.to_tensor([0.0, 0.0, 0.0, 0.0], dtype="float32")
40+
y = paddle.to_tensor([0.1, 0.1, 0.1, 0.1], dtype="float32")
41+
42+
res = paddle.where(self.con > 0.5, x, y)
43+
self.assertEqual(res.dtype, paddle.float32)
44+
45+
res = paddle.where(self.con > 0.5, 0.5, y)
46+
self.assertEqual(res.dtype, paddle.float32)
47+
48+
res = paddle.where(self.con > 0.5, x, 0.6)
49+
self.assertEqual(res.dtype, paddle.float32)
50+
51+
res = paddle.where(self.con > 0.5, 0.5, 0.6)
52+
self.assertEqual(res.dtype, paddle.float32)
53+
54+
def test_where_with_float64_scalar(self):
55+
x = paddle.to_tensor([0.0, 0.0, 0.0, 0.0], dtype="float64")
56+
y = paddle.to_tensor([0.1, 0.1, 0.1, 0.1], dtype="float64")
57+
58+
res = paddle.where(self.con > 0.5, x, y)
59+
self.assertEqual(res.dtype, paddle.float64)
60+
61+
res = paddle.where(self.con > 0.5, 0.5, y)
62+
self.assertEqual(res.dtype, paddle.float64)
63+
64+
res = paddle.where(self.con > 0.5, x, 0.6)
65+
self.assertEqual(res.dtype, paddle.float64)
66+
67+
res = paddle.where(self.con > 0.5, 0.5, 0.6)
68+
self.assertEqual(res.dtype, paddle.float32)
69+
70+
def test_where_with_complex64_scalar(self):
71+
x = paddle.to_tensor([0.0, 0.0, 0.0, 0.0], dtype="complex64")
72+
y = paddle.to_tensor([0.1, 0.1, 0.1, 0.1], dtype="complex64")
73+
74+
res = paddle.where(self.con > 0.5, x, y)
75+
self.assertEqual(res.dtype, paddle.complex64)
76+
77+
res = paddle.where(self.con > 0.5, 0.5, y)
78+
self.assertEqual(res.dtype, paddle.complex64)
79+
80+
res = paddle.where(self.con > 0.5, x, 0.6)
81+
self.assertEqual(res.dtype, paddle.complex64)
82+
83+
res = paddle.where(self.con > 0.5, 0.5, 0.6)
84+
self.assertEqual(res.dtype, paddle.float32)
85+
86+
def test_where_with_complex128_scalar(self):
87+
x = paddle.to_tensor([0.0, 0.0, 0.0, 0.0], dtype="complex128")
88+
y = paddle.to_tensor([0.1, 0.1, 0.1, 0.1], dtype="complex128")
89+
90+
res = paddle.where(self.con > 0.5, x, y)
91+
self.assertEqual(res.dtype, paddle.complex128)
92+
93+
res = paddle.where(self.con > 0.5, 0.5, y)
94+
self.assertEqual(res.dtype, paddle.complex128)
95+
96+
res = paddle.where(self.con > 0.5, x, 0.6)
97+
self.assertEqual(res.dtype, paddle.complex128)
98+
99+
res = paddle.where(self.con > 0.5, 0.5, 0.6)
100+
self.assertEqual(res.dtype, paddle.float32)
101+
102+
def test_where_with_int_scalar(self):
103+
x = paddle.to_tensor([2, 2, 2, 2], dtype="int32")
104+
y = paddle.to_tensor([3, 3, 3, 3], dtype="int32")
105+
106+
res = paddle.where(self.con > 0.5, x, y)
107+
self.assertEqual(res.dtype, paddle.int32)
108+
109+
# TODO(hanchao): Do not support int type promotion yet.
110+
# res = paddle.where(self.con > 0.5, 3, y)
111+
# self.assertEqual(res.dtype, paddle.int32)
112+
113+
# res = paddle.where(self.con > 0.5, x, 4)
114+
# self.assertEqual(res.dtype, paddle.int32)
115+
#
116+
# res = paddle.where(self.con > 0.5, 3, 4)
117+
# self.assertEqual(res.dtype, paddle.int32)
118+
119+
def test_where_with_float32_scalar_2D(self):
120+
x = paddle.to_tensor([0.0, 0.0, 0.0, 0.0], dtype="float32")
121+
y = paddle.to_tensor([0.1, 0.1, 0.1, 0.1], dtype="float32")
122+
123+
res = paddle.where(self.con_2D > 0.5, x, y)
124+
self.assertEqual(res.dtype, paddle.float32)
125+
126+
res = paddle.where(self.con_2D > 0.5, 0.5, y)
127+
self.assertEqual(res.dtype, paddle.float32)
128+
129+
res = paddle.where(self.con_2D > 0.5, x, 0.6)
130+
self.assertEqual(res.dtype, paddle.float32)
131+
132+
res = paddle.where(self.con_2D > 0.5, 0.5, 0.6)
133+
self.assertEqual(res.dtype, paddle.float32)
134+
135+
136+
if __name__ == '__main__':
137+
unittest.main()

0 commit comments

Comments
 (0)