Skip to content

Commit 565ee42

Browse files
committed
Update README with installation steps.
1 parent d43cbca commit 565ee42

File tree

3 files changed

+63
-43
lines changed

3 files changed

+63
-43
lines changed

.github/workflows/publish.yml

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -109,36 +109,36 @@ jobs:
109109
name: ${{env.wheel_name}}
110110
path: ./wheelhouse/${{env.wheel_name}}
111111

112-
publish_package:
113-
name: Publish package
114-
needs: [build_wheels]
115-
116-
runs-on: ubuntu-latest
117-
permissions:
118-
id-token: write
119-
120-
steps:
121-
- uses: actions/checkout@v3
122-
123-
- uses: actions/setup-python@v4
124-
with:
125-
python-version: '3.10'
126-
127-
- name: Install dependencies
128-
run: |
129-
pip install setuptools==68.0.0
130-
pip install git+https://github.com/nshepperd/setuptools-cuda-cpp
131-
pip install ninja packaging wheel pybind11
132-
133-
- name: Build core package
134-
run: |
135-
CUDA_HOME=/ python setup.py sdist --dist-dir=dist
136-
137-
- name: Retrieve release distributions
138-
uses: actions/download-artifact@v4
139-
with:
140-
path: dist/
141-
merge-multiple: true
142-
143-
- name: Publish release distributions to PyPI
144-
uses: pypa/gh-action-pypi-publish@release/v1
112+
# publish_package:
113+
# name: Publish package
114+
# needs: [build_wheels]
115+
116+
# runs-on: ubuntu-latest
117+
# permissions:
118+
# id-token: write
119+
120+
# steps:
121+
# - uses: actions/checkout@v3
122+
123+
# - uses: actions/setup-python@v4
124+
# with:
125+
# python-version: '3.10'
126+
127+
# - name: Install dependencies
128+
# run: |
129+
# pip install setuptools==68.0.0
130+
# pip install git+https://github.com/nshepperd/setuptools-cuda-cpp
131+
# pip install ninja packaging wheel pybind11
132+
133+
# - name: Build core package
134+
# run: |
135+
# CUDA_HOME=/ python setup.py sdist --dist-dir=dist
136+
137+
# - name: Retrieve release distributions
138+
# uses: actions/download-artifact@v4
139+
# with:
140+
# path: dist/
141+
# merge-multiple: true
142+
143+
# - name: Publish release distributions to PyPI
144+
# uses: pypa/gh-action-pypi-publish@release/v1

README.md

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,43 +3,60 @@ This repository provides a jax binding to <https://github.com/Dao-AILab/flash-at
33

44
Please see [Tri Dao's repo](https://github.com/Dao-AILab/flash-attention) for more information about flash attention.
55

6-
## Usage
7-
86
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
97
Please cite (see below) and credit FlashAttention if you use it.
108

11-
## Installation and features
9+
## Installation
1210

1311
Requirements:
1412
- CUDA 11.8 and above.
1513
- Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
1614
- JAX >=`0.4.24`. The custom sharding used for ring attention requires some somewhat advanced features.
1715

18-
To install: For now, download the appropriate release from the releases page and install it with pip.
16+
To install: `pip install flash-attn-jax` will get the latest release from pypi. This gives you the cuda 12.3 build. If you want to use the cuda 11.8 build, you can install from the releases page (but according to jax's documentation, 11.8 will stop being supported for newer versions of jax).
17+
18+
### Installing from source
19+
20+
Flash attention takes a long time to compile unless you have a powerful machine. But if you want to compile from source, I use `cibuildwheel` to compile the releases. You could do the same. Something like (for python 3.12):
21+
22+
```sh
23+
git clone https://github.com/nshepperd/flash-attn-jax
24+
cd flash-attn-jax
25+
cibuildwheel --only cp312-manylinux_x86_64 # I think cibuildwheel needs superuser privileges on some systems because of docker reasons?
26+
```
27+
28+
This will create a wheel in the `wheelhouse` directory. You can then install it with `pip install wheelhouse/flash_attn_jax_0.2.0-cp312-cp312-manylinux_x86_64.whl`. Or you could use setup.py to build the wheel and install it. You need cuda toolkit installed in that case.
29+
30+
## Usage
1931

2032
Interface: `src/flash_attn_jax/flash.py`
2133

2234
```py
2335
from flash_attn_jax import flash_mha
2436

37+
# flash_mha : [n, l, h, d] x [n, lk, hk, d] x [n, lk, hk, d] -> [n, l, h, d]
2538
flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1))
2639
```
2740

28-
Accepts q,k,v with shape `[n, l, h, d]`, and returns `[n, l, h, d]`. `softmax_scale` is the
29-
multiplier for the softmax, defaulting to `1/sqrt(d)`. Set window_size
30-
to positive values for sliding window attention.
41+
This supports multi-query and grouped-query attention (when hk != h). The `softmax_scale` is the multiplier for the softmax, defaulting to `1/sqrt(d)`. Set `window_size` to positive values for sliding window attention.
3142

3243
### Now Supports Ring Attention
3344

34-
Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm:
45+
Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm (forward and backward).
3546

3647
```py
48+
os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_collectives=true'
49+
#...
3750
with Mesh(devices, axis_names=('len',)) as mesh:
38-
sharding = NamedSharding(mesh, P(None,'len',None)) # n l d
51+
sharding = NamedSharding(mesh, P(None,'len')) # n l
3952
tokens = jax.device_put(tokens, sharding)
4053
# invoke your jax.jit'd transformer.forward
4154
```
4255

56+
It's not entirely reliable at hiding the communication latency though, depending on the whims of the xla optimizer. I'm waiting https://github.com/google/jax/issues/20864 to be fixed, then I can make it better.
57+
58+
### GPU support
59+
4360
FlashAttention-2 currently supports:
4461
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
4562
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
[build-system]
2-
requires = ["setuptools", "wheel", "setuptools-cuda-cpp @ git+https://github.com/nshepperd/setuptools-cuda-cpp", "packaging", "pybind11"]
2+
requires = ["setuptools", "wheel", "setuptools-cuda-cpp @ git+https://github.com/nshepperd/setuptools-cuda-cpp", "packaging", "pybind11"]
3+
4+
[tool.cibuildwheel]
5+
manylinux-x86_64-image = "sameli/manylinux_2_28_x86_64_cuda_12.3"

0 commit comments

Comments
 (0)