Skip to content

Commit a98232c

Browse files
committed
[WIP][API-Compat] Add dyna-graph unittests for min/max
1 parent 3fb103f commit a98232c

File tree

2 files changed

+256
-0
lines changed

2 files changed

+256
-0
lines changed

paddle/fluid/pir/dialect/op_generator/op_build_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
'KthvalueInferMeta',
135135
'MaxPoolWithIndexInferMeta',
136136
'MaxPoolV2InferMeta',
137+
'MinMaxWithIndexInferMeta',
137138
'MultinomialInferMeta',
138139
'OverlapAddInferMeta',
139140
'PadInferMeta',
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
19+
import paddle
20+
21+
22+
class TestCompatMinMax(unittest.TestCase):
23+
def setUp(self):
24+
"""Make sure we are in a dynamic graph env"""
25+
paddle.disable_static()
26+
27+
def test_case1_simple_reduce_all(self):
28+
data = paddle.to_tensor([[1.0, 2.0], [3.0, 4.0]], dtype='float32')
29+
min_val = paddle.compat.min(data)
30+
max_val = paddle.compat.max(data)
31+
32+
self.assertAlmostEqual(min_val.item(), 1.0)
33+
self.assertAlmostEqual(max_val.item(), 4.0)
34+
35+
data = paddle.to_tensor(
36+
[[1.0, 1.0], [2.0, 3.0]], dtype='float32', stop_gradient=False
37+
)
38+
min_val = paddle.compat.min(data)
39+
min_val.backward()
40+
41+
expected_grad = np.array([[0.5, 0.5], [0.0, 0.0]])
42+
np.testing.assert_allclose(data.grad.numpy(), expected_grad)
43+
44+
def test_case2_reduce_dim(self):
45+
"""Test dim/keepdim"""
46+
data = paddle.to_tensor(
47+
[[[5, 8], [2, 1]], [[7, 3], [9, 6]]], dtype='float32'
48+
)
49+
50+
min_result = paddle.compat.min(data, dim=1)
51+
self.assertEqual(min_result.values.shape, [2, 2])
52+
np.testing.assert_array_equal(
53+
min_result.values.numpy(), np.array([[2, 1], [7, 3]])
54+
)
55+
np.testing.assert_array_equal(
56+
min_result.indices.numpy(), np.array([[1, 1], [0, 0]])
57+
)
58+
59+
max_result = paddle.compat.max(data, dim=2)
60+
self.assertEqual(max_result.values.shape, [2, 2])
61+
np.testing.assert_array_equal(
62+
max_result.values.numpy(), np.array([[8, 2], [7, 9]])
63+
)
64+
np.testing.assert_array_equal(
65+
max_result.indices.numpy(), np.array([[1, 0], [0, 0]])
66+
)
67+
68+
min_result_keep = paddle.compat.min(data, dim=0, keepdim=True)
69+
self.assertEqual(min_result_keep.values.shape, [1, 2, 2])
70+
np.testing.assert_array_equal(
71+
min_result_keep.values.numpy(), np.array([[[5, 3], [2, 1]]])
72+
)
73+
74+
min_result_neg = paddle.compat.min(data, dim=-2)
75+
np.testing.assert_array_equal(
76+
min_result_neg.values.numpy(), min_result.values.numpy()
77+
)
78+
79+
def test_case2_grad(self):
80+
data = paddle.to_tensor(
81+
[[[1.0, 2.0], [1.0, 3.0]], [[4.0, 1.0], [5.0, 1.0]]],
82+
dtype='float32',
83+
stop_gradient=False,
84+
)
85+
y = data * 2
86+
87+
min_result = paddle.compat.min(y, dim=2)
88+
min_result.values.backward()
89+
90+
expected_grad = np.array(
91+
[[[2.0, 0.0], [2.0, 0.0]], [[0.0, 2.0], [0.0, 2.0]]]
92+
)
93+
np.testing.assert_allclose(data.grad.numpy(), expected_grad, atol=1e-6)
94+
95+
def test_case3_elementwise(self):
96+
"""minimum/maximum"""
97+
x = paddle.to_tensor([[1, 5], [4, 2]], dtype='float32')
98+
y = paddle.to_tensor([[3, 2], [1, 6]], dtype='float32')
99+
100+
min_result = paddle.compat.min(x, y)
101+
np.testing.assert_array_equal(
102+
min_result.numpy(), np.array([[1, 2], [1, 2]])
103+
)
104+
105+
max_result = paddle.compat.max(x, y)
106+
np.testing.assert_array_equal(
107+
max_result.numpy(), np.array([[3, 5], [4, 6]])
108+
)
109+
110+
z = paddle.to_tensor([3, 4], dtype='float32')
111+
broadcast_min = paddle.compat.min(x, z)
112+
np.testing.assert_array_equal(
113+
broadcast_min.numpy(), np.array([[1, 4], [3, 2]])
114+
)
115+
116+
def test_case3_grad(self):
117+
x = paddle.to_tensor(
118+
[[1.0, 2.0], [3.0, 4.0]], dtype=paddle.float16, stop_gradient=False
119+
)
120+
y = paddle.to_tensor(
121+
[[0.5, 2.5], [2.0, 3.5]], dtype=paddle.float16, stop_gradient=False
122+
)
123+
124+
min_val = paddle.compat.min(x, y)
125+
min_val.backward()
126+
127+
expected_x_grad = np.array([[0.0, 1.0], [0.0, 0.0]])
128+
np.testing.assert_allclose(x.grad.numpy(), expected_x_grad)
129+
130+
expected_y_grad = np.array([[1.0, 0.0], [1.0, 1.0]])
131+
np.testing.assert_allclose(y.grad.numpy(), expected_y_grad)
132+
133+
def test_edge_cases(self):
134+
"""Edge cases test"""
135+
# uniform distributed gradient
136+
uniform_data = paddle.ones([2, 3], dtype='float64')
137+
uniform_data.stop_gradient = False
138+
min_val = paddle.compat.min(uniform_data, 0)
139+
min_val.values.sum().backward()
140+
141+
expected_grad = np.full((2, 3), 0.5)
142+
np.testing.assert_allclose(uniform_data.grad.numpy(), expected_grad)
143+
144+
# 0-dim tensor
145+
dim0_tensor = paddle.to_tensor(2, dtype='float32')
146+
max_val = paddle.compat.max(dim0_tensor)
147+
np.testing.assert_allclose(
148+
max_val.numpy(), np.array(2.0, dtype=np.float32)
149+
)
150+
151+
# 1-dim tensor
152+
dim1_tensor = paddle.to_tensor([1], dtype='uint8')
153+
max_val = paddle.compat.max(dim1_tensor, dim=-1, keepdim=True)
154+
np.testing.assert_array_equal(
155+
max_val[0].numpy(), np.array([1], dtype=np.uint8)
156+
)
157+
np.testing.assert_array_equal(
158+
max_val[1].numpy(), np.array([0], dtype=np.int64)
159+
)
160+
161+
def test_compare_with_index_ops_to_origin(self):
162+
dtypes = ['float32', 'float64', 'bfloat16', 'float16', 'int32', 'int64']
163+
164+
for i, dtype in enumerate(dtypes):
165+
data = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype=dtype)
166+
min_vals_inds = paddle.compat.min(data, dim=0)
167+
self.assertEqual(min_vals_inds.values.dtype, data.dtype)
168+
self.assertEqual(min_vals_inds.indices.dtype, paddle.int64)
169+
170+
origin_values = paddle.min(data, axis=0)
171+
origin_indices = paddle.argmin(data, axis=0, dtype="int64")
172+
if i < 4: # floating point
173+
np.testing.assert_allclose(
174+
min_vals_inds.values.numpy(), origin_values.numpy()
175+
)
176+
else:
177+
np.testing.assert_array_equal(
178+
min_vals_inds.values.numpy(), origin_values.numpy()
179+
)
180+
np.testing.assert_array_equal(
181+
min_vals_inds[1].numpy(), origin_indices.numpy()
182+
)
183+
184+
def test_error_handling(self):
185+
"""Test whether correct exception will be thrown. Skip error messages (some of them are long)"""
186+
187+
err_msg1 = (
188+
"Tensors with integral type: 'paddle.int32' should stop gradient."
189+
)
190+
191+
# empty tensor
192+
empty_tensor = paddle.to_tensor([], dtype='float32')
193+
with self.assertRaises(ValueError):
194+
paddle.compat.min(empty_tensor)
195+
196+
# mixed parameters case 1
197+
input_ts = paddle.to_tensor([1, 2, 3], dtype='float32')
198+
other_ts = paddle.to_tensor([1])
199+
with self.assertRaises(TypeError):
200+
paddle.compat.min(input_ts, other=other_ts, dim=0)
201+
202+
# mixed parameters case 2
203+
with self.assertRaises(TypeError):
204+
paddle.compat.min(input_ts, 0, other=other_ts)
205+
206+
# trying to perform grad ops for integral types
207+
with self.assertRaises(TypeError) as cm:
208+
tensor = paddle.ones([2, 2], dtype=paddle.int32)
209+
tensor.stop_gradient = False
210+
tensors = paddle.compat.max(tensor, dim=0)
211+
self.assertEqual(str(cm.exception), err_msg1)
212+
213+
# explicit None case 1
214+
with self.assertRaises(TypeError) as cm:
215+
paddle.compat.min(input_ts, dim=None)
216+
217+
# explicit None case 2
218+
with self.assertRaises(TypeError) as cm:
219+
paddle.compat.min(input_ts, None, keepdim=True)
220+
221+
# keepdim specified without specifying dim
222+
with self.assertRaises(TypeError) as cm:
223+
paddle.compat.min(input_ts, keepdim=True)
224+
225+
# Wrong *args specification case 1
226+
with self.assertRaises(TypeError) as cm:
227+
paddle.compat.min(input_ts, False)
228+
229+
# Wrong *args specification case 2
230+
with self.assertRaises(TypeError) as cm:
231+
paddle.compat.min(input_ts, other_ts, True)
232+
233+
# Tensor input for dim case 1
234+
with self.assertRaises(TypeError) as cm:
235+
paddle.compat.min(input_ts, dim=paddle.to_tensor([0]))
236+
237+
# Tensor input for dim case 2
238+
with self.assertRaises(TypeError) as cm:
239+
paddle.compat.min(input_ts, dim=paddle.to_tensor(0))
240+
241+
# Duplicate Arguments case 1
242+
with self.assertRaises(TypeError) as cm:
243+
paddle.compat.max(input_ts, 0, dim=0)
244+
245+
# Duplicate Arguments case 2
246+
with self.assertRaises(TypeError) as cm:
247+
paddle.compat.max(input_ts, other_ts, other=0)
248+
249+
# Duplicate Arguments case 3
250+
with self.assertRaises(TypeError) as cm:
251+
paddle.compat.max(input_ts, dim=0, other=0, keepdim=True)
252+
253+
254+
if __name__ == '__main__':
255+
unittest.main()

0 commit comments

Comments
 (0)