diff --git a/equinox/_module.py b/equinox/_module.py index 69a1bc79..497ca7a8 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -48,12 +48,9 @@ def static_field(**kwargs): return field(**kwargs, static=True) -_converter_sentinel: Any = doc_repr(object(), "lambda x: x") - - def field( *, - converter: Callable[[Any], Any] = _converter_sentinel, + converter: Callable[[Any], Any] | None = None, static: bool = False, **kwargs, ): @@ -66,6 +63,7 @@ def field( `bool`/`int`/`float`/`complex` values to JAX arrays. This is ran after the `__init__` method (i.e. when using a user-provided `__init__`), and after `__post_init__` (i.e. when using the default dataclass initialisation). + If `converter` is `None`, then no converter is registered. - `static`: whether the field should not interact with any JAX transform at all (by making it part of the PyTree structure rather than a leaf). - `**kwargs`: All other keyword arguments are passed on to `dataclass.field`. @@ -104,7 +102,7 @@ class MyModule(eqx.Module): except KeyError: metadata = {} if "converter" in metadata: - raise ValueError("Cannot use metadata with `static` already set.") + raise ValueError("Cannot use metadata with `converter` already set.") if "static" in metadata: raise ValueError("Cannot use metadata with `static` already set.") # We don't just use `lambda x: x` as the default, so that this works: @@ -122,7 +120,7 @@ class MyModule(eqx.Module): # Oddities like the above are to be discouraged, of course, but in particular # `field(init=False)` was sometimes used to denote an abstract field (prior to the # introduction of `AbstractVar`), so we do want to support this. - if converter is not _converter_sentinel: + if converter is not None: metadata["converter"] = converter if static: metadata["static"] = True @@ -302,14 +300,14 @@ def __new__( # checkers. # Note that mutating the `__init__.__annotations__` is okay, as it was created # by the dataclass decorator on the previous line, so nothing else owns it. - for f in dataclasses.fields(cls): - if f.name not in cls.__init__.__annotations__: - continue # Odd behaviour, so skip. - try: - converter = f.metadata["converter"] - except KeyError: - pass - else: + if has_dataclass_init: + for f in dataclasses.fields(cls): + if not f.init: + continue + + if (converter := f.metadata.get("converter")) is None: + continue # No converter, so skip. + try: signature = inspect.signature(converter) except ValueError: @@ -574,18 +572,16 @@ def __call__(cls, *args, **kwargs): assert not _is_abstract(cls) # [Step 3] Run converters for field in dataclasses.fields(cls): + if (converter := field.metadata.get("converter")) is None: + continue + try: - converter = field.metadata["converter"] - except KeyError: + value = getattr(self, field.name) + except AttributeError: + # Let the all-fields-are-filled check handle the error. pass else: - try: - value = getattr(self, field.name) - except AttributeError: - # Let the all-fields-are-filled check handle the error. - pass - else: - setattr(self, field.name, converter(value)) + setattr(self, field.name, converter(value)) # [Step 4] Check that all fields are occupied. missing_names = { field.name