Skip to content

Commit e0f7323

Browse files
committed
ENH: New functions lazy_raise, lazy_warn, and lazy_wait_on
1 parent 1a8fb30 commit e0f7323

File tree

9 files changed

+376
-48
lines changed

9 files changed

+376
-48
lines changed

docs/api-lazy.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ lazy backends, e.g. Dask or Jax:
1010
:toctree: generated
1111
1212
lazy_apply
13+
lazy_raise
14+
lazy_wait_on
15+
lazy_warn
1316
testing.lazy_xp_function
1417
testing.patch_lazy_xp_functions
1518
```

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"numpy": ("https://numpy.org/doc/stable", None),
5757
"dask": ("https://docs.dask.org/en/stable", None),
5858
"jax": ("https://jax.readthedocs.io/en/latest", None),
59+
"equinox": ("https://docs.kidger.site/equinox/", None),
5960
}
6061

6162
nitpick_ignore = [

pixi.lock

Lines changed: 42 additions & 44 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Typing :: Typed",
2727
]
2828
dynamic = ["version"]
29-
dependencies = ["array-api-compat>=1.10.0,<2"]
29+
# dependencies = ["array-api-compat>=1.10.0,<2"] # DNM
3030

3131
[project.urls]
3232
Homepage = "https://github.com/data-apis/array-api-extra"
@@ -48,10 +48,11 @@ platforms = ["linux-64", "osx-arm64", "win-64"]
4848

4949
[tool.pixi.dependencies]
5050
python = ">=3.10,<3.14"
51-
array-api-compat = ">=1.10.0,<2"
51+
# array-api-compat = ">=1.10.0,<2" # DNM
5252

5353
[tool.pixi.pypi-dependencies]
5454
array-api-extra = { path = ".", editable = true }
55+
array-api-compat = { git = "https://github.com/data-apis/array-api-compat" } # DNM
5556

5657
[tool.pixi.feature.lint.dependencies]
5758
typing-extensions = "*"

src/array_api_extra/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
setdiff1d,
1313
sinc,
1414
)
15-
from ._lib._lazy import lazy_apply
15+
from ._lib._lazy import lazy_apply, lazy_raise, lazy_wait_on, lazy_warn
1616

1717
__version__ = "0.6.1.dev0"
1818

@@ -27,6 +27,9 @@
2727
"isclose",
2828
"kron",
2929
"lazy_apply",
30+
"lazy_raise",
31+
"lazy_wait_on",
32+
"lazy_warn",
3033
"nunique",
3134
"pad",
3235
"setdiff1d",

src/array_api_extra/_lib/_lazy.py

Lines changed: 317 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,18 @@
44
from __future__ import annotations
55

66
import math
7+
import warnings
78
from collections.abc import Callable, Sequence
89
from functools import wraps
910
from types import ModuleType
1011
from typing import TYPE_CHECKING, Any, cast, overload
1112

12-
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_namespace
13+
from ._utils._compat import (
14+
array_namespace,
15+
is_dask_namespace,
16+
is_jax_namespace,
17+
is_lazy_array,
18+
)
1319
from ._utils._typing import Array, DType
1420

1521
if TYPE_CHECKING:
@@ -319,3 +325,313 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
319325
return (xp.asarray(out),)
320326

321327
return wrapper
328+
329+
330+
def lazy_raise( # numpydoc ignore=SA04
331+
x: Array,
332+
cond: bool | Array,
333+
exc: Exception,
334+
*,
335+
xp: ModuleType | None = None,
336+
) -> Array:
337+
"""
338+
Raise an exception if an eager check fails on a lazy array.
339+
340+
Consider this snippet::
341+
342+
>>> def f(x, xp):
343+
... if xp.any(x < 0):
344+
... raise ValueError("Some points are negative")
345+
... return x + 1
346+
347+
The above code fails to compile when x is a JAX array and the function is wrapped
348+
by `jax.jit`; it is also extremely slow on Dask. Other lazy backends, e.g. ndonnx,
349+
are also expected to misbehave.
350+
351+
`xp.any(x < 0)` is a 0-dimensional array with `dtype=bool`; the `if` statement calls
352+
`bool()` on the Array to convert it to a Python bool.
353+
354+
On eager backends such as NumPy, this is not a problem. On Dask, `bool()` implicitly
355+
triggers a computation of the whole graph so far; what's worse is that the
356+
intermediate results are discarded to optimize memory usage, so when later on user
357+
explicitly calls `compute()` on their final output, `x` is recalculated from
358+
scratch. On JAX, `bool()` raises if its called code is wrapped by `jax.jit` for the
359+
same reason.
360+
361+
You should rewrite the above code as follows::
362+
363+
>>> def f(x, xp):
364+
... x = lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative"))
365+
... return x + 1
366+
367+
When `xp` is eager, this is equivalent to the original code; if the error condition
368+
resolves to True, the function raises immediately and the next line `return x + 1`
369+
is never executed.
370+
When `xp` is lazy, the function always returns a lazy array. When eventually the
371+
user actually computes it, e.g. in Dask by calling `compute()` and in JAX by having
372+
their outermost function decorated with `@jax.jit` return, only then the error
373+
condition is evaluated. If True, the exception is raised and propagated as normal,
374+
and the following nodes of the graph are never executed (so if the health check was
375+
in place to prevent not only incorrect results but e.g. a segmentation fault, it's
376+
still going to achieve its purpose).
377+
378+
Parameters
379+
----------
380+
x : Array
381+
Any one Array, potentially lazy, that is used later on to produce the value
382+
returned by your function.
383+
cond : bool | Array
384+
Must be either a plain Python bool or a 0-dimensional Array with boolean dtype.
385+
If True, raise the exception. If False, return x.
386+
exc : Exception
387+
The exception instance to be raised.
388+
xp : array_namespace, optional
389+
The standard-compatible namespace for `x`. Default: infer.
390+
391+
Returns
392+
-------
393+
Array
394+
`x`. If both `x` and `cond` are lazy array, the graph underlying `x` is altered
395+
to raise `exc` if `cond` is True.
396+
397+
Raises
398+
------
399+
type(x)
400+
If `cond` evaluates to True.
401+
402+
Warnings
403+
--------
404+
This function raises when x is eager, and quietly skips the check
405+
when x is lazy::
406+
407+
>>> def f(x, xp):
408+
... lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative"))
409+
... return x + 1
410+
411+
And so does this one, as lazy_raise replaces `x` but it does so too late to
412+
contribute to the return value::
413+
414+
>>> def f(x, xp):
415+
... y = x + 1
416+
... x = lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative"))
417+
... return y
418+
419+
See Also
420+
--------
421+
lazy_apply
422+
lazy_warn
423+
lazy_wait_on
424+
dask.graph_manipulation.wait_on
425+
equinox.error_if
426+
427+
Notes
428+
-----
429+
This function will raise if the :doc:`jax:transfer_guard` is active and `cond` is
430+
a JAX array on a non-CPU device
431+
(`jax-ml/jax#25995 <https://github.com/jax-ml/jax/issues/25998>`_).
432+
"""
433+
434+
def _lazy_raise(x: Array, cond: Array) -> Array: # numpydoc ignore=PR01,RT01
435+
"""Eager helper of `lazy_raise` running inside the lazy graph."""
436+
if cond:
437+
raise exc
438+
return x
439+
440+
return _lazy_wait_on_impl(x, cond, _lazy_raise, xp=xp)
441+
442+
443+
# Signature of warnings.warn copied from python/typeshed
444+
@overload
445+
def lazy_warn( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
446+
x: Array,
447+
cond: bool | Array,
448+
message: str,
449+
category: type[Warning] | None = None,
450+
stacklevel: int = 1,
451+
source: Any | None = None,
452+
*,
453+
xp: ModuleType | None = None,
454+
) -> None: ...
455+
@overload
456+
def lazy_warn( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
457+
x: Array,
458+
cond: bool | Array,
459+
message: Warning,
460+
category: Any = None,
461+
stacklevel: int = 1,
462+
source: Any | None = None,
463+
*,
464+
xp: ModuleType | None = None,
465+
) -> None: ...
466+
467+
468+
def lazy_warn( # type: ignore[no-any-explicit] # numpydoc ignore=SA04,PR04
469+
x: Array,
470+
cond: bool | Array,
471+
message: str | Warning,
472+
category: Any = None,
473+
stacklevel: int = 1,
474+
source: Any | None = None,
475+
*,
476+
xp: ModuleType | None = None,
477+
) -> Array:
478+
"""
479+
Call `warnings.warn` if an eager check fails on a lazy array.
480+
481+
This functions works in the same way as `lazy_raise`; refer to it
482+
for the detailed explanation.
483+
484+
You should replace::
485+
486+
>>> def f(x, xp):
487+
... if xp.any(x < 0):
488+
... warnings.warn("Some points are negative", UserWarning, stacklevel=2)
489+
... return x + 1
490+
491+
with::
492+
493+
>>> def f(x, xp):
494+
... x = lazy_warn(x, xp.any(x < 0),
495+
... "Some points are negative", UserWarning, stacklevel=2)
496+
... return x + 1
497+
498+
Parameters
499+
----------
500+
x : Array
501+
Any one Array, potentially lazy, that is used later on to produce the value
502+
returned by your function.
503+
cond : bool | Array
504+
Must be either a plain Python bool or a 0-dimensional Array with boolean dtype.
505+
If True, raise the exception. If False, return x.
506+
message, category, stacklevel, source :
507+
Parameters to `warnings.warn`. `stacklevel` is automatically increased to
508+
compensate for the extra wrapper function.
509+
xp : array_namespace, optional
510+
The standard-compatible namespace for `x`. Default: infer.
511+
512+
Returns
513+
-------
514+
Array
515+
`x`. If both `x` and `cond` are lazy array, the graph underlying `x` is altered
516+
to issue the warning if `cond` is True.
517+
518+
See Also
519+
--------
520+
warnings.warn
521+
lazy_apply
522+
lazy_raise
523+
lazy_wait_on
524+
dask.graph_manipulation.wait_on
525+
526+
Notes
527+
-----
528+
This function will raise if the :doc:`jax:transfer_guard` is active and `cond` is
529+
a JAX array on a non-CPU device
530+
(`jax-ml/jax#25995 <https://github.com/jax-ml/jax/issues/25998>`_).
531+
532+
On Dask, the warning is typically going to appear on the log of the
533+
worker executing the function instead of on the client.
534+
"""
535+
536+
def _lazy_warn(x: Array, cond: Array) -> Array: # numpydoc ignore=PR01,RT01
537+
"""Eager helper of `lazy_raise` running inside the lazy graph."""
538+
if cond:
539+
warnings.warn(message, category, stacklevel=stacklevel + 2, source=source)
540+
return x
541+
542+
return _lazy_wait_on_impl(x, cond, _lazy_warn, xp=xp)
543+
544+
545+
def lazy_wait_on(
546+
x: Array, wait_on: object, *, xp: ModuleType | None = None
547+
) -> Array: # numpydoc ignore=SA04
548+
"""
549+
Pause materialization of `x` until `wait_on` has been materialized.
550+
551+
This is typically used to collect multiple calls to `lazy_raise` and/or
552+
`lazy_warn` from validation functions that would otherwise return None.
553+
If `wait_on` is not a lazy array, just return `x`.
554+
555+
Read `lazy_raise` for detailed explanation.
556+
557+
If you use this validation pattern for eager backends::
558+
559+
def validate(x, xp):
560+
if xp.any(x < 10):
561+
raise ValueError("Less than 10")
562+
if xp.any(x > 20):
563+
warnings.warn(UserWarning, "More than 20")
564+
565+
def f(x, xp):
566+
validate(x, xp=xp)
567+
return x + 1
568+
569+
You should rewrite it as follows::
570+
571+
def validate(x, xp):
572+
# Future that evaluates the checks. Contents are inconsequential.
573+
# Avoid zero-sized arrays, as they may be elided by the graph optimizer.
574+
future = xp.empty(1)
575+
future = lazy_raise(future, xp.any(x < 10), ValueError("Less than 10"))
576+
future = lazy_warn(future, xp.any(x > 20), UserWarning, "More than 20"))
577+
return future
578+
579+
def f(x, xp):
580+
x = lazy_wait_on(x, validate(x, xp=xp), xp=xp)
581+
return x + 1
582+
583+
Parameters
584+
----------
585+
x : Array
586+
Any one Array, potentially lazy, that is used later on to produce the value
587+
returned by your function.
588+
wait_on : object
589+
Any object. If it's a lazy array, block the materialization of `x` until
590+
`wait_on` has been fully materialized.
591+
xp : array_namespace, optional
592+
The standard-compatible namespace for `x`. Default: infer.
593+
594+
Returns
595+
-------
596+
Array
597+
`x`. If both `x` and `wait_on` are lazy arrays, the graph
598+
underlying `x` is altered to wait until `wait_on` has been materialized.
599+
If `wait_on` raises, the exception is propagated to `x`.
600+
601+
See Also
602+
--------
603+
lazy_apply
604+
lazy_raise
605+
lazy_warn
606+
dask.graph_manipulation.wait_on
607+
"""
608+
609+
def _lazy_wait_on(x: Array, _: Array) -> Array: # numpydoc ignore=PR01,RT01
610+
"""Eager helper of `lazy_wait_on` running inside the lazy graph."""
611+
return x
612+
613+
return _lazy_wait_on_impl(x, wait_on, _lazy_wait_on, xp=xp)
614+
615+
616+
def _lazy_wait_on_impl( # numpydoc ignore=PR01,RT01
617+
x: Array,
618+
wait_on: object,
619+
eager_func: Callable[[Array, Array], Array],
620+
xp: ModuleType | None,
621+
) -> Array:
622+
"""Implementation of lazy_raise, lazy_warn, and lazy_wait_on."""
623+
if not is_lazy_array(wait_on):
624+
return eager_func(x, wait_on)
625+
626+
if cast(Array, wait_on).shape != ():
627+
msg = "cond/wait_on must be 0-dimensional"
628+
raise ValueError(msg)
629+
630+
if xp is None:
631+
xp = array_namespace(x, wait_on)
632+
633+
if is_dask_namespace(xp):
634+
# lazy_apply would rechunk x
635+
return xp.map_blocks(eager_func, x, wait_on, dtype=x.dtype, meta=x._meta) # pylint: disable=protected-access
636+
637+
return lazy_apply(eager_func, x, wait_on, shape=x.shape, dtype=x.dtype, xp=xp)

0 commit comments

Comments
 (0)