|
15 | 15 | from importlib import import_module |
16 | 16 | from unittest.mock import patch |
17 | 17 |
|
| 18 | +from earthkit.utils.testing import get_array_backend |
| 19 | + |
18 | 20 | from earthkit.data import from_object |
19 | 21 | from earthkit.data import from_source |
20 | 22 | from earthkit.data.readers.text import TextReader |
@@ -121,17 +123,7 @@ def modules_installed(*modules): |
121 | 123 |
|
122 | 124 | NO_POLYTOPE = not os.path.exists(os.path.expanduser("~/.polytopeapirc")) |
123 | 125 | NO_COVJSONKIT = not modules_installed("covjsonkit") |
124 | | -NO_PYTORCH = not modules_installed("torch") |
125 | 126 | NO_RIOXARRAY = not modules_installed("rioxarray") |
126 | | -NO_CUPY = not modules_installed("cupy") |
127 | | -NO_JAX = not modules_installed("jax") |
128 | | -if not NO_CUPY: |
129 | | - try: |
130 | | - import cupy as cp |
131 | | - |
132 | | - a = cp.ones(2) |
133 | | - except Exception: |
134 | | - NO_CUPY = True |
135 | 127 |
|
136 | 128 | NO_S3_AUTH = not modules_installed("aws_requests_auth") |
137 | 129 | NO_GEO = not modules_installed("earthkit-data") |
@@ -187,34 +179,8 @@ def load_nc_or_xr_source(path, mode): |
187 | 179 | return from_object(xarray.open_dataset(path)) |
188 | 180 |
|
189 | 181 |
|
190 | | -def check_array_type(array, expected_backend, dtype=None): |
191 | | - from earthkit.data.utils.array import get_backend |
192 | | - |
193 | | - b1 = get_backend(array) |
194 | | - b2 = get_backend(expected_backend) |
195 | | - |
196 | | - assert b1 == b2, f"{b1=}, {b2=}" |
197 | | - |
198 | | - expected_dtype = dtype |
199 | | - if expected_dtype is not None: |
200 | | - assert b2.match_dtype(array, expected_dtype), f"{array.dtype}, {expected_dtype=}" |
201 | | - |
202 | | - |
203 | | -def get_array_namespace(backend): |
204 | | - if backend is None: |
205 | | - backend = "numpy" |
206 | | - |
207 | | - from earthkit.data.utils.array import get_backend |
208 | | - |
209 | | - return get_backend(backend).namespace |
210 | | - |
211 | | - |
212 | | -ARRAY_BACKENDS = ["numpy"] |
213 | | -if not NO_PYTORCH: |
214 | | - ARRAY_BACKENDS.append("pytorch") |
215 | | - |
216 | | -if not NO_CUPY: |
217 | | - ARRAY_BACKENDS.append("cupy") |
| 182 | +# Array backends |
| 183 | +ARRAY_BACKENDS = get_array_backend(["numpy", "torch", "cupy", "jax"], raise_on_missing=False) |
218 | 184 |
|
219 | 185 |
|
220 | 186 | def make_tgz(target_dir, target_name, paths): |
|
0 commit comments