@@ -223,12 +223,12 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
223223
224224 # Backend-specific branches
225225 if is_dask_namespace (xp ):
226- import dask # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
226+ import dask # pylint: disable=import-outside-toplevel
227227
228228 metas = [arg ._meta for arg in args if hasattr (arg , "_meta" )] # pylint: disable=protected-access
229229 meta_xp = array_namespace (* metas )
230230
231- wrapped = dask .delayed (
231+ wrapped = dask .delayed ( # type: ignore[attr-defined] # pyright: ignore[reportPrivateImportUsage]
232232 _lazy_apply_wrapper (func , as_numpy , multi_output , meta_xp ),
233233 pure = True ,
234234 )
@@ -239,7 +239,7 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
239239
240240 out = tuple (
241241 xp .from_delayed (
242- delayed_out [i ],
242+ delayed_out [i ], # pyright: ignore[reportIndexIssue]
243243 # Dask's unknown shapes diverge from the Array API specification
244244 shape = tuple (math .nan if s is None else s for s in shape ),
245245 dtype = dtype ,
@@ -254,7 +254,7 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
254254 # Instead, we delay calling wrapped, which will receive
255255 # as arguments and will return JAX eager arrays.
256256
257- import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
257+ import jax # pylint: disable=import-outside-toplevel
258258
259259 wrapped = _lazy_apply_wrapper (func , as_numpy , multi_output , xp )
260260
@@ -265,18 +265,17 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
265265 out = wrapped (* args , ** kwargs )
266266
267267 else :
268- out = cast (
269- tuple [Array , ...],
270- jax .pure_callback (
271- wrapped ,
272- tuple (
273- jax .ShapeDtypeStruct (shape , dtype ) # pyright: ignore[reportUnknownArgumentType]
274- for shape , dtype in zip (shapes , dtypes , strict = True )
275- ),
276- * args ,
277- ** kwargs ,
268+ # FIXME jax typing bug
269+ out_jax = jax .pure_callback ( # type: ignore[func-returns-value]
270+ wrapped ,
271+ tuple (
272+ jax .ShapeDtypeStruct (shape , dtype )
273+ for shape , dtype in zip (shapes , dtypes , strict = True )
278274 ),
275+ * args ,
276+ ** kwargs ,
279277 )
278+ out = cast (tuple [Array , ...], cast (object , out_jax ))
280279
281280 else :
282281 # Eager backends
0 commit comments