Skip to content

Commit b75f3dd

Browse files
nstarmanpatrick-kidger
authored andcommitted
feat: converter = None as the sentinel (#996)
* feat: converter = None as the sentinel 1. Easier for users to access instead of a private sentinel. 2. simplifies later performance-related logic changes. 3. Broadens support to allow `dataclasses.field(metadata=dict(converter=None))`, not just `eqx.field`. Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com> * refactor: only set annotations if datclass init Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com> --------- Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
1 parent f8ca345 commit b75f3dd

File tree

1 file changed

+19
-23
lines changed

1 file changed

+19
-23
lines changed

equinox/_module.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,9 @@ def static_field(**kwargs):
4949
return field(**kwargs, static=True)
5050

5151

52-
_converter_sentinel: Any = doc_repr(object(), "lambda x: x")
53-
54-
5552
def field(
5653
*,
57-
converter: Callable[[Any], Any] = _converter_sentinel,
54+
converter: Callable[[Any], Any] | None = None,
5855
static: bool = False,
5956
**kwargs,
6057
):
@@ -67,6 +64,7 @@ def field(
6764
`bool`/`int`/`float`/`complex` values to JAX arrays. This is ran after the
6865
`__init__` method (i.e. when using a user-provided `__init__`), and after
6966
`__post_init__` (i.e. when using the default dataclass initialisation).
67+
If `converter` is `None`, then no converter is registered.
7068
- `static`: whether the field should not interact with any JAX transform at all (by
7169
making it part of the PyTree structure rather than a leaf).
7270
- `**kwargs`: All other keyword arguments are passed on to `dataclass.field`.
@@ -105,7 +103,7 @@ class MyModule(eqx.Module):
105103
except KeyError:
106104
metadata = {}
107105
if "converter" in metadata:
108-
raise ValueError("Cannot use metadata with `static` already set.")
106+
raise ValueError("Cannot use metadata with `converter` already set.")
109107
if "static" in metadata:
110108
raise ValueError("Cannot use metadata with `static` already set.")
111109
# We don't just use `lambda x: x` as the default, so that this works:
@@ -123,7 +121,7 @@ class MyModule(eqx.Module):
123121
# Oddities like the above are to be discouraged, of course, but in particular
124122
# `field(init=False)` was sometimes used to denote an abstract field (prior to the
125123
# introduction of `AbstractVar`), so we do want to support this.
126-
if converter is not _converter_sentinel:
124+
if converter is not None:
127125
metadata["converter"] = converter
128126
if static:
129127
metadata["static"] = True
@@ -303,14 +301,14 @@ def __new__(
303301
# checkers.
304302
# Note that mutating the `__init__.__annotations__` is okay, as it was created
305303
# by the dataclass decorator on the previous line, so nothing else owns it.
306-
for f in dataclasses.fields(cls):
307-
if f.name not in cls.__init__.__annotations__:
308-
continue # Odd behaviour, so skip.
309-
try:
310-
converter = f.metadata["converter"]
311-
except KeyError:
312-
pass
313-
else:
304+
if has_dataclass_init:
305+
for f in dataclasses.fields(cls):
306+
if not f.init:
307+
continue
308+
309+
if (converter := f.metadata.get("converter")) is None:
310+
continue # No converter, so skip.
311+
314312
try:
315313
signature = inspect.signature(converter)
316314
except ValueError:
@@ -575,18 +573,16 @@ def __call__(cls, *args, **kwargs):
575573
assert not _is_abstract(cls)
576574
# [Step 3] Run converters
577575
for field in dataclasses.fields(cls):
576+
if (converter := field.metadata.get("converter")) is None:
577+
continue
578+
578579
try:
579-
converter = field.metadata["converter"]
580-
except KeyError:
580+
value = getattr(self, field.name)
581+
except AttributeError:
582+
# Let the all-fields-are-filled check handle the error.
581583
pass
582584
else:
583-
try:
584-
value = getattr(self, field.name)
585-
except AttributeError:
586-
# Let the all-fields-are-filled check handle the error.
587-
pass
588-
else:
589-
setattr(self, field.name, converter(value))
585+
setattr(self, field.name, converter(value))
590586
# [Step 4] Check that all fields are occupied.
591587
missing_names = {
592588
field.name

0 commit comments

Comments
 (0)