Skip to content

Commit 00ba93d

Browse files
LittleHeroZZZXmaxiaolong001
authored andcommitted
[API Compatibility] Add pp.Tensor.mul_, pp.autograd.Function, pp.argwhere (PaddlePaddle#74493)
* Add pp.Tensor.mul_, pp.autograd.Function, pp.argwhere * Remove scalar support for mul and mul_
1 parent e6f38ee commit 00ba93d

File tree

7 files changed

+1048
-0
lines changed

7 files changed

+1048
-0
lines changed

python/paddle/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,7 @@
580580
argmax,
581581
argmin,
582582
argsort,
583+
argwhere,
583584
bucketize,
584585
index_sample,
585586
index_select,
@@ -1131,6 +1132,7 @@
11311132
'atleast_3d',
11321133
'reverse',
11331134
'nonzero',
1135+
'argwhere',
11341136
'CUDAPinnedPlace',
11351137
'XPUPinnedPlace',
11361138
'logical_not',

python/paddle/autograd/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@
2828
from .py_layer import PyLayer, PyLayerContext
2929
from .saved_tensors_hooks import saved_tensors_hooks
3030

31+
Function = PyLayer
32+
3133
__all__ = [
3234
'jacobian',
3335
'hessian',
3436
'backward',
3537
'PyLayer',
38+
'Function',
3639
'PyLayerContext',
3740
'saved_tensors_hooks',
3841
]

python/paddle/tensor/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@
454454
argmax,
455455
argmin,
456456
argsort,
457+
argwhere,
457458
bucketize,
458459
index_sample,
459460
index_select,
@@ -609,6 +610,8 @@
609610
'floor_mod_',
610611
'multiply',
611612
'multiply_',
613+
'mul',
614+
'mul_',
612615
'add',
613616
'add_',
614617
'subtract',
@@ -880,8 +883,12 @@
880883
'log_normal_',
881884
'set_',
882885
'resize_',
886+
'argwhere',
883887
]
884888

889+
mul = multiply
890+
mul_ = multiply_
891+
885892
# this list used in math_op_patch.py for magic_method bind
886893
magic_method_func = [
887894
('__and__', 'bitwise_and'),

python/paddle/tensor/search.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,38 @@ def nonzero(x: Tensor, as_tuple=False):
561561
return tuple(list_out)
562562

563563

564+
def argwhere(input: Tensor) -> Tensor:
565+
"""
566+
Return a tensor containing the indices of all non-zero elements of the `input`
567+
tensor. The returned tensor has shape [z, n], where `z` is the number of all non-zero
568+
elements in the `input` tensor, and `n` is the number of dimensions in the `input`
569+
tensor.
570+
571+
Args:
572+
input (Tensor): The input tensor variable.
573+
574+
Returns:
575+
Tensor, The data type is int64.
576+
577+
Examples:
578+
579+
.. code-block:: python
580+
581+
>>> import paddle
582+
583+
>>> x = paddle.to_tensor([[1.0, 0.0, 0.0],
584+
... [0.0, 2.0, 0.0],
585+
... [0.0, 0.0, 3.0]])
586+
>>> out = paddle.tensor.search.argwhere(x)
587+
>>> print(out)
588+
Tensor(shape=[3, 2], dtype=int64, place=Place(cpu), stop_gradient=True,
589+
[[0, 0],
590+
[1, 1],
591+
[2, 2]])
592+
"""
593+
return nonzero(input, as_tuple=False)
594+
595+
564596
def _restrict_nonzero(condition: Tensor, total_true_num: int) -> Tensor:
565597
"""
566598
Return a tensor containing the indices of all non-zero elements of the `input`

test/legacy_test/test_argwhere_api.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from op_test import OpTest, convert_float_to_uint16
19+
20+
import paddle
21+
from paddle import base
22+
from paddle.base import Program, program_guard
23+
24+
25+
def call_argwhere(x):
26+
input = paddle.to_tensor(x)
27+
return paddle.argwhere(input)
28+
29+
30+
class TestArgwhereAPI(unittest.TestCase):
31+
def test_argwhere_api(self):
32+
paddle.enable_static()
33+
data = np.array([[1, 0], [0, 1]], dtype="float32")
34+
with program_guard(Program(), Program()):
35+
x = paddle.static.data(name='x', shape=[-1, 2], dtype='float32')
36+
if not paddle.framework.use_pir_api():
37+
x.desc.set_need_check_feed(False)
38+
y = paddle.argwhere(x)
39+
exe = base.Executor(base.CPUPlace())
40+
(res,) = exe.run(
41+
feed={'x': data}, fetch_list=[y], return_numpy=False
42+
)
43+
expect_out = np.array([[0, 0], [1, 1]])
44+
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
45+
46+
data = np.array([1, 1, 0], dtype="float32")
47+
with program_guard(Program(), Program()):
48+
x = paddle.static.data(name='x', shape=[-1], dtype='float32')
49+
if not paddle.framework.use_pir_api():
50+
x.desc.set_need_check_feed(False)
51+
y = paddle.argwhere(x)
52+
exe = base.Executor(base.CPUPlace())
53+
(res,) = exe.run(
54+
feed={'x': data}, fetch_list=[y], return_numpy=False
55+
)
56+
expect_out = np.array([[0], [1]])
57+
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
58+
59+
def test_dygraph_api(self):
60+
data_x = np.array([[True, False], [False, True]])
61+
with base.dygraph.guard():
62+
x = paddle.to_tensor(data_x)
63+
z = paddle.argwhere(x)
64+
np_z = z.numpy()
65+
expect_out = np.array([[0, 0], [1, 1]])
66+
67+
68+
# Base case
69+
class TestArgwhereOp(OpTest):
70+
def setUp(self):
71+
'''Test where_index op with random value'''
72+
np.random.seed(2023)
73+
self.op_type = "where_index"
74+
self.python_api = call_argwhere
75+
self.init_shape()
76+
self.init_dtype()
77+
78+
self.inputs = self.create_inputs()
79+
self.outputs = self.return_outputs()
80+
81+
def test_check_output(self):
82+
self.check_output(check_pir=True, check_symbol_infer=False)
83+
84+
def init_shape(self):
85+
self.shape = [8, 8]
86+
87+
def init_dtype(self):
88+
self.dtype = np.float64
89+
90+
def create_inputs(self):
91+
return {
92+
'Condition': np.random.randint(5, size=self.shape).astype(
93+
self.dtype
94+
)
95+
}
96+
97+
def return_outputs(self):
98+
return {'Out': np.argwhere(self.inputs['Condition'])}
99+
100+
101+
class TestArgwhereComplex64Op(TestArgwhereOp):
102+
def init_shape(self):
103+
self.shape = [1, 2, 3]
104+
105+
def init_dtype(self):
106+
self.dtype = np.complex64
107+
108+
109+
class TestArgwhereComplex128Op(TestArgwhereOp):
110+
def init_shape(self):
111+
self.shape = [1, 2, 3]
112+
113+
def init_dtype(self):
114+
self.dtype = np.complex128
115+
116+
117+
class TestArgwhereFP32Op(TestArgwhereOp):
118+
def init_shape(self):
119+
self.shape = [2, 10, 2]
120+
121+
def init_dtype(self):
122+
self.dtype = np.float32
123+
124+
125+
class TestArgwhereFP16Op(TestArgwhereOp):
126+
def init_shape(self):
127+
self.shape = [3, 4, 7]
128+
129+
def init_dtype(self):
130+
self.dtype = np.float16
131+
132+
133+
class TestArgwhereBF16(OpTest):
134+
def setUp(self):
135+
'''Test where_index op with bfloat16 dtype'''
136+
np.random.seed(2023)
137+
self.op_type = "where_index"
138+
self.python_api = call_argwhere
139+
self.init_shape()
140+
self.init_dtype()
141+
142+
self.inputs = self.create_inputs()
143+
self.outputs = self.return_outputs()
144+
145+
def test_check_output(self):
146+
self.check_output(check_pir=True, check_symbol_infer=False)
147+
148+
def init_shape(self):
149+
self.shape = [12, 9]
150+
151+
def init_dtype(self):
152+
self.dtype = np.uint16
153+
154+
def create_inputs(self):
155+
return {
156+
'Condition': convert_float_to_uint16(
157+
np.random.randint(5, size=self.shape).astype(np.float32)
158+
)
159+
}
160+
161+
def return_outputs(self):
162+
return {'Out': np.argwhere(self.inputs['Condition'])}
163+
164+
165+
class TestZeroSizeOp(TestArgwhereOp):
166+
167+
def init_shape(self):
168+
self.shape = [0, 10]
169+
170+
def init_dtype(self):
171+
self.dtype = np.float64
172+
173+
174+
class TestZeroSizeOpCase2(TestArgwhereOp):
175+
176+
def init_shape(self):
177+
self.shape = [0, 10]
178+
179+
def init_dtype(self):
180+
self.dtype = np.float64
181+
182+
def test_check_output(self):
183+
self.check_output(check_pir=True, check_symbol_infer=True)
184+
185+
186+
if __name__ == "__main__":
187+
unittest.main()

0 commit comments

Comments
 (0)