We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
_step_log
_run_step
1 parent 7163eeb commit 8ca8fa2Copy full SHA for 8ca8fa2
axlearn/common/trainer.py
@@ -1111,8 +1111,13 @@ def _run_step(
1111
if self.step % 100 == 0 or 0 <= self.step <= 5:
1112
self._step_log(
1113
"loss=%s aux=%s",
1114
- outputs["loss"],
1115
- jax.tree.map(lambda x: x.item() if x.ndim == 0 else f"T{x.shape}", outputs["aux"]),
+ jax.block_until_ready(outputs["loss"]),
+ jax.block_until_ready(
1116
+ jax.tree.map(
1117
+ lambda x: x.item() if x.ndim == 0 else f"T{x.shape}",
1118
+ outputs["aux"]
1119
+ ),
1120
1121
)
1122
1123
with self._record_event(
0 commit comments