66import warnings
77import weakref
88from collections .abc import Callable , Hashable
9- from typing import Any , cast , Final , Literal , ParamSpec , TYPE_CHECKING , TypeVar
9+ from typing import (
10+ Any ,
11+ cast ,
12+ Final ,
13+ Literal ,
14+ NamedTuple ,
15+ ParamSpec ,
16+ TYPE_CHECKING ,
17+ TypeVar ,
18+ )
1019from typing_extensions import dataclass_transform
1120
1221import jax
@@ -44,6 +53,13 @@ def StrictConfig(
4453_flatten_sentinel = object ()
4554
4655
56+ class _ModuleInfo (NamedTuple ):
57+ """Holds cached information about a Module subclass."""
58+
59+ names_tuple : tuple [str , ...]
60+ names_set : frozenset [str ]
61+
62+
4763# Used to provide a pretty repr when doing `jtu.tree_structure(SomeModule(...))`.
4864@dataclasses .dataclass (slots = True )
4965class _FlattenedData :
@@ -401,7 +417,11 @@ def __new__(
401417 cls .__init__ .__doc__ = init_doc # pyright: ignore[reportPossiblyUnboundVariable]
402418
403419 # Cache the field names for later use.
404- _module_info [cls ] = frozenset (f .name for f in fields )
420+ names = tuple (f .name for f in fields )
421+ _module_info [cls ] = _ModuleInfo (
422+ names_tuple = names ,
423+ names_set = frozenset (names ),
424+ )
405425
406426 flattener = _ModuleFlattener (fields ) # pyright: ignore[reportArgumentType]
407427 jtu .register_pytree_with_keys (
@@ -617,9 +637,7 @@ def __repr__(self) -> str:
617637
618638 def __hash__ (self ) -> int :
619639 return hash (
620- tuple (
621- (f .name , getattr (self , f .name )) for f in dataclasses .fields (type (self ))
622- )
640+ tuple ((k , getattr (self , k )) for k in _module_info [type (self )].names_tuple )
623641 )
624642
625643 def __eq__ (self , other : object , / ) -> bool | np .bool_ | Bool [Array , "" ]: # pyright: ignore
@@ -629,7 +647,8 @@ def __eq__(self, other: object, /) -> bool | np.bool_ | Bool[Array, ""]: # pyri
629647
630648 def __setattr__ (self , name : str , value : Any ) -> None :
631649 if self in _currently_initialising and (
632- name in _module_info [type (self )] or name in wrapper_field_names
650+ name in _module_info [type (self )].names_set
651+ or name in wrapper_field_names
633652 ):
634653 _error_method_assignment (self , value )
635654 _warn_jax_transformed_function (type (self ), value )
0 commit comments