We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
_step_log
_run_step
1 parent 066ddce commit 4c1a7a3Copy full SHA for 4c1a7a3
axlearn/common/trainer.py
@@ -1101,8 +1101,13 @@ def _run_step(
1101
if self.step % 100 == 0 or 0 <= self.step <= 5:
1102
self._step_log(
1103
"loss=%s aux=%s",
1104
- outputs["loss"],
1105
- jax.tree.map(lambda x: x.item() if x.ndim == 0 else f"T{x.shape}", outputs["aux"]),
+ jax.block_until_ready(outputs["loss"]),
+ jax.block_until_ready(
1106
+ jax.tree.map(
1107
+ lambda x: x.item() if x.ndim == 0 else f"T{x.shape}",
1108
+ outputs["aux"]
1109
+ ),
1110
1111
)
1112
1113
self.summary_writer(self.step, {"loss": outputs["loss"], **outputs["summaries"]})
0 commit comments