|
2 | 2 |
|
3 | 3 | from collections.abc import Callable, Generator |
4 | 4 | from contextlib import suppress |
5 | | -from functools import cache, partial, wraps |
| 5 | +from functools import partial, wraps |
6 | 6 | from types import ModuleType |
7 | 7 | from typing import ParamSpec, TypeVar, cast |
8 | 8 |
|
@@ -102,28 +102,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 |
102 | 102 | return wrapper |
103 | 103 |
|
104 | 104 |
|
105 | | -@cache |
106 | | -def _jax_cuda_device() -> Device | None: |
107 | | - """Return a CUDA device for JAX, if available.""" |
108 | | - import jax |
109 | | - |
110 | | - try: |
111 | | - return jax.devices("cuda")[0] |
112 | | - except Exception: |
113 | | - return None |
114 | | - |
115 | | - |
116 | | -@cache |
117 | | -def _torch_cuda_device() -> Device | None: |
118 | | - """Return a CUDA device for PyTorch, if available.""" |
119 | | - import torch |
120 | | - |
121 | | - try: |
122 | | - return torch.empty((0,), device=torch.device("cuda")).device |
123 | | - except Exception: |
124 | | - return None |
125 | | - |
126 | | - |
127 | 105 | @pytest.fixture |
128 | 106 | def xp( |
129 | 107 | library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch |
@@ -168,15 +146,18 @@ def xp( |
168 | 146 | jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore] |
169 | 147 |
|
170 | 148 | if library == Backend.JAX_GPU: |
171 | | - device = _jax_cuda_device() |
172 | | - if device is None: |
| 149 | + try: |
| 150 | + device = jax.devices("cuda")[0] |
| 151 | + except RuntimeError: |
173 | 152 | pytest.skip("no cuda device available") |
174 | 153 | else: |
175 | 154 | device = jax.devices("cpu")[0] |
176 | 155 | jax.config.update("jax_default_device", device) |
177 | 156 |
|
178 | 157 | elif library == Backend.TORCH_GPU: |
179 | | - if _torch_cuda_device() is None: |
| 158 | + import torch.cuda |
| 159 | + |
| 160 | + if not torch.cuda.is_available(): |
180 | 161 | pytest.skip("no cuda device available") |
181 | 162 | xp.set_default_device("cuda") |
182 | 163 |
|
|
0 commit comments