Skip to content

Commit b4b6b73

Browse files
Make Static be Generic
1 parent 0e8d06f commit b4b6b73

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

equinox/_module/_prebuilt.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> _Return:
8282
return self.func(*self.args, *args, **kwargs, **self.keywords)
8383

8484

85-
class Static(Module):
85+
_Value = TypeVar("_Value")
86+
87+
88+
class Static(Module, Generic[_Value]):
8689
"""Wraps a value into a `eqx.field(static=True)`.
8790
8891
This is useful to treat something as just static metadata with respect to a JAX
@@ -93,12 +96,12 @@ class Static(Module):
9396
_leaves: list[Any] = field(static=True)
9497
_treedef: PyTreeDef = field(static=True) # pyright: ignore
9598

96-
def __init__(self, value: Any):
99+
def __init__(self, value: _Value):
97100
# By flattening, we handle pytrees without `__eq__` methods.
98101
# When comparing static metadata for equality, this means we never actually
99102
# call `value.__eq__`.
100103
self._leaves, self._treedef = jtu.tree_flatten(value)
101104

102105
@property
103-
def value(self):
106+
def value(self) -> _Value:
104107
return jtu.tree_unflatten(self._treedef, self._leaves)

0 commit comments

Comments
 (0)