Skip to content

Commit 20d9626

Browse files
SigureMoCopilot
andauthored
[Compat] Auto register compat module overrides when enable torch proxy (PaddlePaddle#76522)
--------- Co-authored-by: Copilot <[email protected]>
1 parent b5efb98 commit 20d9626

File tree

2 files changed

+110
-5
lines changed

2 files changed

+110
-5
lines changed

python/paddle/compat/proxy.py

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,22 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import importlib
1618
import importlib.abc
1719
import importlib.util
1820
import inspect
21+
import pkgutil
1922
import sys
2023
import types
2124
import warnings
22-
from collections.abc import Iterable
2325
from contextlib import contextmanager
24-
from typing import Any
26+
from functools import cache
27+
from typing import TYPE_CHECKING, Any
28+
29+
if TYPE_CHECKING:
30+
from collections.abc import Iterable
2531

2632

2733
def warning_about_fake_interface(name: str):
@@ -48,12 +54,38 @@ def create_fake_function(name):
4854
return fn
4955

5056

57+
class OverriddenAttribute:
58+
def get_value(self):
59+
raise NotImplementedError
60+
61+
62+
class LazyImportOverriddenAttribute(OverriddenAttribute):
63+
def __init__(self, full_name: str):
64+
self._full_name = full_name
65+
66+
def get_value(self):
67+
parts = self._full_name.split(".")
68+
root_module = importlib.import_module(parts[0])
69+
result = root_module
70+
for part in parts[1:]:
71+
result = getattr(result, part)
72+
return result
73+
74+
75+
class RawOverriddenAttribute(OverriddenAttribute):
76+
def __init__(self, value: Any):
77+
self._value = value
78+
79+
def get_value(self):
80+
return self._value
81+
82+
5183
class ProxyModule(types.ModuleType):
5284
def __init__(
5385
self,
5486
original_module: types.ModuleType,
5587
proxy_name: str,
56-
overrides: dict[str, Any],
88+
overrides: dict[str, OverriddenAttribute],
5789
):
5890
super().__init__(proxy_name)
5991
self._original_module = original_module
@@ -62,18 +94,56 @@ def __init__(
6294

6395
def __getattr__(self, name: str) -> Any:
6496
if name in self._overrides:
65-
return self._overrides[name]
97+
return self._overrides[name].get_value()
6698
return getattr(self._original_module, name)
6799

68100

69-
GLOBAL_OVERRIDES = {}
101+
GLOBAL_OVERRIDES: dict[str, OverriddenAttribute] = {
102+
"torch.relu": LazyImportOverriddenAttribute("paddle.nn.functional.relu"),
103+
}
70104

71105
TORCH_PROXY_BLOCKED_MODULES = {
72106
"tvm_ffi",
73107
"transformers",
74108
}
75109

76110

111+
def _extend_torch_proxy_overrides(
112+
overrides: dict[str, OverriddenAttribute],
113+
) -> None:
114+
GLOBAL_OVERRIDES.update(overrides)
115+
116+
117+
@cache
118+
def _register_compat_override():
119+
import paddle.compat
120+
121+
PADDLE_PREFIX = "paddle.compat"
122+
TORCH_PREFIX = "torch"
123+
PUBLIC_ATTR_DECLARATION = "__all__"
124+
125+
compat_overrides = {}
126+
for module_info in pkgutil.walk_packages(
127+
paddle.compat.__path__,
128+
paddle.compat.__name__ + ".",
129+
):
130+
module = importlib.import_module(module_info.name)
131+
if hasattr(module, PUBLIC_ATTR_DECLARATION):
132+
public_attrs = getattr(module, PUBLIC_ATTR_DECLARATION)
133+
torch_module_name = module_info.name.replace(
134+
PADDLE_PREFIX, TORCH_PREFIX, 1
135+
)
136+
for attr_name in public_attrs:
137+
if attr_name.startswith("_"):
138+
continue
139+
paddle_attr = getattr(module, attr_name)
140+
torch_attr_name = f"{torch_module_name}.{attr_name}"
141+
compat_overrides[torch_attr_name] = RawOverriddenAttribute(
142+
paddle_attr
143+
)
144+
_extend_torch_proxy_overrides(compat_overrides)
145+
146+
77147
def _is_specific_module_or_its_submodule(name: str, module: str) -> bool:
78148
return name == module or name.startswith(f"{module}.")
79149

@@ -189,6 +259,18 @@ def exec_module(self, module):
189259
for k, v in self._source.__dict__.items():
190260
if k in ("__name__", "__package__", "__path__", "__spec__"):
191261
continue
262+
if k in overrides:
263+
continue
264+
if isinstance(v, types.ModuleType):
265+
v = ProxyModule(
266+
v,
267+
f"{self._target_name}.{k}",
268+
{
269+
kk.removeprefix(f"{k}."): vv
270+
for kk, vv in overrides.items()
271+
if kk.startswith(f"{k}.")
272+
},
273+
)
192274
module.__dict__[k] = v
193275

194276
# Use fullname for the spec name and mark as package when appropriate so that
@@ -223,6 +305,7 @@ def enable_torch_proxy() -> None:
223305
>>> import torch # This will import paddle as torch
224306
>>> assert torch.sin is paddle.sin
225307
"""
308+
_register_compat_override()
226309
_clear_torch_modules()
227310
sys.meta_path.insert(0, TORCH_PROXY_FINDER)
228311

test/compat/test_torch_proxy.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,28 @@ def test_blocked_module(self):
105105
torch_proxy_blocked_module.use_torch_specific_fn()
106106

107107

108+
class TestOverrideTorchModule(unittest.TestCase):
109+
@paddle.compat.use_torch_proxy_guard()
110+
def test_relu(self):
111+
import torch
112+
113+
self.assertIs(torch.relu, paddle.nn.functional.relu)
114+
115+
@paddle.compat.use_torch_proxy_guard()
116+
def test_access_compat_functions_by_getattr(self):
117+
import torch
118+
119+
self.assertIs(torch.nn.Unfold, paddle.compat.nn.Unfold)
120+
self.assertIs(torch.nn.Linear, paddle.compat.nn.Linear)
121+
122+
@paddle.compat.use_torch_proxy_guard()
123+
def test_access_compat_functions_by_import(self):
124+
from torch.nn.functional import linear, softmax
125+
126+
self.assertIs(softmax, paddle.compat.nn.functional.softmax)
127+
self.assertIs(linear, paddle.compat.nn.functional.linear)
128+
129+
108130
class TestFakeInterface(unittest.TestCase):
109131
def test_fake_interface(self):
110132
FakeGenerator = create_fake_class(

0 commit comments

Comments
 (0)