Skip to content
Merged

Dev #1170

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,10 @@ def _none_to_zero(ct, x):
if x is None:
return None
else:
aval = jax.core.get_aval(x)
if hasattr(jax, "typeof"):
aval = jax.typeof(x)
else:
aval = jax.core.get_aval(x)
if hasattr(aval, "to_tangent_aval"):
# Earlier versions of JAX were internally inconsistent, and expected
# e.g. integer primals to have integer tangents from `custom_{jvp,vjp}`
Expand Down
8 changes: 4 additions & 4 deletions equinox/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Literal


EQX_ON_ERROR: Literal["raise", "breakpoint", "nan"] = os.environ.get( # pyright: ignore
EQX_ON_ERROR: Literal["raise", "breakpoint", "nan", "off"] = os.environ.get(
"EQX_ON_ERROR", "raise"
)
if EQX_ON_ERROR not in ("raise", "breakpoint", "nan"):
) # pyright: ignore
if EQX_ON_ERROR not in ("raise", "breakpoint", "nan", "off"):
raise ValueError(
"Unrecognised value for `EQX_ON_ERROR`. Valid values are `EQX_ON_ERROR=raise`, "
"`EQX_ON_ERROR=breakpoint`, and `EQX_ON_ERROR=nan`."
"`EQX_ON_ERROR=breakpoint`, `EQX_ON_ERROR=nan`, and `EQX_ON_ERROR=off`."
)
if EQX_ON_ERROR == "breakpoint":
warnings.warn(
Expand Down
8 changes: 6 additions & 2 deletions equinox/_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.core
import jax.numpy as jnp
import numpy as np
import wadler_lindig as wl
from jaxtyping import Array, ArrayLike, Bool, Int

from ._doc_utils import doc_repr
Expand Down Expand Up @@ -171,6 +172,10 @@ def __ne__(self, other) -> Bool[Array, ""]: # pyright: ignore
"Can only compare equality between enumerations of the same type."
)

def __pdoc__(self, **kwargs):
del kwargs
return wl.TextDoc(repr(self))

def __repr__(self):
prefix = f"{self._enumeration.__module__}.{self._enumeration.__qualname__}"
message = self._enumeration[self]
Expand Down Expand Up @@ -204,8 +209,7 @@ def is_traced(self) -> bool:

if TYPE_CHECKING:
import enum
from typing import ClassVar
from typing_extensions import Self
from typing import ClassVar, Self

class _Sequence(type):
def __getitem__(cls, item) -> str: ...
Expand Down
16 changes: 11 additions & 5 deletions equinox/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def handle_error():

elif on_error == "nan":
return lax.cond(pred, ft.partial(jtu.tree_map, _nan_like), lambda y: y, x)
elif on_error == "off":
return x
else:
assert False

Expand Down Expand Up @@ -190,7 +192,7 @@ def error_if(
pred: Bool[ArrayLike, "..."],
msg: str,
*,
on_error: Literal["default", "raise", "breakpoint", "nan"] = "default",
on_error: Literal["default", "raise", "breakpoint", "nan", "off"] = "default",
) -> PyTree:
"""Throws an error based on runtime values. Works even under JIT.

Expand Down Expand Up @@ -222,6 +224,8 @@ def error_if(
`EQX_ON_ERROR_BREAKPOINT_FRAMES` environment variable to a small integer,
which specifies how many frames upwards the debugger should capture. The
JAX bug is triggered when taking too many frames.
- `EQX_ON_ERROR=off` turns off all error checking. This is useful for removing
performance penalties incurred from use of `error_if`.

After changing an environment variable, the Python process must be restarted.

Expand Down Expand Up @@ -253,7 +257,7 @@ def branched_error_if(
index: Int[ArrayLike, "..."],
msgs: Sequence[str],
*,
on_error: Literal["default", "raise", "breakpoint", "nan"] = "default",
on_error: Literal["default", "raise", "breakpoint", "nan", "off"] = "default",
) -> PyTree:
"""As [`equinox.error_if`][], but will raise one of
several `msgs` depending on the value of `index`. If `index` is vmap'd, then the
Expand All @@ -275,11 +279,11 @@ def branched_error_if_impl(
index: Int[ArrayLike, "..."],
msgs: Sequence[str],
*,
on_error: Literal["default", "raise", "breakpoint", "nan"],
on_error: Literal["default", "raise", "breakpoint", "off", "nan"],
) -> PyTree:
if on_error == "default":
on_error = EQX_ON_ERROR
elif on_error not in ("raise", "breakpoint", "nan"):
elif on_error not in ("raise", "breakpoint", "off", "nan"):
raise RuntimeError("Unrecognised value for `on_error`.")
with jax.ensure_compile_time_eval():
# This carefully does not perform any JAX operations if `pred` and `index` are
Expand Down Expand Up @@ -310,6 +314,8 @@ def branched_error_if_impl(
"`on_error='nan'`)."
)
return jtu.tree_map(_nan_like, x)
elif on_error == "off":
return x
else:
assert False
# else defer error to runtime, when the index is known.
Expand Down Expand Up @@ -348,7 +354,7 @@ def assert_dce(
x: PyTree,
msg: str,
*,
on_error: Literal["default", "raise", "breakpoint", "nan"] = "default",
on_error: Literal["default", "raise", "breakpoint", "off", "nan"] = "default",
) -> PyTree:
"""Asserts that a particular array (or PyTree of arrays) is DCE'd."""

Expand Down
4 changes: 2 additions & 2 deletions equinox/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class JaxRuntimeError(Exception):
def wait_for_tokens2():
try:
wait_for_tokens()
except JaxRuntimeError:
except (JaxRuntimeError, ValueError):
pass

atexit.unregister(wait_for_tokens)
Expand Down Expand Up @@ -269,7 +269,7 @@ def _call(jit_wrapper: _JitWrapper, is_lower, args, kwargs):
marker, jax.Array
):
marker.block_until_ready()
except JaxRuntimeError as e:
except (JaxRuntimeError, ValueError) as e:
# Catch Equinox's runtime errors, and re-raise them with actually useful
# information. (By default XlaRuntimeError produces a lot of terrifying
# but useless information.)
Expand Down
9 changes: 6 additions & 3 deletions equinox/_module/_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> _Return:
return self.func(*self.args, *args, **kwargs, **self.keywords)


class Static(Module):
_Value = TypeVar("_Value")


class Static(Module, Generic[_Value]):
"""Wraps a value into a `eqx.field(static=True)`.

This is useful to treat something as just static metadata with respect to a JAX
Expand All @@ -93,12 +96,12 @@ class Static(Module):
_leaves: list[Any] = field(static=True)
_treedef: PyTreeDef = field(static=True) # pyright: ignore

def __init__(self, value: Any):
def __init__(self, value: _Value):
# By flattening, we handle pytrees without `__eq__` methods.
# When comparing static metadata for equality, this means we never actually
# call `value.__eq__`.
self._leaves, self._treedef = jtu.tree_flatten(value)

@property
def value(self):
def value(self) -> _Value:
return jtu.tree_unflatten(self._treedef, self._leaves)
4 changes: 2 additions & 2 deletions equinox/debug/_dce.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def f(x):


def inspect_dce(name: Hashable = None):
"""Used in conjunction with `equinox.debug.check_dce`; see documentation there.
"""Used in conjunction with `equinox.debug.store_dce`; see documentation there.

Must be called outside of any JIT'd function.

**Arguments:**

- `name`: Optional argument. Whatever name was used with `check_dce`.
- `name`: Optional argument. Whatever name was used with `store_dce`.

**Returns:**

Expand Down
24 changes: 15 additions & 9 deletions equinox/internal/_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def _is_array_like_internal(x):

def _zero_from_primal(p):
assert type(p) is not ad.UndefinedPrimal
aval = jax.core.get_aval(p)
if hasattr(jax, "typeof"):
aval = jax.typeof(p)
else:
aval = jax.core.get_aval(p)
if hasattr(aval, "to_tangent_aval"):
# JAX >=0.4.34
aval = aval.to_tangent_aval() # pyright: ignore
Expand Down Expand Up @@ -342,22 +345,25 @@ def _vprim_impl(*inputs, prim, __axis_size, __axis_name, __batch_axes, params):
return impl(*inputs)


if hasattr(jax.extend.core, "mapped_aval"):
_mapped_aval = jax.extend.core.mapped_aval # pyright: ignore[reportAttributeAccessIssue]
else:
_mapped_aval = jax.core.mapped_aval
if hasattr(jax.extend.core, "unmapped_aval"):
_unmapped_aval = jax.extend.core.unmapped_aval # pyright: ignore[reportAttributeAccessIssue,reportAssignmentType]
else:
_unmapped_aval = jax.core.unmapped_aval # pyright: ignore[reportAssignmentType]
if jax.__version_info__ >= (0, 5, 1):
_old_unmapped_aval = _unmapped_aval

def _unmapped_aval(axis_size, axis_name, axis, aval):
del axis_name
return jax.core.unmapped_aval(axis_size, axis, aval) # pyright: ignore[reportCallIssue]

else:
# signature (axis_size, axis_name, axis, aval)
_unmapped_aval = jax.core.unmapped_aval # pyright: ignore[reportAssignmentType]
return _old_unmapped_aval(axis_size, axis, aval) # pyright: ignore[reportCallIssue]


def _vprim_abstract_eval(*inputs, prim, __axis_size, __axis_name, __batch_axes, params):
assert len(inputs) == len(__batch_axes)
inputs = [
jax.core.mapped_aval(__axis_size, b, x) for x, b in zip(inputs, __batch_axes)
]
inputs = [_mapped_aval(__axis_size, b, x) for x, b in zip(inputs, __batch_axes)]
abstract_eval = _vprim_abstract_eval_registry[prim]
outs = abstract_eval(*inputs, **dict(params))
outs = [_unmapped_aval(__axis_size, __axis_name, 0, x) for x in outs]
Expand Down
12 changes: 0 additions & 12 deletions equinox/nn/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,6 @@ def __call__(self, x, ctx, *, key):
training_model_again = eqx.nn.inference_mode(inference_model, value=False)
```

This function is essentially equivalent to:
```python
has_inference = lambda leaf: hasattr(leaf, "inference")

def where(pytree):
return tuple(x.inference
for x in jtu.tree_leaves(pytree, is_leaf=has_inference)
if has_inference(x))

inference_pytree = equinox.tree_at(where, pytree, replace_fn=lambda _: value)
```

**Arguments:**

- `pytree`: the PyTree to modify.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ name = "equinox"
readme = "README.md"
requires-python = ">=3.10"
urls = {repository = "https://github.com/patrick-kidger/equinox"}
version = "0.13.2"
version = "0.13.3"

[project.optional-dependencies]
dev = [
Expand Down
29 changes: 29 additions & 0 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ def f(x, pred):
assert jnp.isnan(y)


def test_off_tracetime():
@jax.jit
def f(x):
return eqx.error_if(x, True, "hi", on_error="off") + 1

assert jnp.isclose(f(1.0), 2.0)


def test_off():
@jax.jit
def f(x, pred):
return eqx.error_if(x, pred, "hi", on_error="off")

assert jnp.isclose(f(1.0, True), 1.0)


def test_assert_dce():
@jax.jit
def f(x):
Expand Down Expand Up @@ -167,3 +183,16 @@ def _raises():
# assert e.__cause__ is None # varies by Python version and JAX version.
assert "egads" in str(e)
assert "EQX_ON_ERROR" not in str(e)


# https://github.com/patrick-kidger/equinox/issues/1156
def test_error_after_success():
@eqx.filter_jit
def foo(x):
return eqx.error_if(x, x > 0.0, "foo")

foo(jnp.array(-1.0))
try:
foo(jnp.array(1.0))
except Exception as e:
assert type(e) is eqx.EquinoxRuntimeError
15 changes: 14 additions & 1 deletion tests/test_finalise_jaxpr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterator
from typing import cast

import equinox.internal as eqxi
Expand Down Expand Up @@ -131,10 +132,22 @@ def fn(x):
_assert_jaxpr_equal(finalised_vmap_jaxpr, finalised_finalised_vmap_jaxpr)


# Stolen from `jax.core.subjaxprs` as it is being deprecated.
def _subjaxprs(jaxpr: jax.extend.core.Jaxpr) -> Iterator[jax.extend.core.Jaxpr]:
for eqn in jaxpr.eqns:
for val in eqn.params.values():
vals = val if isinstance(val, tuple) else (val,)
for v in vals:
if isinstance(v, jax.extend.core.Jaxpr):
yield v
elif isinstance(v, jax.extend.core.ClosedJaxpr):
yield v.jaxpr


def _assert_no_unvmap(jaxpr: jax.extend.core.Jaxpr):
for eqn in jaxpr.eqns:
assert eqn.primitive not in (eqxi.unvmap_any_p, eqxi.unvmap_all_p)
for subjaxpr in jax.core.subjaxprs(jaxpr):
for subjaxpr in _subjaxprs(jaxpr):
_assert_no_unvmap(subjaxpr)


Expand Down