Skip to content

Commit a50f9dd

Browse files
QwlouseThe kauldron Authors
authored andcommitted
WIP: fix the state.parent for TreeReduce
PiperOrigin-RevId: 672502390
1 parent 3bcf564 commit a50f9dd

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

kauldron/metrics/base.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,22 @@ def new_get_state(self, *args, **kwargs):
124124
return new_get_state
125125

126126

127+
def skip_link_metric(fn: _FnT) -> _FnT:
128+
"""Decorator to disable the `_link_metric_to_state` magic.
129+
130+
This is important for wrapper metrics like `kd.metrics.TreeReduce` where
131+
the state.parent should remain the wrapped metric.
132+
133+
Args:
134+
fn: The `get_state` or `empty` function to skip.
135+
136+
Returns:
137+
The function with the `_has_link_metric` flag set to True.
138+
"""
139+
fn._has_link_metric = True # pylint: disable=protected-access
140+
return fn
141+
142+
127143
@flax.struct.dataclass
128144
class TreeState(base_state.State):
129145
"""Holds a pytree of metric states."""
@@ -227,7 +243,7 @@ class TreeMap(_TreeMetric):
227243
class State(TreeState):
228244
pass
229245

230-
def get_state(self, **kwargs):
246+
def get_state(self, **kwargs) -> TreeMap.State:
231247
state_tree = self._get_tree_state(**kwargs)
232248
return self.State(state_tree)
233249

@@ -239,6 +255,7 @@ class TreeReduce(_TreeMetric):
239255
The given metric defines the aggregation method.
240256
"""
241257

258+
@skip_link_metric
242259
def get_state(self, **kwargs) -> base_state.State:
243260
state_tree = self._get_tree_state(**kwargs)
244261
reduced_state = jax.tree.reduce(
@@ -249,6 +266,10 @@ def get_state(self, **kwargs) -> base_state.State:
249266
)
250267
return reduced_state
251268

269+
@skip_link_metric
270+
def empty(self) -> base_state.State:
271+
return self.metric.empty()
272+
252273

253274
def _tree_map_with_kwargs(fun, **kwargs):
254275
"""Same as jax.tree.map but taking and passing trees to fun as kwargs."""

0 commit comments

Comments
 (0)