diff --git a/python/paddle/compat.py b/python/paddle/compat.py index 179207174cd7bc..3f08e9dff26a89 100644 --- a/python/paddle/compat.py +++ b/python/paddle/compat.py @@ -184,9 +184,19 @@ def disable_torch_proxy(): @contextmanager -def use_torch_proxy_guard(): - enable_torch_proxy() - try: - yield - finally: +def use_torch_proxy_guard(enable: bool = True): + already_has_torch_proxy = TORCH_PROXY_FINDER in sys.meta_path + if enable == already_has_torch_proxy: + return + if enable: + enable_torch_proxy() + try: + yield + finally: + disable_torch_proxy() + else: disable_torch_proxy() + try: + yield + finally: + enable_torch_proxy() diff --git a/test/compat/test_torch_proxy.py b/test/compat/test_torch_proxy.py index 8be43c9f813a9c..80b43f20f4317a 100644 --- a/test/compat/test_torch_proxy.py +++ b/test/compat/test_torch_proxy.py @@ -64,6 +64,19 @@ def test_use_torch_proxy_guard(self): with self.assertRaises(ModuleNotFoundError): import torch + with paddle.compat.use_torch_proxy_guard(): + import torch + + self.assertIs(torch.cos, paddle.cos) + with paddle.compat.use_torch_proxy_guard(enable=False): + with self.assertRaises(ModuleNotFoundError): + import torch + with paddle.compat.use_torch_proxy_guard(enable=True): + import torch + + with self.assertRaises(ModuleNotFoundError): + import torch + @paddle.compat.use_torch_proxy_guard() def test_use_torch_inside_inner_function(self): result = use_torch_inside_inner_function()