Skip to content

Commit ff2ded8

Browse files
committed
[API-Compat] Restored removed unittests
1 parent d4d9e5c commit ff2ded8

File tree

2 files changed

+361
-0
lines changed

2 files changed

+361
-0
lines changed

test/legacy_test/test_compat_split.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
19+
import paddle
20+
from paddle.compat import split
21+
22+
23+
class TestCompatSplit(unittest.TestCase):
24+
def _compare_with_origin(self, input_tensor, size, axis=0):
25+
pd_results = split(input_tensor, size, dim=axis)
26+
27+
if isinstance(size, int):
28+
shape_on_axis = input_tensor.shape[axis]
29+
remaining_num = shape_on_axis % size
30+
num_sections = shape_on_axis // size
31+
if remaining_num == 0:
32+
size = num_sections
33+
else:
34+
size = [size for _ in range(num_sections)]
35+
size.append(remaining_num)
36+
37+
origin_results = paddle.split(
38+
input_tensor, num_or_sections=size, axis=axis
39+
)
40+
41+
self.assertEqual(len(origin_results), len(pd_results))
42+
43+
# check shape and output section size of the output
44+
for origin_ts, pd_ts in zip(origin_results, pd_results):
45+
np.testing.assert_allclose(origin_ts.numpy(), pd_ts.numpy())
46+
47+
def test_basic_split(self):
48+
"""Test basic splitting with integer size"""
49+
data = paddle.arange(12).reshape([3, 4]).astype('float32')
50+
self._compare_with_origin(data, 1, 0)
51+
self._compare_with_origin(data, 2, 1)
52+
53+
def test_split_with_list_sections(self):
54+
"""Test splitting with list of section sizes"""
55+
data = paddle.rand([10, 5])
56+
self._compare_with_origin(data, [3, 2, 5], 0)
57+
self._compare_with_origin(data, [1, 4], -1)
58+
59+
def test_chained_operations(self):
60+
"""Test split with complex operation chain"""
61+
x = paddle.rand([8, 12])
62+
y = paddle.sin(x) * 2.0 + paddle.exp(x) / 3.0
63+
z = paddle.nn.functional.relu(y)
64+
65+
z1, z2 = split(z, 7, dim=1)
66+
67+
self.assertEqual(z1.shape, [8, 7])
68+
self.assertEqual(z2.shape, [8, 5])
69+
70+
z_np = z.numpy()
71+
np.testing.assert_allclose(z_np[:, :7], z1.numpy())
72+
np.testing.assert_allclose(z_np[:, 7:], z2.numpy())
73+
74+
def test_split_grad(self):
75+
"""Test backprop for split, in1 and in2 are computed by
76+
compat.split and original split"""
77+
78+
def get_tensors():
79+
np.random.seed(114514)
80+
np_arr = np.random.normal(0, 1, [2, 3, 4, 5])
81+
return paddle.to_tensor(np_arr), paddle.to_tensor(np_arr)
82+
83+
in1, in2 = get_tensors()
84+
in1.stop_gradient = False
85+
in2.stop_gradient = False
86+
87+
def computation_graph(in_tensor):
88+
y = in_tensor * 2.3 + 3.0
89+
y = paddle.maximum(y, paddle.to_tensor([0], dtype=paddle.float32))
90+
return y.mean(axis=0)
91+
92+
out1 = computation_graph(in1)
93+
out2 = computation_graph(in2)
94+
95+
packs1 = paddle.compat.split(out1, 2, dim=2)
96+
packs2 = paddle.split(out2, [2, 2, 1], axis=2)
97+
98+
res1 = packs1[0] + packs1[1] + packs1[2]
99+
res2 = packs2[0] + packs2[1] + packs2[2]
100+
res1.backward()
101+
res2.backward()
102+
np.testing.assert_allclose(in1.grad.numpy(), in2.grad.numpy())
103+
104+
def test_empty_dim(self):
105+
"""Split with empty dim"""
106+
in_tensor = paddle.arange(72, dtype=paddle.int64).reshape([3, 12, 2])
107+
self._compare_with_origin(in_tensor, [5, 0, 7], axis=1)
108+
109+
def test_split_with_one_block(self):
110+
"""Resulting tuple should be of length 1"""
111+
in_tensor = paddle.arange(60, dtype=paddle.float32).reshape([3, 4, 5])
112+
self._compare_with_origin(in_tensor, 5, paddle.to_tensor([-1]))
113+
self._compare_with_origin(in_tensor, [5], paddle.to_tensor(2))
114+
115+
def test_edge_cases(self):
116+
"""Test edge cases and error handling"""
117+
x = paddle.arange(5)
118+
s1, s2 = split(x, [3, 2])
119+
np.testing.assert_allclose(s1.numpy(), [0, 1, 2])
120+
np.testing.assert_allclose(s2.numpy(), [3, 4])
121+
122+
x = paddle.rand([2, 2, 2])
123+
a, b = split(x, 1, 2)
124+
self.assertEqual(a.shape, [2, 2, 1])
125+
126+
# invalid split sections
127+
with self.assertRaises(ValueError):
128+
split(x, [3, 1], 1)
129+
130+
# invalid split axis
131+
with self.assertRaises(ValueError):
132+
split(x, 2, 3)
133+
134+
def test_error_hint(self):
135+
"""Test whether there will be correct exception when users pass paddle.split kwargs in paddle.compat.split, vice versa."""
136+
x = paddle.randn([3, 9, 5])
137+
138+
msg_gt_1 = (
139+
"paddle.split() received unexpected keyword arguments 'dim', 'split_size_or_sections', 'tensor'. "
140+
"\nDid you mean to use paddle.compat.split() instead?"
141+
)
142+
msg_gt_2 = (
143+
"paddle.compat.split() received unexpected keyword argument 'num_or_sections'. "
144+
"\nDid you mean to use paddle.split() instead?"
145+
)
146+
msg_gt_3 = "(InvalidArgument) The dim is expected to be in range of [-3, 3), but got 3"
147+
msg_gt_4 = "paddle.compat.split expects split_sizes have only non-negative entries, but got size = -5 on dim 2"
148+
149+
split_size = paddle.to_tensor([3])
150+
msg_gt_5 = (
151+
"The type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode, but "
152+
f"received {type(split_size)}."
153+
)
154+
155+
with self.assertRaises(TypeError) as cm:
156+
tensors = paddle.split(tensor=x, split_size_or_sections=3, dim=0)
157+
self.assertEqual(str(cm.exception), msg_gt_1)
158+
159+
with self.assertRaises(TypeError) as cm:
160+
tensors = split(x, num_or_sections=3, dim=0)
161+
self.assertEqual(str(cm.exception), msg_gt_2)
162+
163+
with self.assertRaises(ValueError) as cm:
164+
tensors = split(x, 3, dim=3)
165+
self.assertEqual(str(cm.exception), msg_gt_3)
166+
167+
with self.assertRaises(ValueError) as cm:
168+
tensors = split(x, [3, 3, -5], -2)
169+
self.assertEqual(str(cm.exception), msg_gt_4)
170+
171+
with self.assertRaises(TypeError) as cm:
172+
tensors = split(x, split_size, 1)
173+
self.assertEqual(str(cm.exception), msg_gt_5)
174+
175+
176+
if __name__ == '__main__':
177+
unittest.main()
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
19+
import paddle
20+
from paddle.compat import split
21+
22+
23+
class TestCompatSplitStatic(unittest.TestCase):
24+
def _compare_with_origin_static(
25+
self, input_shape, size, axis=0, dim_rank=-1
26+
):
27+
"""size_dim: -1 means we input size by int, 0 means 0-size tensor, 1 means tensor with shape [1]"""
28+
numel = 1
29+
for v in input_shape:
30+
numel *= v
31+
input_axis = axis
32+
if dim_rank == 0:
33+
input_axis = paddle.to_tensor(axis)
34+
elif dim_rank == 1:
35+
input_axis = paddle.to_tensor([axis])
36+
paddle.enable_static()
37+
with paddle.static.program_guard(paddle.static.Program()):
38+
input_tensor = paddle.arange(numel, dtype=paddle.float32).reshape(
39+
input_shape
40+
)
41+
pd_results = split(input_tensor, size, dim=input_axis)
42+
43+
if isinstance(size, int):
44+
shape_on_axis = input_tensor.shape[axis]
45+
remaining_num = shape_on_axis % size
46+
num_sections = shape_on_axis // size
47+
if remaining_num == 0:
48+
size = num_sections
49+
else:
50+
size = [size for _ in range(num_sections)]
51+
size.append(remaining_num)
52+
53+
origin_results = paddle.split(
54+
input_tensor, num_or_sections=size, axis=axis
55+
)
56+
assert len(pd_results) == len(origin_results), "length mismatched"
57+
place = (
58+
paddle.CUDAPlace(0)
59+
if paddle.is_compiled_with_cuda()
60+
else paddle.CPUPlace()
61+
)
62+
exe = paddle.static.Executor(place)
63+
results = exe.run(fetch_list=[*origin_results, *pd_results])
64+
length_needed = len(results) // 2
65+
for i in range(length_needed):
66+
np.testing.assert_allclose(
67+
results[i], results[i + length_needed]
68+
)
69+
paddle.disable_static()
70+
71+
def test_split_composite_static(self):
72+
paddle.seed(114514)
73+
74+
def get_tensors():
75+
np.random.seed(114514)
76+
np_arr = np.random.normal(0, 1, [2, 3, 4, 5])
77+
return paddle.to_tensor(np_arr), paddle.to_tensor(np_arr)
78+
79+
in1, in2 = get_tensors()
80+
in1.stop_gradient = False
81+
in2.stop_gradient = False
82+
83+
@paddle.jit.to_static
84+
def computation_graph(in1: paddle.Tensor, in2: paddle.Tensor):
85+
y1 = in1 * 1.5 + 1.0
86+
y1 = paddle.minimum(y1, paddle.to_tensor([0], dtype=paddle.float32))
87+
out1 = y1.mean(axis=0)
88+
89+
y2 = in2 * 1.5 + 1.0
90+
y2 = paddle.minimum(y2, paddle.to_tensor([0], dtype=paddle.float32))
91+
out2 = y2.mean(axis=0)
92+
93+
packs1 = paddle.compat.split(out1, 2, dim=2)
94+
packs2 = paddle.split(out2, [2, 2, 1], axis=2)
95+
96+
res1 = packs1[0] + packs1[1] + packs1[2]
97+
res2 = packs2[0] + packs2[1] + packs2[2]
98+
99+
return res1, res2
100+
101+
res1, res2 = computation_graph(in1, in2)
102+
np.testing.assert_allclose(res1.numpy(), res2.numpy())
103+
104+
def test_static_graph(self):
105+
"""Test static graph execution"""
106+
# fixed random seed for reproducibility
107+
np.random.seed(114514)
108+
# old static graph mode
109+
paddle.enable_static()
110+
111+
with paddle.static.program_guard(paddle.static.Program()):
112+
x = paddle.static.data(name='x', shape=[None, 6], dtype='float32')
113+
result0, result1 = split(x, split_size_or_sections=[3, 3], dim=1)
114+
output = result0 * 2.0 + paddle.sin(result1)
115+
116+
place = (
117+
paddle.CUDAPlace(0)
118+
if paddle.is_compiled_with_cuda()
119+
else paddle.CPUPlace()
120+
)
121+
exe = paddle.static.Executor(place)
122+
123+
input_data = np.random.rand(3, 6).astype('float32')
124+
feed = {'x': input_data}
125+
126+
results = exe.run(feed=feed, fetch_list=[result0, result1, output])
127+
128+
pd_result0, pd_result1 = results[0], results[1]
129+
np.testing.assert_allclose(input_data[:, :3], pd_result0)
130+
np.testing.assert_allclose(input_data[:, 3:], pd_result1)
131+
132+
expected_output = input_data[:, :3] * 2.0 + np.sin(
133+
input_data[:, 3:]
134+
)
135+
np.testing.assert_allclose(
136+
expected_output, results[2], rtol=1e-4, atol=1e-4
137+
)
138+
139+
paddle.disable_static()
140+
141+
def test_error_hint(self):
142+
"""Test whether there will be correct exception when users pass paddle.split kwargs in paddle.compat.split, vice versa."""
143+
144+
msg_gt_1 = "split_size_or_sections must be greater than 0."
145+
msg_gt_2 = "len(split_size_or_sections) must not be more than input.shape[dim]."
146+
msg_gt_3 = "The type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode."
147+
msg_gt_4 = (
148+
"'dim' is not allowed to be a pir.Value in a static graph: "
149+
"\npir.Value can not be used for indexing python lists/tuples."
150+
)
151+
152+
paddle.enable_static()
153+
with self.assertRaises(AssertionError) as cm:
154+
x = paddle.randn([3, 4, 5])
155+
tensors = split(x, -2, dim=0)
156+
self.assertEqual(str(cm.exception), msg_gt_1)
157+
158+
with self.assertRaises(AssertionError) as cm:
159+
x = paddle.randn([3, 4, 5])
160+
tensors = split(x, (1, 1, 1, 1, 2, 2), dim=-1)
161+
self.assertEqual(str(cm.exception), msg_gt_2)
162+
163+
with self.assertRaises(TypeError) as cm:
164+
x = paddle.randn([3, 4, 5])
165+
tensors = split(x, paddle.to_tensor(2), dim=2)
166+
self.assertEqual(str(cm.exception), msg_gt_3)
167+
168+
with self.assertRaises(TypeError) as cm:
169+
x = paddle.randn([3, 4, 5])
170+
tensors = split(x, 2, dim=paddle.to_tensor(2))
171+
paddle.disable_static()
172+
self.assertEqual(str(cm.exception), msg_gt_4)
173+
174+
def test_basic_split(self):
175+
"""Test basic splitting with integer size"""
176+
input_shape = [3, 6]
177+
self._compare_with_origin_static(input_shape, 1, 0)
178+
self._compare_with_origin_static(input_shape, 3, -1)
179+
self._compare_with_origin_static(input_shape, 4, dim_rank=0)
180+
self._compare_with_origin_static(input_shape, 3, dim_rank=1)
181+
182+
183+
if __name__ == '__main__':
184+
unittest.main()

0 commit comments

Comments
 (0)