Skip to content

Commit 8ca8fa2

Browse files
committed
"Adding block_until_ready to the _step_log call in _run_step to catch DATA_LOSS errors in the main thread instead of in the log emitting thread"
1 parent 7163eeb commit 8ca8fa2

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

axlearn/common/trainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,8 +1111,13 @@ def _run_step(
11111111
if self.step % 100 == 0 or 0 <= self.step <= 5:
11121112
self._step_log(
11131113
"loss=%s aux=%s",
1114-
outputs["loss"],
1115-
jax.tree.map(lambda x: x.item() if x.ndim == 0 else f"T{x.shape}", outputs["aux"]),
1114+
jax.block_until_ready(outputs["loss"]),
1115+
jax.block_until_ready(
1116+
jax.tree.map(
1117+
lambda x: x.item() if x.ndim == 0 else f"T{x.shape}",
1118+
outputs["aux"]
1119+
),
1120+
),
11161121
)
11171122

11181123
with self._record_event(

0 commit comments

Comments
 (0)