@@ -128,7 +128,7 @@ def get_array_module(x: npt.ArrayLike) -> ModuleType:
128128
129129 """
130130 if deps .cupy_enabled or deps .jax_enabled :
131- if isinstance (x , jnp .ndarray ):
131+ if deps . jax_enabled and isinstance (x , jnp .ndarray ):
132132 return jnp
133133 elif deps .cupy_enabled :
134134 return cp .get_array_module (x )
@@ -153,7 +153,7 @@ def get_convolve(x: npt.ArrayLike) -> Callable:
153153
154154 """
155155 if deps .cupy_enabled or deps .jax_enabled :
156- if isinstance (x , jnp .ndarray ):
156+ if deps . jax_enabled and isinstance (x , jnp .ndarray ):
157157 return j_convolve
158158 elif deps .cupy_enabled and cp .get_array_module (x ) == cp :
159159 return cp_convolve
@@ -178,7 +178,7 @@ def get_fftconvolve(x: npt.ArrayLike) -> Callable:
178178
179179 """
180180 if deps .cupy_enabled or deps .jax_enabled :
181- if isinstance (x , jnp .ndarray ):
181+ if deps . jax_enabled and isinstance (x , jnp .ndarray ):
182182 return j_fftconvolve
183183 elif deps .cupy_enabled and cp .get_array_module (x ) == cp :
184184 return cp_fftconvolve
@@ -203,7 +203,7 @@ def get_oaconvolve(x: npt.ArrayLike) -> Callable:
203203
204204 """
205205 if deps .cupy_enabled or deps .jax_enabled :
206- if isinstance (x , jnp .ndarray ):
206+ if deps . jax_enabled and isinstance (x , jnp .ndarray ):
207207 raise NotImplementedError (
208208 "oaconvolve not implemented in "
209209 "jax. Consider using a different"
@@ -232,7 +232,7 @@ def get_correlate(x: npt.ArrayLike) -> Callable:
232232
233233 """
234234 if deps .cupy_enabled or deps .jax_enabled :
235- if isinstance (x , jnp .ndarray ):
235+ if deps . jax_enabled and isinstance (x , jnp .ndarray ):
236236 return jax .scipy .signal .correlate
237237 elif deps .cupy_enabled and cp .get_array_module (x ) == cp :
238238 return cp_correlate
@@ -303,7 +303,7 @@ def get_block_diag(x: npt.ArrayLike) -> Callable:
303303
304304 """
305305 if deps .cupy_enabled or deps .jax_enabled :
306- if isinstance (x , jnp .ndarray ):
306+ if deps . jax_enabled and isinstance (x , jnp .ndarray ):
307307 return jnp_block_diag
308308 elif deps .cupy_enabled and cp .get_array_module (x ) == cp :
309309 return cp_block_diag
@@ -328,7 +328,7 @@ def get_toeplitz(x: npt.ArrayLike) -> Callable:
328328
329329 """
330330 if deps .cupy_enabled or deps .jax_enabled :
331- if isinstance (x , jnp .ndarray ):
331+ if deps . jax_enabled and isinstance (x , jnp .ndarray ):
332332 return jnp_toeplitz
333333 elif deps .cupy_enabled and cp .get_array_module (x ) == cp :
334334 return cp_toeplitz
0 commit comments