Skip to content

Commit ab55c08

Browse files
authored
[API compatibility] torch.Tensor.prod torch.Tensor.reshape (#74559)
* [API compatibility] torch.Tensor.prod torch.Tensor.reshape * fix reshape timeout * fix reshape timeout * fix reshape
1 parent 066b3a0 commit ab55c08

File tree

4 files changed

+83
-4
lines changed

4 files changed

+83
-4
lines changed

python/paddle/tensor/manipulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from paddle.tensor import fill_constant
2727
from paddle.utils.decorator_utils import (
2828
ParamAliasDecorator,
29-
param_one_alias,
3029
param_two_alias,
30+
reshape_decorator,
3131
view_decorator,
3232
)
3333
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
@@ -4989,7 +4989,7 @@ def get_attr_expand_shape(list_expand_shape):
49894989
return out
49904990

49914991

4992-
@param_one_alias(["x", "input"])
4992+
@reshape_decorator()
49934993
def reshape(x: Tensor, shape: ShapeLike, name: str | None = None) -> Tensor:
49944994
"""
49954995
Changes the shape of ``x`` without changing its data.

python/paddle/utils/decorator_utils.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,8 @@ def process(
244244
"""
245245
Usage Example:
246246
paddle.view(x=tensor_x, shape_or_dtype=[-1, 1, 3], name=None)
247-
248247
tensor_x.view(paddle.float32) -> paddle.view(tensor_x, paddle.float32)
249248
tensor_x.view(dtype=paddle.float32) -> paddle.view(tensor_x, dtype=paddle.float32)
250-
251249
tensor_x.view([-1, 1, 3]) -> paddle.view(tensor_x, [-1, 1, 3])
252250
tensor_x.view(-1, 1, 3) -> paddle.view(tensor_x, -1, 1, 3)
253251
tensor_x.view(size=[-1, 1, 3]) -> paddle.view(tensor_x, size=[-1, 1, 3])
@@ -273,3 +271,30 @@ def wrapper(*args, **kwargs):
273271
return wrapper
274272

275273
return decorator
274+
275+
276+
def reshape_decorator():
277+
"""
278+
Usage Example:
279+
paddle.reshape(x=tensor_x, shape=[-1, 1, 3], name=None)
280+
paddle.reshape(input=tensor_x, shape=[-1, 1, 3], name=None)
281+
tensor_x.reshape([-1, 1, 3]) -> paddle.reshape(tensor_x, [-1, 1, 3])
282+
tensor_x.reshape(-1, 1, 3) -> paddle.reshape(tensor_x, -1, 1, 3])
283+
"""
284+
285+
def decorator(func):
286+
@functools.wraps(func)
287+
def wrapper(*args, **kwargs):
288+
if ("input" in kwargs) and ("x" not in kwargs):
289+
kwargs["x"] = kwargs.pop("input")
290+
elif len(args) >= 2 and type(args[1]) is int:
291+
if all(type(arg) is int for arg in args[1:]):
292+
kwargs["x"] = args[0]
293+
kwargs['shape'] = list(args[1:])
294+
args = ()
295+
return func(*args, **kwargs)
296+
297+
wrapper.__signature__ = inspect.signature(func)
298+
return wrapper
299+
300+
return decorator

test/legacy_test/test_prod_op.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,35 @@ def test_gpu(self):
423423
with static_guard():
424424
self.run_static()
425425

426+
def test_tensor_prod(self):
427+
"""x.prod(axis=1) is equivalent to x.prod(dim=1)"""
428+
axis_cases = [0, 1, -1]
429+
430+
def run_test_cases(place):
431+
"""Helper function to run test cases on specified device."""
432+
for param_alias in ["axis", "dim"]:
433+
for axis in axis_cases:
434+
input_tensor = paddle.to_tensor(self.input, place=place)
435+
kwargs = {param_alias: axis}
436+
437+
result = input_tensor.prod(**kwargs)
438+
expected = np.prod(self.input, axis=axis)
439+
np.testing.assert_allclose(
440+
(
441+
result.numpy()
442+
if place.is_cpu_place()
443+
else result.cpu().numpy()
444+
),
445+
expected,
446+
rtol=1e-05,
447+
)
448+
449+
with dygraph_guard():
450+
run_test_cases(paddle.CPUPlace())
451+
452+
if paddle.base.core.is_compiled_with_cuda():
453+
run_test_cases(paddle.CUDAPlace(0))
454+
426455

427456
if __name__ == "__main__":
428457
unittest.main()

test/legacy_test/test_reshape_op.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,31 @@ def test_imperative(self):
915915
np.testing.assert_array_equal(out_2.numpy(), input.reshape([5, 10]))
916916
np.testing.assert_array_equal(out_3.numpy(), input.reshape(shape))
917917

918+
def test_tensor_reshape(self):
919+
"""The `shape` parameter accepts either variable arguments or a list/tuple.
920+
For example, x.reshape(2, 5, 5) is equivalent to x.reshape([2, 5, 5]).
921+
"""
922+
923+
def run_test_cases(place):
924+
"""Helper function to run test cases on specified device."""
925+
input = np.random.random([2, 25]).astype("float32")
926+
input_tensor = paddle.to_tensor(input, place=place)
927+
928+
out_1 = input_tensor.reshape([2, 5, 5])
929+
out_2 = input_tensor.reshape(2, 5, 5)
930+
931+
np.testing.assert_array_equal(
932+
out_1.numpy(), input.reshape([2, 5, 5])
933+
)
934+
np.testing.assert_array_equal(
935+
out_2.numpy(), input.reshape([2, 5, 5])
936+
)
937+
938+
with base.dygraph.guard():
939+
run_test_cases(paddle.CPUPlace())
940+
if paddle.base.core.is_compiled_with_cuda():
941+
run_test_cases(paddle.CUDAPlace(0))
942+
918943

919944
if __name__ == "__main__":
920945
paddle.enable_static()

0 commit comments

Comments
 (0)