Skip to content

Commit b8478c6

Browse files
committed
[API-Compat] More unittest & static graph check & updated decorator
1 parent f4350f3 commit b8478c6

File tree

7 files changed

+342
-120
lines changed

7 files changed

+342
-120
lines changed

python/paddle/tensor/compat.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,16 @@
2929

3030
from paddle import Tensor
3131

32-
from paddle.utils.compat_kwarg_check import forbid_keywords
32+
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
3333

3434
__all__ = []
3535

3636

37-
@forbid_keywords(["x", "num_or_sections", "axis", "name"], "paddle.split")
37+
@ForbidKeywordsDecorator(
38+
illegal_keys=["x", "num_or_sections", "axis", "name"],
39+
func_name="paddle.compat.split",
40+
correct_name="paddle.split",
41+
)
3842
def split(
3943
tensor: Tensor, split_size_or_sections: int | Sequence[int], dim: int = 0
4044
) -> tuple[Tensor, ...]:
@@ -105,12 +109,13 @@ def GetSplitSize(split_size, shape_on_dim):
105109
sections.append(remaining_num)
106110
return sections
107111

108-
def SaveGetShapeOnDim(shape, dim: int) -> int:
112+
def GetShapeOnDimInRange(shape, dim: int) -> int:
109113
shape_range = len(shape)
110-
if dim < -shape_range or dim >= shape_range:
111-
raise ValueError(
112-
f"(InvalidArgument) The dim is expected to be in range of [-{shape_range}, {shape_range}), but got {dim}"
113-
)
114+
if isinstance(dim, int):
115+
if dim < -shape_range or dim >= shape_range:
116+
raise ValueError(
117+
f"(InvalidArgument) The dim is expected to be in range of [-{shape_range}, {shape_range}), but got {dim}"
118+
)
114119
return shape[dim]
115120

116121
if isinstance(split_size_or_sections, (list, tuple)):
@@ -151,7 +156,7 @@ def SaveGetShapeOnDim(shape, dim: int) -> int:
151156
), 'split_size_or_sections must be greater than 0.'
152157

153158
split_size_or_sections = GetSplitSize(
154-
split_size_or_sections, SaveGetShapeOnDim(tensor.shape, dim)
159+
split_size_or_sections, GetShapeOnDimInRange(tensor.shape, dim)
155160
)
156161

157162
if isinstance(split_size_or_sections, list):
@@ -164,7 +169,10 @@ def SaveGetShapeOnDim(shape, dim: int) -> int:
164169
return tuple(_C_ops.split(tensor, split_size_or_sections, dim))
165170
else:
166171
if isinstance(dim, paddle.pir.Value):
167-
dim.stop_gradient = True
172+
raise TypeError(
173+
"'dim' is not allowed to be a pir.Value in a static graph: "
174+
"\npir.Value can not be used for indexing python lists/tuples."
175+
)
168176
if isinstance(dim, int):
169177
assert len(tensor.shape) + dim >= 0, "(rank(x) + dim) must >= 0"
170178
dim = (len(tensor.shape) + dim) if dim < 0 else dim
@@ -173,16 +181,15 @@ def SaveGetShapeOnDim(shape, dim: int) -> int:
173181

174182
if not isinstance(split_size_or_sections, (int, list, tuple)):
175183
raise TypeError(
176-
"The type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode, but "
177-
f"received {type(split_size_or_sections)}."
184+
"The type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode."
178185
)
179186
if isinstance(split_size_or_sections, int):
180187
assert (
181188
split_size_or_sections > 0
182189
), 'split_size_or_sections must be greater than 0.'
183190

184191
split_size_or_sections = GetSplitSize(
185-
split_size_or_sections, SaveGetShapeOnDim(tensor.shape, dim)
192+
split_size_or_sections, GetShapeOnDimInRange(tensor.shape, dim)
186193
)
187194
if isinstance(split_size_or_sections, list):
188195
if paddle.utils._contain_var(split_size_or_sections):

python/paddle/tensor/manipulation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
TensorOrTensors,
5959
)
6060

61-
from paddle.utils.compat_kwarg_check import forbid_keywords
61+
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
6262

6363
__all__ = []
6464

@@ -2725,8 +2725,10 @@ def row_stack(x: Sequence[Tensor], name: str | None = None) -> Tensor:
27252725
return paddle.vstack(x, name=name)
27262726

27272727

2728-
@forbid_keywords(
2729-
["tensor", "split_size_or_sections", "dim"], "paddle.compat.split"
2728+
@ForbidKeywordsDecorator(
2729+
illegal_keys=["tensor", "split_size_or_sections", "dim"],
2730+
func_name="paddle.split",
2731+
correct_name="paddle.compat.split",
27302732
)
27312733
def split(
27322734
x: Tensor,

python/paddle/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from ..base.framework import require_version
1616
from . import ( # noqa: F401
17-
compat_kwarg_check,
1817
cpp_extension,
1918
decorator_utils,
2019
dlpack,

python/paddle/utils/compat_kwarg_check.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

python/paddle/utils/decorator_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,32 @@ def process(
105105
f"Cannot specify both '{original}' and its alias '{alias}'"
106106
)
107107
return args, processed_kwargs
108+
109+
110+
class ForbidKeywordsDecorator(DecoratorBase[_P, _R]):
111+
"""A decorator that hints users to use the correct `compat` functions, when erroneous keyword arguments are detected"""
112+
113+
def __init__(
114+
self, illegal_keys: list[str] | str, func_name: str, correct_name: str
115+
) -> None:
116+
super().__init__()
117+
self.illegal_keys = (
118+
[illegal_keys] if isinstance(illegal_keys, str) else illegal_keys
119+
)
120+
self.func_name = func_name
121+
self.correct_name = correct_name
122+
123+
def process(
124+
self, args: tuple[Any, ...], kwargs: dict[str, Any]
125+
) -> tuple[tuple[Any, ...], dict[str, Any]]:
126+
found_keys = [key for key in self.illegal_keys if key in kwargs]
127+
128+
if found_keys:
129+
keys_str = ", ".join(f"'{key}'" for key in found_keys)
130+
plural = "s" if len(found_keys) > 1 else ""
131+
132+
raise TypeError(
133+
f"{self.func_name}() received unexpected keyword argument{plural} {keys_str}. "
134+
f"\nDid you mean to use {self.correct_name}() instead?"
135+
)
136+
return args, kwargs

test/legacy_test/test_compat_split.py

Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -71,42 +71,46 @@ def test_chained_operations(self):
7171
np.testing.assert_allclose(z_np[:, :7], z1.numpy())
7272
np.testing.assert_allclose(z_np[:, 7:], z2.numpy())
7373

74-
def test_static_graph(self):
75-
"""Test static graph execution"""
76-
# fixed random seed for reproducibility
77-
np.random.seed(114514)
78-
# old static graph mode
79-
paddle.enable_static()
80-
81-
with paddle.static.program_guard(paddle.static.Program()):
82-
x = paddle.static.data(name='x', shape=[None, 6], dtype='float32')
83-
result0, result1 = split(x, split_size_or_sections=[3, 3], dim=1)
84-
output = result0 * 2.0 + paddle.sin(result1)
85-
86-
place = (
87-
paddle.CUDAPlace(0)
88-
if paddle.is_compiled_with_cuda()
89-
else paddle.CPUPlace()
90-
)
91-
exe = paddle.static.Executor(place)
92-
93-
input_data = np.random.rand(3, 6).astype('float32')
94-
feed = {'x': input_data}
95-
96-
results = exe.run(feed=feed, fetch_list=[result0, result1, output])
97-
98-
pd_result0, pd_result1 = results[0], results[1]
99-
np.testing.assert_allclose(input_data[:, :3], pd_result0)
100-
np.testing.assert_allclose(input_data[:, 3:], pd_result1)
101-
102-
expected_output = input_data[:, :3] * 2.0 + np.sin(
103-
input_data[:, 3:]
104-
)
105-
np.testing.assert_allclose(
106-
expected_output, results[2], rtol=1e-3, atol=1e-3
107-
)
108-
109-
paddle.disable_static()
74+
def test_split_grad(self):
75+
"""Test backprop for split, in1 and in2 are computed by
76+
compat.split and original split"""
77+
78+
def get_tensors():
79+
np.random.seed(114514)
80+
np_arr = np.random.normal(0, 1, [2, 3, 4, 5])
81+
return paddle.to_tensor(np_arr), paddle.to_tensor(np_arr)
82+
83+
in1, in2 = get_tensors()
84+
in1.stop_gradient = False
85+
in2.stop_gradient = False
86+
87+
def computation_graph(in_tensor):
88+
y = in_tensor * 2.3 + 3.0
89+
y = paddle.maximum(y, paddle.to_tensor([0], dtype=paddle.float32))
90+
return y.mean(axis=0)
91+
92+
out1 = computation_graph(in1)
93+
out2 = computation_graph(in2)
94+
95+
packs1 = paddle.compat.split(out1, 2, dim=2)
96+
packs2 = paddle.split(out2, [2, 2, 1], axis=2)
97+
98+
res1 = packs1[0] + packs1[1] + packs1[2]
99+
res2 = packs2[0] + packs2[1] + packs2[2]
100+
res1.backward()
101+
res2.backward()
102+
np.testing.assert_allclose(in1.grad.numpy(), in2.grad.numpy())
103+
104+
def test_empty_dim(self):
105+
"""Split with empty dim"""
106+
in_tensor = paddle.arange(72, dtype=paddle.int64).reshape([3, 12, 2])
107+
self._compare_with_origin(in_tensor, [5, 0, 7], axis=1)
108+
109+
def test_split_with_one_block(self):
110+
"""Resulting tuple should be of length 1"""
111+
in_tensor = paddle.arange(60, dtype=paddle.float32).reshape([3, 4, 5])
112+
self._compare_with_origin(in_tensor, 5, paddle.to_tensor([-1]))
113+
self._compare_with_origin(in_tensor, [5], paddle.to_tensor(2))
110114

111115
def test_edge_cases(self):
112116
"""Test edge cases and error handling"""
@@ -131,8 +135,22 @@ def test_error_hint(self):
131135
"""Test whether there will be correct exception when users pass paddle.split kwargs in paddle.compat.split, vice versa."""
132136
x = paddle.randn([3, 9, 5])
133137

134-
msg_gt_1 = "split() received unexpected keyword arguments 'tensor', 'split_size_or_sections', 'dim'. \nDid you mean to use paddle.compat.split() instead?"
135-
msg_gt_2 = "split() received unexpected keyword argument 'num_or_sections'. \nDid you mean to use paddle.split() instead?"
138+
msg_gt_1 = (
139+
"paddle.split() received unexpected keyword arguments 'tensor', 'split_size_or_sections', 'dim'. "
140+
"\nDid you mean to use paddle.compat.split() instead?"
141+
)
142+
msg_gt_2 = (
143+
"paddle.compat.split() received unexpected keyword argument 'num_or_sections'. "
144+
"\nDid you mean to use paddle.split() instead?"
145+
)
146+
msg_gt_3 = "(InvalidArgument) The dim is expected to be in range of [-3, 3), but got 3"
147+
msg_gt_4 = "paddle.compat.split expects split_sizes have only non-negative entries, but got size = -5 on dim 2"
148+
149+
split_size = paddle.to_tensor([3])
150+
msg_gt_5 = (
151+
"The type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode, but "
152+
f"received {type(split_size)}."
153+
)
136154

137155
with self.assertRaises(TypeError) as cm:
138156
tensors = paddle.split(tensor=x, split_size_or_sections=3, dim=0)
@@ -142,6 +160,18 @@ def test_error_hint(self):
142160
tensors = split(x, num_or_sections=3, dim=0)
143161
self.assertEqual(str(cm.exception), msg_gt_2)
144162

163+
with self.assertRaises(ValueError) as cm:
164+
tensors = split(x, 3, dim=3)
165+
self.assertEqual(str(cm.exception), msg_gt_3)
166+
167+
with self.assertRaises(ValueError) as cm:
168+
tensors = split(x, [3, 3, -5], -2)
169+
self.assertEqual(str(cm.exception), msg_gt_4)
170+
171+
with self.assertRaises(TypeError) as cm:
172+
tensors = split(x, split_size, 1)
173+
self.assertEqual(str(cm.exception), msg_gt_5)
174+
145175

146176
if __name__ == '__main__':
147177
unittest.main()

0 commit comments

Comments
 (0)