File tree Expand file tree Collapse file tree 1 file changed +20
-0
lines changed Expand file tree Collapse file tree 1 file changed +20
-0
lines changed Original file line number Diff line number Diff line change 1+ from importlib import reload
2+ from unittest .mock import MagicMock , patch
3+
14import jax
25import numpy as np
36import pytest
@@ -181,3 +184,20 @@ def test_wrap_as_torch_function_single_arg_autograd_check(
181184 x_torch = torch .tensor (x_numpy , requires_grad = True )
182185 torch_function = torch_wrapper .wrap_as_torch_function (jax_function )
183186 torch .autograd .gradcheck (torch_function , x_torch )
187+
188+
189+ def test_check_pytorch_available ():
190+ try :
191+ with patch .dict ("sys.modules" , torch = None ):
192+ reload (torch_wrapper )
193+ with pytest .raises (RuntimeError , match = "torch needs to be installed" ):
194+ torch_wrapper .check_torch_available ()
195+ with patch .dict ("sys.modules" , torch = MagicMock ()):
196+ reload (torch_wrapper )
197+ # We should not get an exception here irrespective of whether torch is
198+ # installed
199+ torch_wrapper .check_torch_available ()
200+ finally :
201+ # Ensure torch_wrapper always reloaded with original state irrespective of test
202+ # passing or failing
203+ reload (torch_wrapper )
You can’t perform that action at this time.
0 commit comments