Skip to content

Commit 3524f4a

Browse files
zhengshengningmaxiaolong001
authored andcommitted
[API compatibility] add Alias : paddle.prod & paddle.no_grad & paddle.reshape & paddle.Tensor.bitwise_or_ (PaddlePaddle#74480)
* add Alias : prod no_grad reshape bitwise_or_ * opt Decorator * opt Decorator2 * opt reshape alias * fix reshape alias * fix reshape alias * fix reshape alias2
1 parent 269e005 commit 3524f4a

File tree

9 files changed

+298
-24
lines changed

9 files changed

+298
-24
lines changed

python/paddle/base/dygraph/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from paddle.base import core, framework
3333
from paddle.base.framework import global_var
3434
from paddle.base.multiprocess_utils import CleanupFuncRegistrar
35+
from paddle.utils.decorator_utils import ParamAliasDecorator
3536

3637
from ..framework import _get_paddle_place
3738
from ..wrapped_decorator import (
@@ -323,6 +324,7 @@ def no_grad(func: None = ...) -> AbstractContextManager: ...
323324
def no_grad(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]: ...
324325

325326

327+
@ParamAliasDecorator({"func": ["orig_func"]})
326328
def no_grad(func=None):
327329
"""
328330
:api_attr: imperative

python/paddle/sparse/unary.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
)
2929
from paddle.common_ops_import import Variable
3030
from paddle.framework import LayerHelper
31+
from paddle.utils.decorator_utils import (
32+
param_one_alias,
33+
)
3134

3235
if TYPE_CHECKING:
3336
from collections.abc import Sequence
@@ -879,6 +882,7 @@ def expm1(x: Tensor, name: str | None = None) -> Tensor:
879882
return _C_ops.sparse_expm1(x)
880883

881884

885+
@param_one_alias({"x": "input"})
882886
def reshape(x: Tensor, shape: ShapeLike, name: str | None = None) -> Tensor:
883887
"""
884888
Changes the shape of ``x`` without changing its value, requiring x to be a SparseCooTensor or SparseCsrTensor.

python/paddle/tensor/logic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from paddle import _C_ops
2323
from paddle.tensor.creation import full
2424
from paddle.tensor.math import broadcast_shape
25+
from paddle.utils.decorator_utils import ParamAliasDecorator
2526
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
2627

2728
from ..base.data_feeder import check_type, check_variable_and_dtype
@@ -1329,6 +1330,7 @@ def bitwise_and_(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
13291330
return _C_ops.bitwise_and_(x, y)
13301331

13311332

1333+
@ParamAliasDecorator({"x": ["input"], "y": ["other"]})
13321334
def bitwise_or(
13331335
x: Tensor, y: Tensor, out: Tensor | None = None, name: str | None = None
13341336
) -> Tensor:
@@ -1389,6 +1391,7 @@ def __ror__(
13891391

13901392

13911393
@inplace_apis_in_dygraph_only
1394+
@ParamAliasDecorator({"x": ["input"], "y": ["other"]})
13921395
def bitwise_or_(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
13931396
r"""
13941397
Inplace version of ``bitwise_or`` API, the output Tensor will be inplaced with input ``x``.

python/paddle/tensor/manipulation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
import paddle
2525
from paddle import _C_ops
2626
from paddle.tensor import fill_constant
27-
from paddle.utils.decorator_utils import ParamAliasDecorator
27+
from paddle.utils.decorator_utils import (
28+
ParamAliasDecorator,
29+
param_one_alias,
30+
)
2831
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
2932

3033
from ..base.data_feeder import (
@@ -4976,6 +4979,7 @@ def get_attr_expand_shape(list_expand_shape):
49764979
return out
49774980

49784981

4982+
@param_one_alias({"x": "input"})
49794983
def reshape(x: Tensor, shape: ShapeLike, name: str | None = None) -> Tensor:
49804984
"""
49814985
Changes the shape of ``x`` without changing its data.

python/paddle/tensor/math.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from paddle.base.libpaddle import DataType
2626
from paddle.common_ops_import import VarDesc, dygraph_utils
2727
from paddle.pir import Value
28+
from paddle.utils.decorator_utils import ParamAliasDecorator
2829
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
2930

3031
from ..base.data_feeder import (
@@ -4949,6 +4950,7 @@ def isnan(x: Tensor, name: str | None = None) -> Tensor:
49494950
return out
49504951

49514952

4953+
@ParamAliasDecorator({"x": ["input"], "axis": ["dim"]})
49524954
def prod(
49534955
x: Tensor,
49544956
axis: int | Sequence[int] | None = None,

python/paddle/utils/decorator_utils.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class DecoratorBase:
2727
"""
2828

2929
def __init__(self, *args: Any, **kwargs: Any) -> None:
30-
"""Initialize decorator parameters"""
3130
self.args = args
3231
self.kwargs = kwargs
3332

@@ -38,25 +37,15 @@ def __call__(self, func: _F) -> _F:
3837
def wrapper(*args, **kwargs):
3938
# Pretreatment parameters
4039
processed_args, processed_kwargs = self.process(args, kwargs)
41-
# Call the original function
4240
return func(*processed_args, **processed_kwargs)
4341

44-
# Keep original signature
4542
wrapper.__signature__ = inspect.signature(func)
4643
return cast("_F", wrapper)
4744

4845
def process(
4946
self, args: tuple[Any, ...], kwargs: dict[str, Any]
5047
) -> tuple[tuple[Any, ...], dict[str, Any]]:
51-
"""Core processing methods that subclasses must implement.
52-
53-
Args:
54-
args: positional parameter
55-
kwargs: Keyword Argument
56-
57-
Returns:
58-
Processed tuples (args, kwargs)
59-
"""
48+
"""To be implemented by subclass"""
6049
raise NotImplementedError("Subclasses must implement this method")
6150

6251

@@ -66,31 +55,58 @@ class ParamAliasDecorator(DecoratorBase):
6655

6756
def __init__(self, alias_mapping: dict[str, Iterable[str]]) -> None:
6857
super().__init__()
58+
# Check alias_mapping types
6959
if not isinstance(alias_mapping, dict):
7060
raise TypeError("alias_mapping must be a dictionary")
7161
for k, v in alias_mapping.items():
7262
if not isinstance(v, (list, tuple, set)):
7363
raise TypeError(f"Aliases for '{k}' must be iterable")
74-
self.alias_mapping = alias_mapping
64+
65+
# Build a reverse alias map for faster lookup
66+
self.alias_mapping = {}
67+
for original, aliases in alias_mapping.items():
68+
for alias in aliases:
69+
self.alias_mapping[alias] = original
7570

7671
def process(
7772
self, args: tuple[Any, ...], kwargs: dict[str, Any]
7873
) -> tuple[tuple[Any, ...], dict[str, Any]]:
74+
"""Process parameters to handle alias mapping"""
7975
if not kwargs:
8076
return args, kwargs
81-
processed_kwargs = kwargs.copy()
82-
for original, aliases in self.alias_mapping.items():
83-
for alias in aliases:
84-
if alias in processed_kwargs:
85-
if original not in processed_kwargs:
86-
processed_kwargs[original] = processed_kwargs.pop(alias)
87-
else:
88-
raise ValueError(
89-
f"Cannot specify both '{original}' and its alias '{alias}'"
90-
)
77+
78+
processed_kwargs = kwargs
79+
alias_mapping = self.alias_mapping
80+
81+
# Directly modify kwargs based on alias mapping (only modify if necessary)
82+
for alias, original in alias_mapping.items():
83+
if alias in processed_kwargs:
84+
if original not in processed_kwargs:
85+
# Only modify the dictionary if necessary
86+
processed_kwargs[original] = processed_kwargs.pop(alias)
87+
else:
88+
raise ValueError(
89+
f"Cannot specify both '{original}' and its alias '{alias}'"
90+
)
91+
9192
return args, processed_kwargs
9293

9394

95+
def param_one_alias(alias_mapping):
96+
def decorator(func):
97+
def wrapper(*args, **kwargs):
98+
if not kwargs:
99+
return func(*args, **kwargs)
100+
if ("input" in kwargs) and ("x" not in kwargs):
101+
kwargs["x"] = kwargs.pop("input")
102+
return func(*args, **kwargs)
103+
104+
wrapper.__signature__ = inspect.signature(func)
105+
return wrapper
106+
107+
return decorator
108+
109+
94110
# *size => shape decorator
95111
class SizeArgsDecorator(DecoratorBase):
96112
"""

test/legacy_test/test_inplace.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,6 +1529,22 @@ def non_inplace_api_processing(self, var):
15291529
return paddle.bitwise_or(var, self.y)
15301530

15311531

1532+
class TestDygraphInplacBitwiseOrAlias1(TestDygraphInplacBitwiseAnd):
1533+
def inplace_api_processing(self, var):
1534+
return paddle.bitwise_or_(var, other=self.y)
1535+
1536+
def non_inplace_api_processing(self, var):
1537+
return paddle.bitwise_or(var, other=self.y)
1538+
1539+
1540+
class TestDygraphInplacBitwiseOrAlias2(TestDygraphInplacBitwiseAnd):
1541+
def inplace_api_processing(self, var):
1542+
return paddle.bitwise_or_(input=var, other=self.y)
1543+
1544+
def non_inplace_api_processing(self, var):
1545+
return paddle.bitwise_or(input=var, other=self.y)
1546+
1547+
15321548
class TestDygraphInplacBitwiseXor(TestDygraphInplacBitwiseAnd):
15331549
def inplace_api_processing(self, var):
15341550
return paddle.bitwise_xor_(var, self.y)

test/legacy_test/test_prod_op.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,5 +311,118 @@ def run_imperative(self, place):
311311
np.testing.assert_allclose(out.numpy(), input.numpy())
312312

313313

314+
class TestProdAliasOp(unittest.TestCase):
315+
def setUp(self):
316+
self.input = np.random.random(size=(10, 10, 5)).astype(np.float32)
317+
318+
def run_imperative(self, place):
319+
input = paddle.to_tensor(self.input, place=place)
320+
out = paddle.prod(input=input)
321+
expected_result = np.prod(self.input)
322+
np.testing.assert_allclose(out.numpy(), expected_result, rtol=1e-05)
323+
324+
out = paddle.prod(input, dim=1)
325+
expected_result = np.prod(self.input, axis=1)
326+
np.testing.assert_allclose(out.numpy(), expected_result, rtol=1e-05)
327+
328+
out = paddle.prod(input=input, dim=-1)
329+
expected_result = np.prod(self.input, axis=-1)
330+
np.testing.assert_allclose(out.numpy(), expected_result, rtol=1e-05)
331+
332+
out = paddle.prod(input, dim=[0, 1])
333+
expected_result = np.prod(self.input, axis=(0, 1))
334+
np.testing.assert_allclose(
335+
out.numpy(), expected_result, rtol=1e-05, atol=1e-8
336+
)
337+
338+
out = paddle.prod(input, dim=1, keepdim=True)
339+
expected_result = np.prod(self.input, axis=1, keepdims=True)
340+
np.testing.assert_allclose(out.numpy(), expected_result, rtol=1e-05)
341+
342+
out = paddle.prod(input=input, dim=1, dtype='int64')
343+
expected_result = np.prod(self.input, axis=1, dtype=np.int64)
344+
np.testing.assert_allclose(out.numpy(), expected_result, rtol=1e-05)
345+
346+
out = paddle.prod(input=input, dim=1, keepdim=True, dtype='int64')
347+
expected_result = np.prod(
348+
self.input, axis=1, keepdims=True, dtype=np.int64
349+
)
350+
np.testing.assert_allclose(out.numpy(), expected_result, rtol=1e-05)
351+
352+
def run_static(self, use_gpu=False):
353+
with paddle.static.program_guard(paddle.static.Program()):
354+
input = paddle.static.data(
355+
name='input', shape=[10, 10, 5], dtype='float32'
356+
)
357+
result0 = paddle.prod(input=input)
358+
result1 = paddle.prod(input, dim=1)
359+
result2 = paddle.prod(input=input, dim=-1)
360+
result3 = paddle.prod(input, dim=[0, 1])
361+
result4 = paddle.prod(input, dim=1, keepdim=True)
362+
result5 = paddle.prod(input=input, dim=1, dtype='int64')
363+
result6 = paddle.prod(input, dim=1, keepdim=True, dtype='int64')
364+
365+
place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace()
366+
exe = paddle.static.Executor(place)
367+
exe.run(paddle.static.default_startup_program())
368+
static_result = exe.run(
369+
feed={"input": self.input},
370+
fetch_list=[
371+
result0,
372+
result1,
373+
result2,
374+
result3,
375+
result4,
376+
result5,
377+
result6,
378+
],
379+
)
380+
381+
expected_result = np.prod(self.input)
382+
np.testing.assert_allclose(
383+
static_result[0], expected_result, rtol=1e-05
384+
)
385+
expected_result = np.prod(self.input, axis=1)
386+
np.testing.assert_allclose(
387+
static_result[1], expected_result, rtol=1e-05
388+
)
389+
expected_result = np.prod(self.input, axis=-1)
390+
np.testing.assert_allclose(
391+
static_result[2], expected_result, rtol=1e-05
392+
)
393+
expected_result = np.prod(self.input, axis=(0, 1))
394+
np.testing.assert_allclose(
395+
static_result[3], expected_result, rtol=1e-05, atol=1e-8
396+
)
397+
expected_result = np.prod(self.input, axis=1, keepdims=True)
398+
np.testing.assert_allclose(
399+
static_result[4], expected_result, rtol=1e-05
400+
)
401+
expected_result = np.prod(self.input, axis=1, dtype=np.int64)
402+
np.testing.assert_allclose(
403+
static_result[5], expected_result, rtol=1e-05
404+
)
405+
expected_result = np.prod(
406+
self.input, axis=1, keepdims=True, dtype=np.int64
407+
)
408+
np.testing.assert_allclose(
409+
static_result[6], expected_result, rtol=1e-05
410+
)
411+
412+
def test_cpu(self):
413+
with dygraph_guard():
414+
self.run_imperative(place=paddle.CPUPlace())
415+
with static_guard():
416+
self.run_static()
417+
418+
def test_gpu(self):
419+
if not paddle.base.core.is_compiled_with_cuda():
420+
return
421+
with dygraph_guard():
422+
self.run_imperative(place=paddle.CUDAPlace(0))
423+
with static_guard():
424+
self.run_static()
425+
426+
314427
if __name__ == "__main__":
315428
unittest.main()

0 commit comments

Comments
 (0)