@@ -48,12 +48,9 @@ def static_field(**kwargs):
4848 return field (** kwargs , static = True )
4949
5050
51- _converter_sentinel : Any = doc_repr (object (), "lambda x: x" )
52-
53-
5451def field (
5552 * ,
56- converter : Callable [[Any ], Any ] = _converter_sentinel ,
53+ converter : Callable [[Any ], Any ] | None = None ,
5754 static : bool = False ,
5855 ** kwargs ,
5956):
@@ -66,6 +63,7 @@ def field(
6663 `bool`/`int`/`float`/`complex` values to JAX arrays. This is ran after the
6764 `__init__` method (i.e. when using a user-provided `__init__`), and after
6865 `__post_init__` (i.e. when using the default dataclass initialisation).
66+ If `converter` is `None`, then no converter is registered.
6967 - `static`: whether the field should not interact with any JAX transform at all (by
7068 making it part of the PyTree structure rather than a leaf).
7169 - `**kwargs`: All other keyword arguments are passed on to `dataclass.field`.
@@ -104,7 +102,7 @@ class MyModule(eqx.Module):
104102 except KeyError :
105103 metadata = {}
106104 if "converter" in metadata :
107- raise ValueError ("Cannot use metadata with `static ` already set." )
105+ raise ValueError ("Cannot use metadata with `converter ` already set." )
108106 if "static" in metadata :
109107 raise ValueError ("Cannot use metadata with `static` already set." )
110108 # We don't just use `lambda x: x` as the default, so that this works:
@@ -122,7 +120,7 @@ class MyModule(eqx.Module):
122120 # Oddities like the above are to be discouraged, of course, but in particular
123121 # `field(init=False)` was sometimes used to denote an abstract field (prior to the
124122 # introduction of `AbstractVar`), so we do want to support this.
125- if converter is not _converter_sentinel :
123+ if converter is not None :
126124 metadata ["converter" ] = converter
127125 if static :
128126 metadata ["static" ] = True
@@ -302,14 +300,14 @@ def __new__(
302300 # checkers.
303301 # Note that mutating the `__init__.__annotations__` is okay, as it was created
304302 # by the dataclass decorator on the previous line, so nothing else owns it.
305- for f in dataclasses . fields ( cls ) :
306- if f . name not in cls . __init__ . __annotations__ :
307- continue # Odd behaviour, so skip.
308- try :
309- converter = f . metadata [ "converter" ]
310- except KeyError :
311- pass
312- else :
303+ if has_dataclass_init :
304+ for f in dataclasses . fields ( cls ) :
305+ if not f . init :
306+ continue
307+
308+ if ( converter := f . metadata . get ( "converter" )) is None :
309+ continue # No converter, so skip.
310+
313311 try :
314312 signature = inspect .signature (converter )
315313 except ValueError :
@@ -574,18 +572,16 @@ def __call__(cls, *args, **kwargs):
574572 assert not _is_abstract (cls )
575573 # [Step 3] Run converters
576574 for field in dataclasses .fields (cls ):
575+ if (converter := field .metadata .get ("converter" )) is None :
576+ continue
577+
577578 try :
578- converter = field .metadata ["converter" ]
579- except KeyError :
579+ value = getattr (self , field .name )
580+ except AttributeError :
581+ # Let the all-fields-are-filled check handle the error.
580582 pass
581583 else :
582- try :
583- value = getattr (self , field .name )
584- except AttributeError :
585- # Let the all-fields-are-filled check handle the error.
586- pass
587- else :
588- setattr (self , field .name , converter (value ))
584+ setattr (self , field .name , converter (value ))
589585 # [Step 4] Check that all fields are occupied.
590586 missing_names = {
591587 field .name
0 commit comments