Skip to content

Commit 6378984

Browse files
authored
Add back the import of jtu in flash_attention.py
This was erroneously removed in de3191f.
1 parent 2b7b074 commit 6378984

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

jax/experimental/mosaic/gpu/examples/flash_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax
2323
from jax import random
2424
from jax._src.interpreters import mlir
25+
from jax._src import test_util as jtu
2526
from jax.experimental.mosaic.gpu import profiler
2627
from jax.experimental.mosaic.gpu import * # noqa: F403
2728
import jax.numpy as jnp

0 commit comments

Comments
 (0)