Skip to content

Commit e826c7c

Browse files
committed
[WIP][API-Compat] Fixed CPU failure
1 parent a98232c commit e826c7c

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed

python/paddle/tensor/compat.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -409,10 +409,8 @@ def min(input: Tensor, *args: Any, **kwargs: Any) -> Tensor | MinMaxRetType:
409409
return MinMaxRetType(values=vals, indices=inds)
410410
else:
411411
# CPUPlace and other placements are implemented by composition
412-
indices = _C_ops.argmin(
413-
input, dim_or_other, True, False, paddle.int64
414-
)
415-
values = _C_ops.take_along_axis(input, indices, dim_or_other)
412+
indices = paddle.argmin(input, axis=dim_or_other, keepdim=True)
413+
values = paddle.take_along_axis(input, indices, axis=dim_or_other)
416414
if keepdim:
417415
return MinMaxRetType(values=values, indices=indices)
418416
return MinMaxRetType(
@@ -524,10 +522,8 @@ def max(input: Tensor, *args: Any, **kwargs: Any) -> Tensor | MinMaxRetType:
524522
return MinMaxRetType(values=vals, indices=inds)
525523
else:
526524
# CPUPlace and other placements are implemented by composition
527-
indices = _C_ops.argmax(
528-
input, dim_or_other, True, False, paddle.int64
529-
)
530-
values = _C_ops.take_along_axis(input, indices, dim_or_other)
525+
indices = paddle.argmax(input, axis=dim_or_other, keepdim=True)
526+
values = paddle.take_along_axis(input, indices, axis=dim_or_other)
531527
if keepdim:
532528
return MinMaxRetType(values=values, indices=indices)
533529
return MinMaxRetType(

python/paddle/tensor/math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@
100100
from paddle import Tensor
101101
from paddle._typing import DTypeLike
102102

103+
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
104+
103105
__all__ = []
104106

105107
_supported_int_dtype_ = [
@@ -3131,6 +3133,11 @@ def _check_input(x):
31313133
return out
31323134

31333135

3136+
@ForbidKeywordsDecorator(
3137+
illegal_keys=["input", "dim", "other"],
3138+
func_name="paddle.max",
3139+
correct_name="paddle.compat.max",
3140+
)
31343141
def max(
31353142
x: Tensor,
31363143
axis: int | Sequence[int] | None = None,
@@ -3290,6 +3297,11 @@ def max(
32903297
return out
32913298

32923299

3300+
@ForbidKeywordsDecorator(
3301+
illegal_keys=["input", "dim", "other"],
3302+
func_name="paddle.min",
3303+
correct_name="paddle.compat.min",
3304+
)
32933305
def min(
32943306
x: Tensor,
32953307
axis: int | Sequence[int] | None = None,

test/legacy_test/test_compat_minmax.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,14 @@ def test_error_handling(self):
187187
err_msg1 = (
188188
"Tensors with integral type: 'paddle.int32' should stop gradient."
189189
)
190+
err_msg2 = (
191+
"paddle.min() received unexpected keyword arguments 'input', 'dim'. "
192+
"\nDid you mean to use paddle.compat.min() instead?"
193+
)
194+
err_msg3 = (
195+
"paddle.compat.max() received unexpected keyword argument 'axis'. "
196+
"\nDid you mean to use paddle.max() instead?"
197+
)
190198

191199
# empty tensor
192200
empty_tensor = paddle.to_tensor([], dtype='float32')
@@ -250,6 +258,16 @@ def test_error_handling(self):
250258
with self.assertRaises(TypeError) as cm:
251259
paddle.compat.max(input_ts, dim=0, other=0, keepdim=True)
252260

261+
# Wrong API used case 1
262+
with self.assertRaises(TypeError) as cm:
263+
paddle.min(input=input_ts, dim=0)
264+
self.assertEqual(str(cm.exception), err_msg2)
265+
266+
# Wrong API used case 2
267+
with self.assertRaises(TypeError) as cm:
268+
paddle.compat.max(input_ts, axis=0)
269+
self.assertEqual(str(cm.exception), err_msg3)
270+
253271

254272
if __name__ == '__main__':
255273
unittest.main()

0 commit comments

Comments
 (0)