You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+26-9Lines changed: 26 additions & 9 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -3,43 +3,60 @@ This repository provides a jax binding to <https://github.com/Dao-AILab/flash-at
3
3
4
4
Please see [Tri Dao's repo](https://github.com/Dao-AILab/flash-attention) for more information about flash attention.
5
5
6
-
## Usage
7
-
8
6
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
9
7
Please cite (see below) and credit FlashAttention if you use it.
10
8
11
-
## Installation and features
9
+
## Installation
12
10
13
11
Requirements:
14
12
- CUDA 11.8 and above.
15
13
- Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
16
14
- JAX >=`0.4.24`. The custom sharding used for ring attention requires some somewhat advanced features.
17
15
18
-
To install: For now, download the appropriate release from the releases page and install it with pip.
16
+
To install: `pip install flash-attn-jax` will get the latest release from pypi. This gives you the cuda 12.3 build. If you want to use the cuda 11.8 build, you can install from the releases page (but according to jax's documentation, 11.8 will stop being supported for newer versions of jax).
17
+
18
+
### Installing from source
19
+
20
+
Flash attention takes a long time to compile unless you have a powerful machine. But if you want to compile from source, I use `cibuildwheel` to compile the releases. You could do the same. Something like (for python 3.12):
cibuildwheel --only cp312-manylinux_x86_64 # I think cibuildwheel needs superuser privileges on some systems because of docker reasons?
26
+
```
27
+
28
+
This will create a wheel in the `wheelhouse` directory. You can then install it with `pip install wheelhouse/flash_attn_jax_0.2.0-cp312-cp312-manylinux_x86_64.whl`. Or you could use setup.py to build the wheel and install it. You need cuda toolkit installed in that case.
Accepts q,k,v with shape `[n, l, h, d]`, and returns `[n, l, h, d]`. `softmax_scale` is the
29
-
multiplier for the softmax, defaulting to `1/sqrt(d)`. Set window_size
30
-
to positive values for sliding window attention.
41
+
This supports multi-query and grouped-query attention (when hk != h). The `softmax_scale` is the multiplier for the softmax, defaulting to `1/sqrt(d)`. Set `window_size` to positive values for sliding window attention.
31
42
32
43
### Now Supports Ring Attention
33
44
34
-
Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm:
45
+
Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm (forward and backward).
sharding = NamedSharding(mesh, P(None,'len',None)) # n l d
51
+
sharding = NamedSharding(mesh, P(None,'len')) # n l
39
52
tokens = jax.device_put(tokens, sharding)
40
53
# invoke your jax.jit'd transformer.forward
41
54
```
42
55
56
+
It's not entirely reliable at hiding the communication latency though, depending on the whims of the xla optimizer. I'm waiting https://github.com/google/jax/issues/20864 to be fixed, then I can make it better.
57
+
58
+
### GPU support
59
+
43
60
FlashAttention-2 currently supports:
44
61
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
45
62
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
0 commit comments