Skip to content

Commit 7f14de0

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[mosaic_gpu] Warmup before measuring the running time in profiler.measure
PiperOrigin-RevId: 700650380
1 parent 7a2070e commit 7f14de0

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

jax/experimental/mosaic/gpu/profiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def run(*args, **kwargs):
104104
raise ValueError("Can only measure functions with at least one output")
105105
return outs, _event_elapsed(start_event, end_event)
106106

107+
jax.block_until_ready(run(*args, **kwargs)) # Warmup.
107108
outs, elapsed = run(*args, **kwargs)
108109
return outs, float(elapsed)
109110

0 commit comments

Comments
 (0)