Skip to content

Commit ff5b73d

Browse files
fix reshape alias2
1 parent 6eb073c commit ff5b73d

File tree

3 files changed

+7
-12
lines changed

3 files changed

+7
-12
lines changed

python/paddle/sparse/unary.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from paddle.common_ops_import import Variable
3030
from paddle.framework import LayerHelper
3131
from paddle.utils.decorator_utils import (
32-
reshape_param_alias,
32+
param_one_alias,
3333
)
3434

3535
if TYPE_CHECKING:
@@ -882,7 +882,7 @@ def expm1(x: Tensor, name: str | None = None) -> Tensor:
882882
return _C_ops.sparse_expm1(x)
883883

884884

885-
@reshape_param_alias({"x": "input"})
885+
@param_one_alias({"x": "input"})
886886
def reshape(x: Tensor, shape: ShapeLike, name: str | None = None) -> Tensor:
887887
"""
888888
Changes the shape of ``x`` without changing its value, requiring x to be a SparseCooTensor or SparseCsrTensor.

python/paddle/tensor/manipulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from paddle.tensor import fill_constant
2727
from paddle.utils.decorator_utils import (
2828
ParamAliasDecorator,
29-
reshape_param_alias,
29+
param_one_alias,
3030
)
3131
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
3232

@@ -4979,7 +4979,7 @@ def get_attr_expand_shape(list_expand_shape):
49794979
return out
49804980

49814981

4982-
@reshape_param_alias({"x": "input"})
4982+
@param_one_alias({"x": "input"})
49834983
def reshape(x: Tensor, shape: ShapeLike, name: str | None = None) -> Tensor:
49844984
"""
49854985
Changes the shape of ``x`` without changing its data.

python/paddle/utils/decorator_utils.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,13 @@ def process(
9292
return args, processed_kwargs
9393

9494

95-
def reshape_param_alias(alias_mapping):
95+
def param_one_alias(alias_mapping):
9696
def decorator(func):
9797
def wrapper(*args, **kwargs):
9898
if not kwargs:
9999
return func(*args, **kwargs)
100-
for original, alias in alias_mapping.items():
101-
if alias in kwargs:
102-
if original not in kwargs:
103-
kwargs[original] = kwargs.pop(alias)
104-
# if "input" in kwargs:
105-
# if "x" not in kwargs:
106-
# kwargs["x"] = kwargs.pop("input")
100+
if ("input" in kwargs) and ("x" not in kwargs):
101+
kwargs["x"] = kwargs.pop("input")
107102
return func(*args, **kwargs)
108103

109104
wrapper.__signature__ = inspect.signature(func)

0 commit comments

Comments
 (0)