Skip to content

Commit d4d9e5c

Browse files
committed
[API-Compat] Resolved merged conflicts, add symbolic shape test.
1 parent aa1cfd6 commit d4d9e5c

File tree

5 files changed

+174
-42
lines changed

5 files changed

+174
-42
lines changed

python/paddle/tensor/compat.py

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626

2727
if TYPE_CHECKING:
28+
from collections.abc import Sequence
2829

2930
from paddle import Tensor
3031
from paddle._typing import (
@@ -639,24 +640,35 @@ def min(
639640
ret = paddle.min(input)
640641
elif isinstance(dim_or_other, int):
641642
_check_out_status(out, True)
642-
if in_dynamic_mode() and not input.place.is_gpu_place():
643-
_min_max_allow_cpu_composite(input)
644-
# CPUPlace and other placements are implemented by composition
645-
indices = paddle.argmin(input, axis=dim_or_other, keepdim=True)
646-
values = paddle.take_along_axis(input, indices, axis=dim_or_other)
647-
if keepdim:
648-
ret = MinMaxRetType(values=values, indices=indices)
643+
if input.ndim:
644+
if in_dynamic_mode() and not input.place.is_gpu_place():
645+
_min_max_allow_cpu_composite(input)
646+
# CPUPlace and other placements are implemented by composition
647+
648+
indices = paddle.argmin(input, axis=dim_or_other, keepdim=True)
649+
values = paddle.take_along_axis(
650+
input, indices, axis=dim_or_other
651+
)
652+
if keepdim:
653+
ret = MinMaxRetType(values=values, indices=indices)
654+
else:
655+
ret = MinMaxRetType(
656+
values=values.squeeze_(axis=dim_or_other),
657+
indices=indices.squeeze_(axis=dim_or_other),
658+
)
649659
else:
650-
ret = MinMaxRetType(
651-
values=values.squeeze_(axis=dim_or_other),
652-
indices=indices.squeeze_(axis=dim_or_other),
660+
vals, inds = _C_ops.min_with_index(
661+
input, dim_or_other, keepdim, False
653662
)
663+
inds.stop_gradient = True
664+
ret = MinMaxRetType(values=vals, indices=inds)
654665
else:
655-
vals, inds = _C_ops.min_with_index(
656-
input, dim_or_other, keepdim, False
666+
ret = MinMaxRetType(
667+
values=input,
668+
indices=paddle.zeros(
669+
[], dtype=paddle.int64, device=input.place
670+
),
657671
)
658-
inds.stop_gradient = True
659-
ret = MinMaxRetType(values=vals, indices=inds)
660672
else:
661673
_check_out_status(out, False)
662674
ret = _C_ops.minimum(input, dim_or_other)
@@ -780,23 +792,33 @@ def max(
780792
ret = paddle.max(input)
781793
elif isinstance(dim_or_other, int):
782794
_check_out_status(out, True)
783-
if in_dynamic_mode() and not input.place.is_gpu_place():
784-
_min_max_allow_cpu_composite(input)
785-
indices = paddle.argmax(input, axis=dim_or_other, keepdim=True)
786-
values = paddle.take_along_axis(input, indices, axis=dim_or_other)
787-
if keepdim:
788-
ret = MinMaxRetType(values=values, indices=indices)
795+
if input.ndim:
796+
if in_dynamic_mode() and not input.place.is_gpu_place():
797+
_min_max_allow_cpu_composite(input)
798+
indices = paddle.argmax(input, axis=dim_or_other, keepdim=True)
799+
values = paddle.take_along_axis(
800+
input, indices, axis=dim_or_other
801+
)
802+
if keepdim:
803+
ret = MinMaxRetType(values=values, indices=indices)
804+
else:
805+
ret = MinMaxRetType(
806+
values=values.squeeze_(axis=dim_or_other),
807+
indices=indices.squeeze_(axis=dim_or_other),
808+
)
789809
else:
790-
ret = MinMaxRetType(
791-
values=values.squeeze_(axis=dim_or_other),
792-
indices=indices.squeeze_(axis=dim_or_other),
810+
vals, inds = _C_ops.max_with_index(
811+
input, dim_or_other, keepdim, False
793812
)
813+
inds.stop_gradient = True
814+
ret = MinMaxRetType(values=vals, indices=inds)
794815
else:
795-
vals, inds = _C_ops.max_with_index(
796-
input, dim_or_other, keepdim, False
816+
ret = MinMaxRetType(
817+
values=input,
818+
indices=paddle.zeros(
819+
[], dtype=paddle.int64, device=input.place
820+
),
797821
)
798-
inds.stop_gradient = True
799-
ret = MinMaxRetType(values=vals, indices=inds)
800822
else:
801823
_check_out_status(out, False)
802824
ret = _C_ops.maximum(input, dim_or_other)

python/paddle/utils/decorator_utils.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,21 +123,6 @@ def __init__(
123123
self.default_params = default_params
124124
warnings.simplefilter("always", category=Warning)
125125

126-
127-
# *size => shape decorator
128-
class SizeArgsDecorator(DecoratorBase):
129-
"""
130-
Usage Example:
131-
132-
paddle.ones(1, dtype=paddle.float32)
133-
paddle.ones(1, 2, 3, dtype=paddle.float32)
134-
paddle.ones([1, 2, 3], dtype=paddle.float32)
135-
paddle.ones(size=[1, 2, 3], dtype=paddle.float32)
136-
137-
paddle.ones([1, 2, 3], paddle.float32)
138-
paddle.ones(shape=[1, 2, 3], dtype=paddle.float32)
139-
"""
140-
141126
def process(
142127
self, args: tuple[Any, ...], kwargs: dict[str, Any]
143128
) -> tuple[tuple[Any, ...], dict[str, Any]]:
@@ -235,6 +220,32 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
235220
return decorator
236221

237222

223+
# *size => shape decorator
224+
class SizeArgsDecorator(DecoratorBase):
225+
"""
226+
Usage Example:
227+
228+
paddle.ones(1, dtype=paddle.float32)
229+
paddle.ones(1, 2, 3, dtype=paddle.float32)
230+
paddle.ones([1, 2, 3], dtype=paddle.float32)
231+
paddle.ones(size=[1, 2, 3], dtype=paddle.float32)
232+
233+
paddle.ones([1, 2, 3], paddle.float32)
234+
paddle.ones(shape=[1, 2, 3], dtype=paddle.float32)
235+
"""
236+
237+
def process(
238+
self, args: tuple[Any, ...], kwargs: dict[str, Any]
239+
) -> tuple[tuple[Any, ...], dict[str, Any]]:
240+
if 'size' in kwargs:
241+
kwargs['shape'] = kwargs.pop('size')
242+
elif len(args) >= 1 and isinstance(args[0], int):
243+
kwargs['shape'] = list(args)
244+
args = ()
245+
246+
return args, kwargs
247+
248+
238249
class VariableArgsDecorator(DecoratorBase):
239250
def __init__(self, var: str) -> None:
240251
super().__init__()

test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,45 @@ def test_eval_symbolic(self):
7373
return True
7474

7575

76+
class MaxMinWithIndexNet(paddle.nn.Layer):
77+
def __init__(self):
78+
super().__init__()
79+
80+
def forward(self, x):
81+
min_vals, min_inds = paddle.compat.min(x, dim=-1, keepdim=False)
82+
max_vals, max_inds = paddle.compat.max(x, dim=-1, keepdim=True)
83+
return min_vals + max_vals.squeeze(axis=-1), min_inds + max_inds
84+
85+
86+
class MinMaxWithIndexOpInferSymbolicShapeTest(TestBase):
87+
def prepare_data(self):
88+
self.cases = [np.random.rand(3, 4, 5, 6), np.random.rand(257)]
89+
self.expected = [
90+
[
91+
'shape[S0, S1, S2], data[NULL]',
92+
'shape[S0, Broadcast(S0, S1), Broadcast(S1, S2), S2], data[NULL]',
93+
],
94+
['shape[], data[NULL]', 'shape[1], data[NULL]'],
95+
]
96+
97+
def test_eval_symbolic(self):
98+
net = MaxMinWithIndexNet()
99+
100+
for i in range(len(self.cases)):
101+
x = self.cases[i]
102+
x_spec = InputSpec(
103+
shape=[None for index in range(len(x.shape))], dtype='float32'
104+
)
105+
input_spec = [x_spec]
106+
net = apply_to_static(net, False, input_spec)
107+
net.eval()
108+
check_infer_results(
109+
net, input_spec, 'builtin.shadow_output', self.expected[i]
110+
)
111+
112+
return True
113+
114+
76115
class AsComplexAsRealNet(paddle.nn.Layer):
77116
def __init__(self):
78117
super().__init__()

test/legacy_test/test_minmax_with_index_op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,11 @@ def test_check_grad(self):
232232

233233

234234
class TestMinMaxWithIndexPlace(unittest.TestCase):
235+
"""min/max_with_index has no CPU version, so when CUDA is not available,
236+
we skip all the above test. A runtime error will be emitted if min/max_with_index
237+
is called on CPU, this unit test tries capturing it.
238+
"""
239+
235240
def init(self):
236241
self.input_shape = [30, 10, 10]
237242
self.data = np.random.randn(30, 10, 10)

test/legacy_test/test_zero_dim_sundry_dygraph_api.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,61 @@ def test_argmax(self):
551551
out = paddle.argmax(x, keepdim=True)
552552
self.assertEqual(out.shape, [1, 1])
553553

554+
def _make_compat_minmax_test(self, func_name):
555+
# 1) x is 0D
556+
x = paddle.rand([])
557+
val1, ind1 = func_name(x, 0)
558+
val2, ind2 = func_name(x, -1)
559+
val3 = func_name(x)
560+
561+
self.assertEqual(val1.shape, [])
562+
self.assertEqual(ind1.shape, [])
563+
np.testing.assert_allclose(val1, x)
564+
np.testing.assert_allclose(ind1, 0)
565+
566+
self.assertEqual(val2.shape, [])
567+
self.assertEqual(ind2.shape, [])
568+
np.testing.assert_allclose(val2, x)
569+
np.testing.assert_allclose(ind2, 0)
570+
571+
self.assertEqual(val3.shape, [])
572+
np.testing.assert_allclose(val3, x)
573+
574+
# 2) x is 1D
575+
x = paddle.rand([5])
576+
val, ind = func_name(x, 0)
577+
self.assertEqual(val.shape, [])
578+
self.assertEqual(ind.shape, [])
579+
580+
# 3) x is ND
581+
x = paddle.rand([3, 5])
582+
val, ind = func_name(x, dim=1)
583+
self.assertEqual(val.shape, [3])
584+
self.assertEqual(ind.shape, [3])
585+
586+
val = func_name(x)
587+
self.assertEqual(val.shape, [])
588+
589+
# 4) x is ND, keepdim=True
590+
x = paddle.rand([3, 5])
591+
val, ind = func_name(x, dim=0, keepdim=True)
592+
self.assertEqual(val.shape, [1, 5])
593+
self.assertEqual(ind.shape, [1, 5])
594+
595+
# 5) test backward
596+
x = paddle.randn([4, 5])
597+
x.stop_gradient = False
598+
599+
val, ind = func_name(x, dim=0)
600+
val.backward()
601+
self.assertEqual(x.grad.shape, [4, 5])
602+
603+
def test_compat_min(self):
604+
self._make_compat_minmax_test(paddle.compat.min)
605+
606+
def test_compat_max(self):
607+
self._make_compat_minmax_test(paddle.compat.max)
608+
554609
def test_kthvalue(self):
555610
# 1) x is 0D
556611
x = paddle.randn([])

0 commit comments

Comments
 (0)