|
| 1 | +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import unittest |
| 16 | + |
| 17 | +import numpy as np |
| 18 | + |
| 19 | +import paddle |
| 20 | +from paddle.compat import split |
| 21 | + |
| 22 | + |
| 23 | +class TestCompatSplit(unittest.TestCase): |
| 24 | + def _compare_with_origin(self, input_tensor, size, axis=0): |
| 25 | + pd_results = split(input_tensor, size, dim=axis) |
| 26 | + |
| 27 | + if isinstance(size, int): |
| 28 | + shape_on_axis = input_tensor.shape[axis] |
| 29 | + remaining_num = shape_on_axis % size |
| 30 | + num_sections = shape_on_axis // size |
| 31 | + if remaining_num == 0: |
| 32 | + size = num_sections |
| 33 | + else: |
| 34 | + size = [size for _ in range(num_sections)] |
| 35 | + size.append(remaining_num) |
| 36 | + |
| 37 | + origin_results = paddle.split( |
| 38 | + input_tensor, num_or_sections=size, axis=axis |
| 39 | + ) |
| 40 | + |
| 41 | + self.assertEqual(len(origin_results), len(pd_results)) |
| 42 | + |
| 43 | + # check shape and output section size of the output |
| 44 | + for origin_ts, pd_ts in zip(origin_results, pd_results): |
| 45 | + np.testing.assert_allclose(origin_ts.numpy(), pd_ts.numpy()) |
| 46 | + |
| 47 | + def test_basic_split(self): |
| 48 | + """Test basic splitting with integer size""" |
| 49 | + data = paddle.arange(12).reshape([3, 4]).astype('float32') |
| 50 | + self._compare_with_origin(data, 1, 0) |
| 51 | + self._compare_with_origin(data, 2, 1) |
| 52 | + |
| 53 | + def test_split_with_list_sections(self): |
| 54 | + """Test splitting with list of section sizes""" |
| 55 | + data = paddle.rand([10, 5]) |
| 56 | + self._compare_with_origin(data, [3, 2, 5], 0) |
| 57 | + self._compare_with_origin(data, [1, 4], -1) |
| 58 | + |
| 59 | + def test_chained_operations(self): |
| 60 | + """Test split with complex operation chain""" |
| 61 | + x = paddle.rand([8, 12]) |
| 62 | + y = paddle.sin(x) * 2.0 + paddle.exp(x) / 3.0 |
| 63 | + z = paddle.nn.functional.relu(y) |
| 64 | + |
| 65 | + z1, z2 = split(z, 7, dim=1) |
| 66 | + |
| 67 | + self.assertEqual(z1.shape, [8, 7]) |
| 68 | + self.assertEqual(z2.shape, [8, 5]) |
| 69 | + |
| 70 | + z_np = z.numpy() |
| 71 | + np.testing.assert_allclose(z_np[:, :7], z1.numpy()) |
| 72 | + np.testing.assert_allclose(z_np[:, 7:], z2.numpy()) |
| 73 | + |
| 74 | + def test_split_grad(self): |
| 75 | + """Test backprop for split, in1 and in2 are computed by |
| 76 | + compat.split and original split""" |
| 77 | + |
| 78 | + def get_tensors(): |
| 79 | + np.random.seed(114514) |
| 80 | + np_arr = np.random.normal(0, 1, [2, 3, 4, 5]) |
| 81 | + return paddle.to_tensor(np_arr), paddle.to_tensor(np_arr) |
| 82 | + |
| 83 | + in1, in2 = get_tensors() |
| 84 | + in1.stop_gradient = False |
| 85 | + in2.stop_gradient = False |
| 86 | + |
| 87 | + def computation_graph(in_tensor): |
| 88 | + y = in_tensor * 2.3 + 3.0 |
| 89 | + y = paddle.maximum(y, paddle.to_tensor([0], dtype=paddle.float32)) |
| 90 | + return y.mean(axis=0) |
| 91 | + |
| 92 | + out1 = computation_graph(in1) |
| 93 | + out2 = computation_graph(in2) |
| 94 | + |
| 95 | + packs1 = paddle.compat.split(out1, 2, dim=2) |
| 96 | + packs2 = paddle.split(out2, [2, 2, 1], axis=2) |
| 97 | + |
| 98 | + res1 = packs1[0] + packs1[1] + packs1[2] |
| 99 | + res2 = packs2[0] + packs2[1] + packs2[2] |
| 100 | + res1.backward() |
| 101 | + res2.backward() |
| 102 | + np.testing.assert_allclose(in1.grad.numpy(), in2.grad.numpy()) |
| 103 | + |
| 104 | + def test_empty_dim(self): |
| 105 | + """Split with empty dim""" |
| 106 | + in_tensor = paddle.arange(72, dtype=paddle.int64).reshape([3, 12, 2]) |
| 107 | + self._compare_with_origin(in_tensor, [5, 0, 7], axis=1) |
| 108 | + |
| 109 | + def test_split_with_one_block(self): |
| 110 | + """Resulting tuple should be of length 1""" |
| 111 | + in_tensor = paddle.arange(60, dtype=paddle.float32).reshape([3, 4, 5]) |
| 112 | + self._compare_with_origin(in_tensor, 5, paddle.to_tensor([-1])) |
| 113 | + self._compare_with_origin(in_tensor, [5], paddle.to_tensor(2)) |
| 114 | + |
| 115 | + def test_edge_cases(self): |
| 116 | + """Test edge cases and error handling""" |
| 117 | + x = paddle.arange(5) |
| 118 | + s1, s2 = split(x, [3, 2]) |
| 119 | + np.testing.assert_allclose(s1.numpy(), [0, 1, 2]) |
| 120 | + np.testing.assert_allclose(s2.numpy(), [3, 4]) |
| 121 | + |
| 122 | + x = paddle.rand([2, 2, 2]) |
| 123 | + a, b = split(x, 1, 2) |
| 124 | + self.assertEqual(a.shape, [2, 2, 1]) |
| 125 | + |
| 126 | + # invalid split sections |
| 127 | + with self.assertRaises(ValueError): |
| 128 | + split(x, [3, 1], 1) |
| 129 | + |
| 130 | + # invalid split axis |
| 131 | + with self.assertRaises(ValueError): |
| 132 | + split(x, 2, 3) |
| 133 | + |
| 134 | + def test_error_hint(self): |
| 135 | + """Test whether there will be correct exception when users pass paddle.split kwargs in paddle.compat.split, vice versa.""" |
| 136 | + x = paddle.randn([3, 9, 5]) |
| 137 | + |
| 138 | + msg_gt_1 = ( |
| 139 | + "paddle.split() received unexpected keyword arguments 'dim', 'split_size_or_sections', 'tensor'. " |
| 140 | + "\nDid you mean to use paddle.compat.split() instead?" |
| 141 | + ) |
| 142 | + msg_gt_2 = ( |
| 143 | + "paddle.compat.split() received unexpected keyword argument 'num_or_sections'. " |
| 144 | + "\nDid you mean to use paddle.split() instead?" |
| 145 | + ) |
| 146 | + msg_gt_3 = "(InvalidArgument) The dim is expected to be in range of [-3, 3), but got 3" |
| 147 | + msg_gt_4 = "paddle.compat.split expects split_sizes have only non-negative entries, but got size = -5 on dim 2" |
| 148 | + |
| 149 | + split_size = paddle.to_tensor([3]) |
| 150 | + msg_gt_5 = ( |
| 151 | + "The type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode, but " |
| 152 | + f"received {type(split_size)}." |
| 153 | + ) |
| 154 | + |
| 155 | + with self.assertRaises(TypeError) as cm: |
| 156 | + tensors = paddle.split(tensor=x, split_size_or_sections=3, dim=0) |
| 157 | + self.assertEqual(str(cm.exception), msg_gt_1) |
| 158 | + |
| 159 | + with self.assertRaises(TypeError) as cm: |
| 160 | + tensors = split(x, num_or_sections=3, dim=0) |
| 161 | + self.assertEqual(str(cm.exception), msg_gt_2) |
| 162 | + |
| 163 | + with self.assertRaises(ValueError) as cm: |
| 164 | + tensors = split(x, 3, dim=3) |
| 165 | + self.assertEqual(str(cm.exception), msg_gt_3) |
| 166 | + |
| 167 | + with self.assertRaises(ValueError) as cm: |
| 168 | + tensors = split(x, [3, 3, -5], -2) |
| 169 | + self.assertEqual(str(cm.exception), msg_gt_4) |
| 170 | + |
| 171 | + with self.assertRaises(TypeError) as cm: |
| 172 | + tensors = split(x, split_size, 1) |
| 173 | + self.assertEqual(str(cm.exception), msg_gt_5) |
| 174 | + |
| 175 | + |
| 176 | +if __name__ == '__main__': |
| 177 | + unittest.main() |
0 commit comments