Skip to content

Commit f8d269a

Browse files
committed
Simplify conftest
1 parent f60d9af commit f8d269a

File tree

1 file changed

+7
-26
lines changed

1 file changed

+7
-26
lines changed

tests/conftest.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Callable, Generator
44
from contextlib import suppress
5-
from functools import cache, partial, wraps
5+
from functools import partial, wraps
66
from types import ModuleType
77
from typing import ParamSpec, TypeVar, cast
88

@@ -102,28 +102,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
102102
return wrapper
103103

104104

105-
@cache
106-
def _jax_cuda_device() -> Device | None:
107-
"""Return a CUDA device for JAX, if available."""
108-
import jax
109-
110-
try:
111-
return jax.devices("cuda")[0]
112-
except Exception:
113-
return None
114-
115-
116-
@cache
117-
def _torch_cuda_device() -> Device | None:
118-
"""Return a CUDA device for PyTorch, if available."""
119-
import torch
120-
121-
try:
122-
return torch.empty((0,), device=torch.device("cuda")).device
123-
except Exception:
124-
return None
125-
126-
127105
@pytest.fixture
128106
def xp(
129107
library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
@@ -168,15 +146,18 @@ def xp(
168146
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
169147

170148
if library == Backend.JAX_GPU:
171-
device = _jax_cuda_device()
172-
if device is None:
149+
try:
150+
device = jax.devices("cuda")[0]
151+
except RuntimeError:
173152
pytest.skip("no cuda device available")
174153
else:
175154
device = jax.devices("cpu")[0]
176155
jax.config.update("jax_default_device", device)
177156

178157
elif library == Backend.TORCH_GPU:
179-
if _torch_cuda_device() is None:
158+
import torch.cuda
159+
160+
if not torch.cuda.is_available():
180161
pytest.skip("no cuda device available")
181162
xp.set_default_device("cuda")
182163

0 commit comments

Comments
 (0)