1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from __future__ import annotations
16+
1517import importlib
1618import importlib .abc
1719import importlib .util
1820import inspect
21+ import pkgutil
1922import sys
2023import types
2124import warnings
22- from collections .abc import Iterable
2325from contextlib import contextmanager
24- from typing import Any
26+ from functools import cache
27+ from typing import TYPE_CHECKING , Any
28+
29+ if TYPE_CHECKING :
30+ from collections .abc import Iterable
2531
2632
2733def warning_about_fake_interface (name : str ):
@@ -48,12 +54,38 @@ def create_fake_function(name):
4854 return fn
4955
5056
57+ class OverriddenAttribute :
58+ def get_value (self ):
59+ raise NotImplementedError
60+
61+
62+ class LazyImportOverriddenAttribute (OverriddenAttribute ):
63+ def __init__ (self , full_name : str ):
64+ self ._full_name = full_name
65+
66+ def get_value (self ):
67+ parts = self ._full_name .split ("." )
68+ root_module = importlib .import_module (parts [0 ])
69+ result = root_module
70+ for part in parts [1 :]:
71+ result = getattr (result , part )
72+ return result
73+
74+
75+ class RawOverriddenAttribute (OverriddenAttribute ):
76+ def __init__ (self , value : Any ):
77+ self ._value = value
78+
79+ def get_value (self ):
80+ return self ._value
81+
82+
5183class ProxyModule (types .ModuleType ):
5284 def __init__ (
5385 self ,
5486 original_module : types .ModuleType ,
5587 proxy_name : str ,
56- overrides : dict [str , Any ],
88+ overrides : dict [str , OverriddenAttribute ],
5789 ):
5890 super ().__init__ (proxy_name )
5991 self ._original_module = original_module
@@ -62,18 +94,56 @@ def __init__(
6294
6395 def __getattr__ (self , name : str ) -> Any :
6496 if name in self ._overrides :
65- return self ._overrides [name ]
97+ return self ._overrides [name ]. get_value ()
6698 return getattr (self ._original_module , name )
6799
68100
69- GLOBAL_OVERRIDES = {}
101+ GLOBAL_OVERRIDES : dict [str , OverriddenAttribute ] = {
102+ "torch.relu" : LazyImportOverriddenAttribute ("paddle.nn.functional.relu" ),
103+ }
70104
71105TORCH_PROXY_BLOCKED_MODULES = {
72106 "tvm_ffi" ,
73107 "transformers" ,
74108}
75109
76110
111+ def _extend_torch_proxy_overrides (
112+ overrides : dict [str , OverriddenAttribute ],
113+ ) -> None :
114+ GLOBAL_OVERRIDES .update (overrides )
115+
116+
117+ @cache
118+ def _register_compat_override ():
119+ import paddle .compat
120+
121+ PADDLE_PREFIX = "paddle.compat"
122+ TORCH_PREFIX = "torch"
123+ PUBLIC_ATTR_DECLARATION = "__all__"
124+
125+ compat_overrides = {}
126+ for module_info in pkgutil .walk_packages (
127+ paddle .compat .__path__ ,
128+ paddle .compat .__name__ + "." ,
129+ ):
130+ module = importlib .import_module (module_info .name )
131+ if hasattr (module , PUBLIC_ATTR_DECLARATION ):
132+ public_attrs = getattr (module , PUBLIC_ATTR_DECLARATION )
133+ torch_module_name = module_info .name .replace (
134+ PADDLE_PREFIX , TORCH_PREFIX , 1
135+ )
136+ for attr_name in public_attrs :
137+ if attr_name .startswith ("_" ):
138+ continue
139+ paddle_attr = getattr (module , attr_name )
140+ torch_attr_name = f"{ torch_module_name } .{ attr_name } "
141+ compat_overrides [torch_attr_name ] = RawOverriddenAttribute (
142+ paddle_attr
143+ )
144+ _extend_torch_proxy_overrides (compat_overrides )
145+
146+
77147def _is_specific_module_or_its_submodule (name : str , module : str ) -> bool :
78148 return name == module or name .startswith (f"{ module } ." )
79149
@@ -189,6 +259,18 @@ def exec_module(self, module):
189259 for k , v in self ._source .__dict__ .items ():
190260 if k in ("__name__" , "__package__" , "__path__" , "__spec__" ):
191261 continue
262+ if k in overrides :
263+ continue
264+ if isinstance (v , types .ModuleType ):
265+ v = ProxyModule (
266+ v ,
267+ f"{ self ._target_name } .{ k } " ,
268+ {
269+ kk .removeprefix (f"{ k } ." ): vv
270+ for kk , vv in overrides .items ()
271+ if kk .startswith (f"{ k } ." )
272+ },
273+ )
192274 module .__dict__ [k ] = v
193275
194276 # Use fullname for the spec name and mark as package when appropriate so that
@@ -223,6 +305,7 @@ def enable_torch_proxy() -> None:
223305 >>> import torch # This will import paddle as torch
224306 >>> assert torch.sin is paddle.sin
225307 """
308+ _register_compat_override ()
226309 _clear_torch_modules ()
227310 sys .meta_path .insert (0 , TORCH_PROXY_FINDER )
228311
0 commit comments