Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,9 @@ def static_field(**kwargs):
return field(**kwargs, static=True)


_converter_sentinel: Any = doc_repr(object(), "lambda x: x")


def field(
*,
converter: Callable[[Any], Any] = _converter_sentinel,
converter: Callable[[Any], Any] | None = None,
static: bool = False,
**kwargs,
):
Expand All @@ -66,6 +63,7 @@ def field(
`bool`/`int`/`float`/`complex` values to JAX arrays. This is ran after the
`__init__` method (i.e. when using a user-provided `__init__`), and after
`__post_init__` (i.e. when using the default dataclass initialisation).
If `converter` is `None`, then no converter is registered.
- `static`: whether the field should not interact with any JAX transform at all (by
making it part of the PyTree structure rather than a leaf).
- `**kwargs`: All other keyword arguments are passed on to `dataclass.field`.
Expand Down Expand Up @@ -104,7 +102,7 @@ class MyModule(eqx.Module):
except KeyError:
metadata = {}
if "converter" in metadata:
raise ValueError("Cannot use metadata with `static` already set.")
raise ValueError("Cannot use metadata with `converter` already set.")
if "static" in metadata:
raise ValueError("Cannot use metadata with `static` already set.")
# We don't just use `lambda x: x` as the default, so that this works:
Expand All @@ -122,7 +120,7 @@ class MyModule(eqx.Module):
# Oddities like the above are to be discouraged, of course, but in particular
# `field(init=False)` was sometimes used to denote an abstract field (prior to the
# introduction of `AbstractVar`), so we do want to support this.
if converter is not _converter_sentinel:
if converter is not None:
metadata["converter"] = converter
if static:
metadata["static"] = True
Expand Down Expand Up @@ -302,14 +300,14 @@ def __new__(
# checkers.
# Note that mutating the `__init__.__annotations__` is okay, as it was created
# by the dataclass decorator on the previous line, so nothing else owns it.
for f in dataclasses.fields(cls):
if f.name not in cls.__init__.__annotations__:
continue # Odd behaviour, so skip.
try:
converter = f.metadata["converter"]
except KeyError:
pass
else:
if has_dataclass_init:
for f in dataclasses.fields(cls):
if not f.init:
continue

if (converter := f.metadata.get("converter")) is None:
continue # No converter, so skip.

try:
signature = inspect.signature(converter)
except ValueError:
Expand Down Expand Up @@ -574,18 +572,16 @@ def __call__(cls, *args, **kwargs):
assert not _is_abstract(cls)
# [Step 3] Run converters
for field in dataclasses.fields(cls):
if (converter := field.metadata.get("converter")) is None:
continue

try:
converter = field.metadata["converter"]
except KeyError:
value = getattr(self, field.name)
except AttributeError:
# Let the all-fields-are-filled check handle the error.
pass
else:
try:
value = getattr(self, field.name)
except AttributeError:
# Let the all-fields-are-filled check handle the error.
pass
else:
setattr(self, field.name, converter(value))
setattr(self, field.name, converter(value))
# [Step 4] Check that all fields are occupied.
missing_names = {
field.name
Expand Down