We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
jtu
flash_attention.py
1 parent 2b7b074 commit 6378984Copy full SHA for 6378984
jax/experimental/mosaic/gpu/examples/flash_attention.py
@@ -22,6 +22,7 @@
22
import jax
23
from jax import random
24
from jax._src.interpreters import mlir
25
+from jax._src import test_util as jtu
26
from jax.experimental.mosaic.gpu import profiler
27
from jax.experimental.mosaic.gpu import * # noqa: F403
28
import jax.numpy as jnp
0 commit comments