Skip to content

Commit 4614b93

Browse files
Parameter mapping from torch to paddle(broadcast_to) (#74449)
* add torch_to_paddle_decorator * add DecoratorBase
1 parent c62a1e3 commit 4614b93

File tree

3 files changed

+110
-0
lines changed

3 files changed

+110
-0
lines changed

python/paddle/tensor/manipulation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import paddle
2525
from paddle import _C_ops
2626
from paddle.tensor import fill_constant
27+
from paddle.utils.decorator_utils import ParamAliasDecorator
2728
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
2829

2930
from ..base.data_feeder import (
@@ -4762,6 +4763,7 @@ def expand_as(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
47624763
return out
47634764

47644765

4766+
@ParamAliasDecorator({"x": ["input"], "shape": ["size"]})
47654767
def broadcast_to(
47664768
x: Tensor, shape: ShapeLike, name: str | None = None
47674769
) -> Tensor:

python/paddle/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ..base.framework import require_version
1616
from . import ( # noqa: F401
1717
cpp_extension,
18+
decorator_utils,
1819
dlpack,
1920
download,
2021
image_util,
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import functools
18+
import inspect
19+
from typing import (
20+
TYPE_CHECKING,
21+
Any,
22+
Callable,
23+
Generic,
24+
TypeVar,
25+
cast,
26+
)
27+
28+
from typing_extensions import ParamSpec
29+
30+
if TYPE_CHECKING:
31+
from collections.abc import Iterable
32+
33+
34+
_P = ParamSpec("_P")
35+
_R = TypeVar("_R")
36+
_DecoratedFunc = Callable[_P, _R]
37+
38+
39+
class DecoratorBase(Generic[_P, _R]):
40+
"""装饰器基类,提供通用装饰器框架
41+
42+
子类只需实现 `process` 方法定义核心逻辑
43+
"""
44+
45+
def __init__(self, *args: Any, **kwargs: Any) -> None:
46+
"""初始化装饰器参数"""
47+
self.args = args
48+
self.kwargs = kwargs
49+
50+
def __call__(self, func: _DecoratedFunc[_P, _R]) -> _DecoratedFunc[_P, _R]:
51+
"""作为装饰器应用的入口点"""
52+
53+
@functools.wraps(func)
54+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
55+
# 预处理参数
56+
processed_args, processed_kwargs = self.process(args, kwargs)
57+
# 调用原函数
58+
return func(*processed_args, **processed_kwargs)
59+
60+
# 保留原始签名
61+
wrapper.__signature__ = inspect.signature(func)
62+
return cast("_DecoratedFunc[_P, _R]", wrapper)
63+
64+
def process(
65+
self, args: tuple[Any, ...], kwargs: dict[str, Any]
66+
) -> tuple[tuple[Any, ...], dict[str, Any]]:
67+
"""子类必须实现的核心处理方法
68+
69+
Args:
70+
args: 位置参数
71+
kwargs: 关键字参数
72+
73+
Returns:
74+
处理后的 (args, kwargs) 元组
75+
"""
76+
raise NotImplementedError("Subclasses must implement this method")
77+
78+
79+
# 示例实现:参数别名装饰器
80+
class ParamAliasDecorator(DecoratorBase[_P, _R]):
81+
"""参数别名处理的装饰器实现"""
82+
83+
def __init__(self, alias_mapping: dict[str, Iterable[str]]) -> None:
84+
super().__init__()
85+
if not isinstance(alias_mapping, dict):
86+
raise TypeError("alias_mapping must be a dictionary")
87+
for k, v in alias_mapping.items():
88+
if not isinstance(v, (list, tuple, set)):
89+
raise TypeError(f"Aliases for '{k}' must be iterable")
90+
self.alias_mapping = alias_mapping
91+
92+
def process(
93+
self, args: tuple[Any, ...], kwargs: dict[str, Any]
94+
) -> tuple[tuple[Any, ...], dict[str, Any]]:
95+
if not kwargs:
96+
return args, kwargs
97+
processed_kwargs = kwargs.copy()
98+
for original, aliases in self.alias_mapping.items():
99+
for alias in aliases:
100+
if alias in processed_kwargs:
101+
if original not in processed_kwargs:
102+
processed_kwargs[original] = processed_kwargs.pop(alias)
103+
else:
104+
raise ValueError(
105+
f"Cannot specify both '{original}' and its alias '{alias}'"
106+
)
107+
return args, processed_kwargs

0 commit comments

Comments
 (0)