|
23 | 23 | is_custom_device, |
24 | 24 | skip_check_grad_ci, |
25 | 25 | ) |
| 26 | +from utils import dygraph_guard, static_guard |
26 | 27 |
|
27 | 28 | import paddle |
28 | 29 | from paddle import base |
@@ -942,6 +943,134 @@ def run_test_cases(place): |
942 | 943 | run_test_cases(get_device_place()) |
943 | 944 |
|
944 | 945 |
|
| 946 | +class TestReshapeWithTensorShape(unittest.TestCase): |
| 947 | + """ |
| 948 | + reshape supports shape like: |
| 949 | + paddle.reshape(x, shape=[1, 2, 3]) |
| 950 | + paddle.reshape(x, shape=[1, Tensor(2), 3]) |
| 951 | + paddle.reshape(x, shape=Tensor([1, 2, 3])) |
| 952 | + paddle.reshape(x, 1, 2, 3) # Compatible usage |
| 953 | + paddle.reshape(x, 1, Tensor(2), 3) # Compatible usage |
| 954 | + """ |
| 955 | + |
| 956 | + @static_guard() |
| 957 | + def check_reshape_static( |
| 958 | + self, fn, x_shape, expected_out_shape, dynamic_dims=[] |
| 959 | + ): |
| 960 | + main_program = Program() |
| 961 | + with program_guard(main_program): |
| 962 | + x = paddle.static.data('x', shape=x_shape, dtype='float32') |
| 963 | + out = fn(x) |
| 964 | + if dynamic_dims: |
| 965 | + expected_out_shape_with_dynamic = list(expected_out_shape) |
| 966 | + for dim in dynamic_dims: |
| 967 | + expected_out_shape_with_dynamic[dim] = -1 |
| 968 | + self.assertEqual(out.shape, expected_out_shape_with_dynamic) |
| 969 | + else: |
| 970 | + self.assertEqual(out.shape, expected_out_shape) |
| 971 | + |
| 972 | + exe = paddle.static.Executor() |
| 973 | + (out_np,) = exe.run( |
| 974 | + main_program, |
| 975 | + feed={'x': np.random.random(x_shape)}, |
| 976 | + fetch_list=[out], |
| 977 | + ) |
| 978 | + self.assertEqual(list(out_np.shape), expected_out_shape) |
| 979 | + |
| 980 | + @dygraph_guard() |
| 981 | + def check_reshape_dygraph(self, fn, x_shape, expected_out_shape): |
| 982 | + x = paddle.to_tensor(np.random.random(x_shape).astype('float32')) |
| 983 | + out = fn(x) |
| 984 | + self.assertEqual(list(out.shape), expected_out_shape) |
| 985 | + |
| 986 | + def check_reshape(self, fn, x_shape, expected_out_shape): |
| 987 | + self.check_reshape_static(fn, x_shape, expected_out_shape) |
| 988 | + self.check_reshape_dygraph(fn, x_shape, expected_out_shape) |
| 989 | + |
| 990 | + def test_reshape_with_list_int(self): |
| 991 | + def reshape_fn(x): |
| 992 | + return paddle.reshape(x, shape=[2, 3, 4]) |
| 993 | + |
| 994 | + self.check_reshape(reshape_fn, [2, 12], [2, 3, 4]) |
| 995 | + |
| 996 | + def test_reshape_with_list_scalar_tensor(self): |
| 997 | + def reshape_fn(x): |
| 998 | + dim0 = paddle.full([], 2, dtype='int64') |
| 999 | + dim1 = paddle.full([], 3, dtype='int64') |
| 1000 | + dim2 = paddle.full([], 4, dtype='int64') |
| 1001 | + return paddle.reshape(x, shape=[dim0, dim1, dim2]) |
| 1002 | + |
| 1003 | + self.check_reshape(reshape_fn, [2, 12], [2, 3, 4]) |
| 1004 | + |
| 1005 | + def test_reshape_with_list_scalar_tensor_dynamic_dim(self): |
| 1006 | + def reshape_fn(x): |
| 1007 | + dim0 = paddle.full([], 1, dtype='int64') + 1 # dynamic dim |
| 1008 | + dim1 = paddle.full([], 3, dtype='int64') |
| 1009 | + dim2 = paddle.full([], 4, dtype='int64') |
| 1010 | + return paddle.reshape(x, shape=[dim0, dim1, dim2]) |
| 1011 | + |
| 1012 | + self.check_reshape_static( |
| 1013 | + reshape_fn, |
| 1014 | + x_shape=[2, 12], |
| 1015 | + expected_out_shape=[2, 3, 4], |
| 1016 | + dynamic_dims=[0], |
| 1017 | + ) |
| 1018 | + |
| 1019 | + def test_reshape_with_list_mix_int_tensor(self): |
| 1020 | + def reshape_fn(x): |
| 1021 | + dim1 = paddle.full([], 3, dtype='int64') |
| 1022 | + return paddle.reshape(x, shape=[2, dim1, 4]) |
| 1023 | + |
| 1024 | + self.check_reshape(reshape_fn, [2, 12], [2, 3, 4]) |
| 1025 | + |
| 1026 | + def test_reshape_with_tensor_dynamic_dim(self): |
| 1027 | + def reshape_fn(x): |
| 1028 | + shape_tensor = paddle.to_tensor([1, 2, 3]) + 1 # all dynamic dims |
| 1029 | + return paddle.reshape(x, shape=shape_tensor) |
| 1030 | + |
| 1031 | + self.check_reshape_static( |
| 1032 | + reshape_fn, |
| 1033 | + x_shape=[2, 12], |
| 1034 | + expected_out_shape=[2, 3, 4], |
| 1035 | + dynamic_dims=[0, 1, 2], |
| 1036 | + ) |
| 1037 | + |
| 1038 | + def test_reshape_with_tensor(self): |
| 1039 | + def reshape_fn(x): |
| 1040 | + shape_tensor = paddle.stack( |
| 1041 | + [ |
| 1042 | + paddle.full([], 2, dtype='int64'), |
| 1043 | + paddle.full([], 3, dtype='int64'), |
| 1044 | + paddle.full([], 4, dtype='int64'), |
| 1045 | + ] |
| 1046 | + ) |
| 1047 | + return paddle.reshape(x, shape=shape_tensor) |
| 1048 | + |
| 1049 | + self.check_reshape(reshape_fn, [2, 12], [2, 3, 4]) |
| 1050 | + |
| 1051 | + def test_reshape_with_list_int_compatible(self): |
| 1052 | + def reshape_fn(x): |
| 1053 | + return paddle.reshape(x, 2, 3, 4) |
| 1054 | + |
| 1055 | + self.check_reshape(reshape_fn, [2, 12], [2, 3, 4]) |
| 1056 | + |
| 1057 | + def test_reshape_with_list_scalar_tensor_compatible(self): |
| 1058 | + def reshape_fn(x): |
| 1059 | + dim0 = paddle.full([], 2, dtype='int64') |
| 1060 | + dim1 = paddle.full([], 3, dtype='int64') |
| 1061 | + dim2 = paddle.full([], 4, dtype='int64') |
| 1062 | + return paddle.reshape(x, dim0, dim1, dim2) |
| 1063 | + |
| 1064 | + self.check_reshape(reshape_fn, [2, 12], [2, 3, 4]) |
| 1065 | + |
| 1066 | + def test_reshape_with_list_mix_int_tensor_compatible(self): |
| 1067 | + def reshape_fn(x): |
| 1068 | + dim1 = paddle.full([], 3, dtype='int64') |
| 1069 | + return paddle.reshape(x, 2, dim1, 4) |
| 1070 | + |
| 1071 | + self.check_reshape(reshape_fn, [2, 12], [2, 3, 4]) |
| 1072 | + |
| 1073 | + |
945 | 1074 | if __name__ == "__main__": |
946 | 1075 | paddle.enable_static() |
947 | 1076 | unittest.main() |
0 commit comments