Skip to content

Commit b4a5484

Browse files
authored
[API] Support tensor shape in reshape with compatible API (PaddlePaddle#76025)
1 parent 915abce commit b4a5484

File tree

2 files changed

+145
-6
lines changed

2 files changed

+145
-6
lines changed

python/paddle/utils/decorator_utils.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,23 @@
2121

2222
from typing_extensions import ParamSpec
2323

24+
import paddle
25+
2426
if TYPE_CHECKING:
2527
from collections.abc import Iterable
2628

2729
_InputT = ParamSpec("_InputT")
2830
_RetT = TypeVar("_RetT")
2931

3032

33+
def _is_in_or_scalar_tensor(x):
34+
if isinstance(x, int):
35+
return True
36+
if isinstance(x, (paddle.Tensor, paddle.pir.Value)):
37+
return x.ndim == 0
38+
return False
39+
40+
3141
class DecoratorBase:
3242
"""Decorative base class, providing a universal decorative framework.
3343
@@ -410,8 +420,8 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
410420
kwargs["shape_or_dtype"] = kwargs.pop("dtype")
411421
elif ("size" in kwargs) and ("shape_or_dtype" not in kwargs):
412422
kwargs["shape_or_dtype"] = kwargs.pop("size")
413-
elif len(args) >= 2 and type(args[1]) is int:
414-
if all(type(arg) is int for arg in args[1:]):
423+
elif len(args) >= 2 and _is_in_or_scalar_tensor(args[1]):
424+
if all(_is_in_or_scalar_tensor(arg) for arg in args[1:]):
415425
kwargs["x"] = args[0]
416426
kwargs['shape_or_dtype'] = list(args[1:])
417427
args = ()
@@ -542,8 +552,8 @@ def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
542552
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
543553
if ("input" in kwargs) and ("x" not in kwargs):
544554
kwargs["x"] = kwargs.pop("input")
545-
elif len(args) >= 2 and type(args[1]) is int:
546-
if all(type(arg) is int for arg in args[1:]):
555+
elif len(args) >= 2 and _is_in_or_scalar_tensor(args[1]):
556+
if all(_is_in_or_scalar_tensor(arg) for arg in args[1:]):
547557
kwargs["x"] = args[0]
548558
kwargs['shape'] = list(args[1:])
549559
args = ()
@@ -614,8 +624,8 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
614624
kwargs["x"] = kwargs.pop("input")
615625
if ("size" in kwargs) and ("shape" not in kwargs):
616626
kwargs["shape"] = kwargs.pop("size")
617-
elif len(args) >= 2 and type(args[1]) is int:
618-
if all(type(arg) is int for arg in args[1:]):
627+
elif len(args) >= 2 and _is_in_or_scalar_tensor(args[1]):
628+
if all(_is_in_or_scalar_tensor(arg) for arg in args[1:]):
619629
kwargs["x"] = args[0]
620630
kwargs['shape'] = list(args[1:])
621631
args = ()

test/legacy_test/test_reshape_op.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
is_custom_device,
2424
skip_check_grad_ci,
2525
)
26+
from utils import dygraph_guard, static_guard
2627

2728
import paddle
2829
from paddle import base
@@ -942,6 +943,134 @@ def run_test_cases(place):
942943
run_test_cases(get_device_place())
943944

944945

946+
class TestReshapeWithTensorShape(unittest.TestCase):
947+
"""
948+
reshape supports shape like:
949+
paddle.reshape(x, shape=[1, 2, 3])
950+
paddle.reshape(x, shape=[1, Tensor(2), 3])
951+
paddle.reshape(x, shape=Tensor([1, 2, 3]))
952+
paddle.reshape(x, 1, 2, 3) # Compatible usage
953+
paddle.reshape(x, 1, Tensor(2), 3) # Compatible usage
954+
"""
955+
956+
@static_guard()
957+
def check_reshape_static(
958+
self, fn, x_shape, expected_out_shape, dynamic_dims=[]
959+
):
960+
main_program = Program()
961+
with program_guard(main_program):
962+
x = paddle.static.data('x', shape=x_shape, dtype='float32')
963+
out = fn(x)
964+
if dynamic_dims:
965+
expected_out_shape_with_dynamic = list(expected_out_shape)
966+
for dim in dynamic_dims:
967+
expected_out_shape_with_dynamic[dim] = -1
968+
self.assertEqual(out.shape, expected_out_shape_with_dynamic)
969+
else:
970+
self.assertEqual(out.shape, expected_out_shape)
971+
972+
exe = paddle.static.Executor()
973+
(out_np,) = exe.run(
974+
main_program,
975+
feed={'x': np.random.random(x_shape)},
976+
fetch_list=[out],
977+
)
978+
self.assertEqual(list(out_np.shape), expected_out_shape)
979+
980+
@dygraph_guard()
981+
def check_reshape_dygraph(self, fn, x_shape, expected_out_shape):
982+
x = paddle.to_tensor(np.random.random(x_shape).astype('float32'))
983+
out = fn(x)
984+
self.assertEqual(list(out.shape), expected_out_shape)
985+
986+
def check_reshape(self, fn, x_shape, expected_out_shape):
987+
self.check_reshape_static(fn, x_shape, expected_out_shape)
988+
self.check_reshape_dygraph(fn, x_shape, expected_out_shape)
989+
990+
def test_reshape_with_list_int(self):
991+
def reshape_fn(x):
992+
return paddle.reshape(x, shape=[2, 3, 4])
993+
994+
self.check_reshape(reshape_fn, [2, 12], [2, 3, 4])
995+
996+
def test_reshape_with_list_scalar_tensor(self):
997+
def reshape_fn(x):
998+
dim0 = paddle.full([], 2, dtype='int64')
999+
dim1 = paddle.full([], 3, dtype='int64')
1000+
dim2 = paddle.full([], 4, dtype='int64')
1001+
return paddle.reshape(x, shape=[dim0, dim1, dim2])
1002+
1003+
self.check_reshape(reshape_fn, [2, 12], [2, 3, 4])
1004+
1005+
def test_reshape_with_list_scalar_tensor_dynamic_dim(self):
1006+
def reshape_fn(x):
1007+
dim0 = paddle.full([], 1, dtype='int64') + 1 # dynamic dim
1008+
dim1 = paddle.full([], 3, dtype='int64')
1009+
dim2 = paddle.full([], 4, dtype='int64')
1010+
return paddle.reshape(x, shape=[dim0, dim1, dim2])
1011+
1012+
self.check_reshape_static(
1013+
reshape_fn,
1014+
x_shape=[2, 12],
1015+
expected_out_shape=[2, 3, 4],
1016+
dynamic_dims=[0],
1017+
)
1018+
1019+
def test_reshape_with_list_mix_int_tensor(self):
1020+
def reshape_fn(x):
1021+
dim1 = paddle.full([], 3, dtype='int64')
1022+
return paddle.reshape(x, shape=[2, dim1, 4])
1023+
1024+
self.check_reshape(reshape_fn, [2, 12], [2, 3, 4])
1025+
1026+
def test_reshape_with_tensor_dynamic_dim(self):
1027+
def reshape_fn(x):
1028+
shape_tensor = paddle.to_tensor([1, 2, 3]) + 1 # all dynamic dims
1029+
return paddle.reshape(x, shape=shape_tensor)
1030+
1031+
self.check_reshape_static(
1032+
reshape_fn,
1033+
x_shape=[2, 12],
1034+
expected_out_shape=[2, 3, 4],
1035+
dynamic_dims=[0, 1, 2],
1036+
)
1037+
1038+
def test_reshape_with_tensor(self):
1039+
def reshape_fn(x):
1040+
shape_tensor = paddle.stack(
1041+
[
1042+
paddle.full([], 2, dtype='int64'),
1043+
paddle.full([], 3, dtype='int64'),
1044+
paddle.full([], 4, dtype='int64'),
1045+
]
1046+
)
1047+
return paddle.reshape(x, shape=shape_tensor)
1048+
1049+
self.check_reshape(reshape_fn, [2, 12], [2, 3, 4])
1050+
1051+
def test_reshape_with_list_int_compatible(self):
1052+
def reshape_fn(x):
1053+
return paddle.reshape(x, 2, 3, 4)
1054+
1055+
self.check_reshape(reshape_fn, [2, 12], [2, 3, 4])
1056+
1057+
def test_reshape_with_list_scalar_tensor_compatible(self):
1058+
def reshape_fn(x):
1059+
dim0 = paddle.full([], 2, dtype='int64')
1060+
dim1 = paddle.full([], 3, dtype='int64')
1061+
dim2 = paddle.full([], 4, dtype='int64')
1062+
return paddle.reshape(x, dim0, dim1, dim2)
1063+
1064+
self.check_reshape(reshape_fn, [2, 12], [2, 3, 4])
1065+
1066+
def test_reshape_with_list_mix_int_tensor_compatible(self):
1067+
def reshape_fn(x):
1068+
dim1 = paddle.full([], 3, dtype='int64')
1069+
return paddle.reshape(x, 2, dim1, 4)
1070+
1071+
self.check_reshape(reshape_fn, [2, 12], [2, 3, 4])
1072+
1073+
9451074
if __name__ == "__main__":
9461075
paddle.enable_static()
9471076
unittest.main()

0 commit comments

Comments
 (0)