Skip to content

Commit 7208d9f

Browse files
committed
Add test for function checking torch available
1 parent 4990ca4 commit 7208d9f

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tests/test_torch_wrapper.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from importlib import reload
2+
from unittest.mock import MagicMock, patch
3+
14
import jax
25
import numpy as np
36
import pytest
@@ -181,3 +184,20 @@ def test_wrap_as_torch_function_single_arg_autograd_check(
181184
x_torch = torch.tensor(x_numpy, requires_grad=True)
182185
torch_function = torch_wrapper.wrap_as_torch_function(jax_function)
183186
torch.autograd.gradcheck(torch_function, x_torch)
187+
188+
189+
def test_check_pytorch_available():
190+
try:
191+
with patch.dict("sys.modules", torch=None):
192+
reload(torch_wrapper)
193+
with pytest.raises(RuntimeError, match="torch needs to be installed"):
194+
torch_wrapper.check_torch_available()
195+
with patch.dict("sys.modules", torch=MagicMock()):
196+
reload(torch_wrapper)
197+
# We should not get an exception here irrespective of whether torch is
198+
# installed
199+
torch_wrapper.check_torch_available()
200+
finally:
201+
# Ensure torch_wrapper always reloaded with original state irrespective of test
202+
# passing or failing
203+
reload(torch_wrapper)

0 commit comments

Comments
 (0)