Skip to content

Commit 4c1a7a3

Browse files
lukebaumannRoshaniN
authored andcommitted
"Adding block_until_ready to the _step_log call in _run_step to catch DATA_LOSS errors in the main thread instead of in the log emitting thread"
1 parent 066ddce commit 4c1a7a3

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

axlearn/common/trainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,8 +1101,13 @@ def _run_step(
11011101
if self.step % 100 == 0 or 0 <= self.step <= 5:
11021102
self._step_log(
11031103
"loss=%s aux=%s",
1104-
outputs["loss"],
1105-
jax.tree.map(lambda x: x.item() if x.ndim == 0 else f"T{x.shape}", outputs["aux"]),
1104+
jax.block_until_ready(outputs["loss"]),
1105+
jax.block_until_ready(
1106+
jax.tree.map(
1107+
lambda x: x.item() if x.ndim == 0 else f"T{x.shape}",
1108+
outputs["aux"]
1109+
),
1110+
),
11061111
)
11071112

11081113
self.summary_writer(self.step, {"loss": outputs["loss"], **outputs["summaries"]})

0 commit comments

Comments
 (0)