Skip to content

Commit 20414a7

Browse files
nstarmanpatrick-kidger
authored andcommitted
perf: opt hash
And use in the hash Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 9ce2a0c commit 20414a7

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

equinox/_module/_module.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,16 @@
66
import warnings
77
import weakref
88
from collections.abc import Callable, Hashable
9-
from typing import Any, cast, Final, Literal, ParamSpec, TYPE_CHECKING, TypeVar
9+
from typing import (
10+
Any,
11+
cast,
12+
Final,
13+
Literal,
14+
NamedTuple,
15+
ParamSpec,
16+
TYPE_CHECKING,
17+
TypeVar,
18+
)
1019
from typing_extensions import dataclass_transform
1120

1221
import jax
@@ -44,6 +53,13 @@ def StrictConfig(
4453
_flatten_sentinel = object()
4554

4655

56+
class _ModuleInfo(NamedTuple):
57+
"""Holds cached information about a Module subclass."""
58+
59+
names_tuple: tuple[str, ...]
60+
names_set: frozenset[str]
61+
62+
4763
# Used to provide a pretty repr when doing `jtu.tree_structure(SomeModule(...))`.
4864
@dataclasses.dataclass(slots=True)
4965
class _FlattenedData:
@@ -401,7 +417,11 @@ def __new__(
401417
cls.__init__.__doc__ = init_doc # pyright: ignore[reportPossiblyUnboundVariable]
402418

403419
# Cache the field names for later use.
404-
_module_info[cls] = frozenset(f.name for f in fields)
420+
names = tuple(f.name for f in fields)
421+
_module_info[cls] = _ModuleInfo(
422+
names_tuple=names,
423+
names_set=frozenset(names),
424+
)
405425

406426
flattener = _ModuleFlattener(fields) # pyright: ignore[reportArgumentType]
407427
jtu.register_pytree_with_keys(
@@ -617,9 +637,7 @@ def __repr__(self) -> str:
617637

618638
def __hash__(self) -> int:
619639
return hash(
620-
tuple(
621-
(f.name, getattr(self, f.name)) for f in dataclasses.fields(type(self))
622-
)
640+
tuple((k, getattr(self, k)) for k in _module_info[type(self)].names_tuple)
623641
)
624642

625643
def __eq__(self, other: object, /) -> bool | np.bool_ | Bool[Array, ""]: # pyright: ignore
@@ -629,7 +647,8 @@ def __eq__(self, other: object, /) -> bool | np.bool_ | Bool[Array, ""]: # pyri
629647

630648
def __setattr__(self, name: str, value: Any) -> None:
631649
if self in _currently_initialising and (
632-
name in _module_info[type(self)] or name in wrapper_field_names
650+
name in _module_info[type(self)].names_set
651+
or name in wrapper_field_names
633652
):
634653
_error_method_assignment(self, value)
635654
_warn_jax_transformed_function(type(self), value)

0 commit comments

Comments
 (0)