Skip to content

Commit c4741b0

Browse files
nstarmanpatrick-kidger
authored andcommitted
feat: converter = None as the sentinel (#996)
* feat: converter = None as the sentinel 1. Easier for users to access instead of a private sentinel. 2. simplifies later performance-related logic changes. 3. Broadens support to allow `dataclasses.field(metadata=dict(converter=None))`, not just `eqx.field`. Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com> * refactor: only set annotations if datclass init Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com> --------- Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
1 parent d3d4eb9 commit c4741b0

File tree

1 file changed

+19
-23
lines changed

1 file changed

+19
-23
lines changed

equinox/_module.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,9 @@ def static_field(**kwargs):
4848
return field(**kwargs, static=True)
4949

5050

51-
_converter_sentinel: Any = doc_repr(object(), "lambda x: x")
52-
53-
5451
def field(
5552
*,
56-
converter: Callable[[Any], Any] = _converter_sentinel,
53+
converter: Callable[[Any], Any] | None = None,
5754
static: bool = False,
5855
**kwargs,
5956
):
@@ -66,6 +63,7 @@ def field(
6663
`bool`/`int`/`float`/`complex` values to JAX arrays. This is ran after the
6764
`__init__` method (i.e. when using a user-provided `__init__`), and after
6865
`__post_init__` (i.e. when using the default dataclass initialisation).
66+
If `converter` is `None`, then no converter is registered.
6967
- `static`: whether the field should not interact with any JAX transform at all (by
7068
making it part of the PyTree structure rather than a leaf).
7169
- `**kwargs`: All other keyword arguments are passed on to `dataclass.field`.
@@ -104,7 +102,7 @@ class MyModule(eqx.Module):
104102
except KeyError:
105103
metadata = {}
106104
if "converter" in metadata:
107-
raise ValueError("Cannot use metadata with `static` already set.")
105+
raise ValueError("Cannot use metadata with `converter` already set.")
108106
if "static" in metadata:
109107
raise ValueError("Cannot use metadata with `static` already set.")
110108
# We don't just use `lambda x: x` as the default, so that this works:
@@ -122,7 +120,7 @@ class MyModule(eqx.Module):
122120
# Oddities like the above are to be discouraged, of course, but in particular
123121
# `field(init=False)` was sometimes used to denote an abstract field (prior to the
124122
# introduction of `AbstractVar`), so we do want to support this.
125-
if converter is not _converter_sentinel:
123+
if converter is not None:
126124
metadata["converter"] = converter
127125
if static:
128126
metadata["static"] = True
@@ -302,14 +300,14 @@ def __new__(
302300
# checkers.
303301
# Note that mutating the `__init__.__annotations__` is okay, as it was created
304302
# by the dataclass decorator on the previous line, so nothing else owns it.
305-
for f in dataclasses.fields(cls):
306-
if f.name not in cls.__init__.__annotations__:
307-
continue # Odd behaviour, so skip.
308-
try:
309-
converter = f.metadata["converter"]
310-
except KeyError:
311-
pass
312-
else:
303+
if has_dataclass_init:
304+
for f in dataclasses.fields(cls):
305+
if not f.init:
306+
continue
307+
308+
if (converter := f.metadata.get("converter")) is None:
309+
continue # No converter, so skip.
310+
313311
try:
314312
signature = inspect.signature(converter)
315313
except ValueError:
@@ -574,18 +572,16 @@ def __call__(cls, *args, **kwargs):
574572
assert not _is_abstract(cls)
575573
# [Step 3] Run converters
576574
for field in dataclasses.fields(cls):
575+
if (converter := field.metadata.get("converter")) is None:
576+
continue
577+
577578
try:
578-
converter = field.metadata["converter"]
579-
except KeyError:
579+
value = getattr(self, field.name)
580+
except AttributeError:
581+
# Let the all-fields-are-filled check handle the error.
580582
pass
581583
else:
582-
try:
583-
value = getattr(self, field.name)
584-
except AttributeError:
585-
# Let the all-fields-are-filled check handle the error.
586-
pass
587-
else:
588-
setattr(self, field.name, converter(value))
584+
setattr(self, field.name, converter(value))
589585
# [Step 4] Check that all fields are occupied.
590586
missing_names = {
591587
field.name

0 commit comments

Comments
 (0)