@@ -49,12 +49,9 @@ def static_field(**kwargs):
4949 return field (** kwargs , static = True )
5050
5151
52- _converter_sentinel : Any = doc_repr (object (), "lambda x: x" )
53-
54-
5552def field (
5653 * ,
57- converter : Callable [[Any ], Any ] = _converter_sentinel ,
54+ converter : Callable [[Any ], Any ] | None = None ,
5855 static : bool = False ,
5956 ** kwargs ,
6057):
@@ -67,6 +64,7 @@ def field(
6764 `bool`/`int`/`float`/`complex` values to JAX arrays. This is ran after the
6865 `__init__` method (i.e. when using a user-provided `__init__`), and after
6966 `__post_init__` (i.e. when using the default dataclass initialisation).
67+ If `converter` is `None`, then no converter is registered.
7068 - `static`: whether the field should not interact with any JAX transform at all (by
7169 making it part of the PyTree structure rather than a leaf).
7270 - `**kwargs`: All other keyword arguments are passed on to `dataclass.field`.
@@ -105,7 +103,7 @@ class MyModule(eqx.Module):
105103 except KeyError :
106104 metadata = {}
107105 if "converter" in metadata :
108- raise ValueError ("Cannot use metadata with `static ` already set." )
106+ raise ValueError ("Cannot use metadata with `converter ` already set." )
109107 if "static" in metadata :
110108 raise ValueError ("Cannot use metadata with `static` already set." )
111109 # We don't just use `lambda x: x` as the default, so that this works:
@@ -123,7 +121,7 @@ class MyModule(eqx.Module):
123121 # Oddities like the above are to be discouraged, of course, but in particular
124122 # `field(init=False)` was sometimes used to denote an abstract field (prior to the
125123 # introduction of `AbstractVar`), so we do want to support this.
126- if converter is not _converter_sentinel :
124+ if converter is not None :
127125 metadata ["converter" ] = converter
128126 if static :
129127 metadata ["static" ] = True
@@ -303,14 +301,14 @@ def __new__(
303301 # checkers.
304302 # Note that mutating the `__init__.__annotations__` is okay, as it was created
305303 # by the dataclass decorator on the previous line, so nothing else owns it.
306- for f in dataclasses . fields ( cls ) :
307- if f . name not in cls . __init__ . __annotations__ :
308- continue # Odd behaviour, so skip.
309- try :
310- converter = f . metadata [ "converter" ]
311- except KeyError :
312- pass
313- else :
304+ if has_dataclass_init :
305+ for f in dataclasses . fields ( cls ) :
306+ if not f . init :
307+ continue
308+
309+ if ( converter := f . metadata . get ( "converter" )) is None :
310+ continue # No converter, so skip.
311+
314312 try :
315313 signature = inspect .signature (converter )
316314 except ValueError :
@@ -575,18 +573,16 @@ def __call__(cls, *args, **kwargs):
575573 assert not _is_abstract (cls )
576574 # [Step 3] Run converters
577575 for field in dataclasses .fields (cls ):
576+ if (converter := field .metadata .get ("converter" )) is None :
577+ continue
578+
578579 try :
579- converter = field .metadata ["converter" ]
580- except KeyError :
580+ value = getattr (self , field .name )
581+ except AttributeError :
582+ # Let the all-fields-are-filled check handle the error.
581583 pass
582584 else :
583- try :
584- value = getattr (self , field .name )
585- except AttributeError :
586- # Let the all-fields-are-filled check handle the error.
587- pass
588- else :
589- setattr (self , field .name , converter (value ))
585+ setattr (self , field .name , converter (value ))
590586 # [Step 4] Check that all fields are occupied.
591587 missing_names = {
592588 field .name
0 commit comments