Skip to content

Commit 8039762

Browse files
authored
Merge pull request #607 from mrava87/patch-jaximport
bug: protect use of jnp if jax is not installed
2 parents e111225 + b8d2c85 commit 8039762

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

pylops/utils/backend.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def get_array_module(x: npt.ArrayLike) -> ModuleType:
128128
129129
"""
130130
if deps.cupy_enabled or deps.jax_enabled:
131-
if isinstance(x, jnp.ndarray):
131+
if deps.jax_enabled and isinstance(x, jnp.ndarray):
132132
return jnp
133133
elif deps.cupy_enabled:
134134
return cp.get_array_module(x)
@@ -153,7 +153,7 @@ def get_convolve(x: npt.ArrayLike) -> Callable:
153153
154154
"""
155155
if deps.cupy_enabled or deps.jax_enabled:
156-
if isinstance(x, jnp.ndarray):
156+
if deps.jax_enabled and isinstance(x, jnp.ndarray):
157157
return j_convolve
158158
elif deps.cupy_enabled and cp.get_array_module(x) == cp:
159159
return cp_convolve
@@ -178,7 +178,7 @@ def get_fftconvolve(x: npt.ArrayLike) -> Callable:
178178
179179
"""
180180
if deps.cupy_enabled or deps.jax_enabled:
181-
if isinstance(x, jnp.ndarray):
181+
if deps.jax_enabled and isinstance(x, jnp.ndarray):
182182
return j_fftconvolve
183183
elif deps.cupy_enabled and cp.get_array_module(x) == cp:
184184
return cp_fftconvolve
@@ -203,7 +203,7 @@ def get_oaconvolve(x: npt.ArrayLike) -> Callable:
203203
204204
"""
205205
if deps.cupy_enabled or deps.jax_enabled:
206-
if isinstance(x, jnp.ndarray):
206+
if deps.jax_enabled and isinstance(x, jnp.ndarray):
207207
raise NotImplementedError(
208208
"oaconvolve not implemented in "
209209
"jax. Consider using a different"
@@ -232,7 +232,7 @@ def get_correlate(x: npt.ArrayLike) -> Callable:
232232
233233
"""
234234
if deps.cupy_enabled or deps.jax_enabled:
235-
if isinstance(x, jnp.ndarray):
235+
if deps.jax_enabled and isinstance(x, jnp.ndarray):
236236
return jax.scipy.signal.correlate
237237
elif deps.cupy_enabled and cp.get_array_module(x) == cp:
238238
return cp_correlate
@@ -303,7 +303,7 @@ def get_block_diag(x: npt.ArrayLike) -> Callable:
303303
304304
"""
305305
if deps.cupy_enabled or deps.jax_enabled:
306-
if isinstance(x, jnp.ndarray):
306+
if deps.jax_enabled and isinstance(x, jnp.ndarray):
307307
return jnp_block_diag
308308
elif deps.cupy_enabled and cp.get_array_module(x) == cp:
309309
return cp_block_diag
@@ -328,7 +328,7 @@ def get_toeplitz(x: npt.ArrayLike) -> Callable:
328328
329329
"""
330330
if deps.cupy_enabled or deps.jax_enabled:
331-
if isinstance(x, jnp.ndarray):
331+
if deps.jax_enabled and isinstance(x, jnp.ndarray):
332332
return jnp_toeplitz
333333
elif deps.cupy_enabled and cp.get_array_module(x) == cp:
334334
return cp_toeplitz

0 commit comments

Comments
 (0)