@@ -124,6 +124,22 @@ def new_get_state(self, *args, **kwargs):
124124 return new_get_state
125125
126126
127+ def skip_link_metric (fn : _FnT ) -> _FnT :
128+ """Decorator to disable the `_link_metric_to_state` magic.
129+
130+ This is important for wrapper metrics like `kd.metrics.TreeReduce` where
131+ the state.parent should remain the wrapped metric.
132+
133+ Args:
134+ fn: The `get_state` or `empty` function to skip.
135+
136+ Returns:
137+ The function with the `_has_link_metric` flag set to True.
138+ """
139+ fn ._has_link_metric = True # pylint: disable=protected-access
140+ return fn
141+
142+
127143@flax .struct .dataclass
128144class TreeState (base_state .State ):
129145 """Holds a pytree of metric states."""
@@ -227,7 +243,7 @@ class TreeMap(_TreeMetric):
227243 class State (TreeState ):
228244 pass
229245
230- def get_state (self , ** kwargs ):
246+ def get_state (self , ** kwargs ) -> TreeMap . State :
231247 state_tree = self ._get_tree_state (** kwargs )
232248 return self .State (state_tree )
233249
@@ -239,6 +255,7 @@ class TreeReduce(_TreeMetric):
239255 The given metric defines the aggregation method.
240256 """
241257
258+ @skip_link_metric
242259 def get_state (self , ** kwargs ) -> base_state .State :
243260 state_tree = self ._get_tree_state (** kwargs )
244261 reduced_state = jax .tree .reduce (
@@ -249,6 +266,10 @@ def get_state(self, **kwargs) -> base_state.State:
249266 )
250267 return reduced_state
251268
269+ @skip_link_metric
270+ def empty (self ) -> base_state .State :
271+ return self .metric .empty ()
272+
252273
253274def _tree_map_with_kwargs (fun , ** kwargs ):
254275 """Same as jax.tree.map but taking and passing trees to fun as kwargs."""
0 commit comments