You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This repository provides a jax binding to <https://github.com/Dao-AILab/flash-attention>. To avoid depending on pytorch, since torch and jax installations often conflict, this is a fork of the official repo.
3
3
4
-
Please see [Tri Dao's repo](https://github.com/Dao-AILab/flash-attention) for more information about flash attention.
4
+
Please see [Tri Dao's repo](https://github.com/Dao-AILab/flash-attention) for more information about flash attention. Also check there for how to cite the authors if you used flash attention in your work.
5
5
6
6
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
7
7
Please cite (see below) and credit FlashAttention if you use it.
8
8
9
9
## Installation
10
10
11
11
Requirements:
12
-
- CUDA 11.8 and above.
12
+
- CUDA 12.8 and above.
13
13
- Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
14
-
- JAX >=`0.4.24`. The custom sharding used for ring attention requires some somewhat advanced features.
14
+
- JAX >=`0.5.*`. The custom call api changed in this version.
15
15
16
-
To install: `pip install flash-attn-jax` will get the latest release from pypi. This gives you the cuda 12.3 build. If you want to use the cuda 11.8 build, you can install from the releases page (but according to jax's documentation, 11.8 will stop being supported for newer versions of jax).
16
+
To install: `pip install flash-attn-jax` will get the latest release from pypi. This gives you the cuda 12.8
17
+
build. CUDA 11 isn't supported any more (since jax stopped supporting it).
17
18
18
19
### Installing from source
19
20
@@ -25,7 +26,7 @@ cd flash-attn-jax
25
26
cibuildwheel --only cp312-manylinux_x86_64 # I think cibuildwheel needs superuser privileges on some systems because of docker reasons?
26
27
```
27
28
28
-
This will create a wheel in the `wheelhouse` directory. You can then install it with `pip install wheelhouse/flash_attn_jax_0.2.0-cp312-cp312-manylinux_x86_64.whl`. Or you could use setup.py to build the wheel and install it. You need cuda toolkit installed in that case.
29
+
This will create a wheel in the `wheelhouse` directory. You can then install it with `pip install wheelhouse/flash_attn_jax_*.whl`. Or you could build it without docker using `uv build --wheel`. You need cuda installed in that case.
29
30
30
31
## Usage
31
32
@@ -45,15 +46,16 @@ This supports multi-query and grouped-query attention (when hk != h). The `softm
45
46
Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm (forward and backward).
sharding = NamedSharding(mesh, P(None,'len')) # n l
52
53
tokens = jax.device_put(tokens, sharding)
53
54
# invoke your jax.jit'd transformer.forward
54
55
```
55
56
56
-
It's not entirely reliable at hiding the communication latency though, depending on the whims of the xla optimizer. I'm waiting https://github.com/google/jax/issues/20864 to be fixed, then I can make it better.
57
+
The latency hiding seems to be reliable now that some bugs have been fixed, as long as you enable the
58
+
latency hiding scheduler as above.
57
59
58
60
### GPU support
59
61
@@ -63,19 +65,3 @@ FlashAttention-2 currently supports:
63
65
GPUs for now.
64
66
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
65
67
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
66
-
67
-
## Citation
68
-
If you use this codebase, or otherwise found our work valuable, please cite:
69
-
```
70
-
@inproceedings{dao2022flashattention,
71
-
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
72
-
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
73
-
booktitle={Advances in Neural Information Processing Systems},
74
-
year={2022}
75
-
}
76
-
@article{dao2023flashattention2,
77
-
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
0 commit comments