Skip to content

Commit a224720

Browse files
committed
Fix the library for jax 0.7.0
1 parent 0eef5a9 commit a224720

File tree

12 files changed

+647
-132
lines changed

12 files changed

+647
-132
lines changed

.github/workflows/publish.yml

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ jobs:
4141
fail-fast: false
4242
matrix:
4343
os: [ubuntu-20.04]
44-
python-version: ['cp39', 'cp310', 'cp311', 'cp312']
45-
cuda-version: ['11.8', '12.3']
44+
python-version: ['cp311', 'cp312']
45+
cuda-version: ['12.8']
4646

4747
steps:
4848
- name: Checkout
@@ -51,7 +51,7 @@ jobs:
5151
- name: Set up python
5252
uses: actions/setup-python@v4
5353
with:
54-
python-version: '3.10'
54+
python-version: '3.11'
5555

5656
- name: Set CUDA and PyTorch versions
5757
run: |
@@ -76,7 +76,7 @@ jobs:
7676
uses: pypa/[email protected]
7777
env:
7878
CIBW_BUILD: ${{ matrix.python-version }}-manylinux_x86_64
79-
CIBW_MANYLINUX_X86_64_IMAGE: sameli/manylinux2014_x86_64_cuda_${{ matrix.cuda-version }}
79+
CIBW_BEFORE_ALL: bash scripts/install-cuda-linux.sh ${{ matrix.cuda-version }}
8080
CIBW_BUILD_VERBOSITY: 1
8181

8282
- name: Log Built Wheels
@@ -128,17 +128,15 @@ jobs:
128128

129129
- uses: actions/setup-python@v4
130130
with:
131-
python-version: '3.10'
131+
python-version: '3.11'
132132

133133
- name: Install dependencies
134134
run: |
135-
pip install setuptools==68.0.0
136-
pip install git+https://github.com/nshepperd/setuptools-cuda-cpp
137-
pip install ninja packaging wheel pybind11
135+
pip install uv
138136
139137
- name: Build core package
140138
run: |
141-
CUDA_HOME=/ python setup.py sdist --dist-dir=dist
139+
uv build --sdist
142140
143141
- name: Retrieve release distributions
144142
uses: actions/download-artifact@v4

README.md

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
# FlashAttention JAX
22
This repository provides a jax binding to <https://github.com/Dao-AILab/flash-attention>. To avoid depending on pytorch, since torch and jax installations often conflict, this is a fork of the official repo.
33

4-
Please see [Tri Dao's repo](https://github.com/Dao-AILab/flash-attention) for more information about flash attention.
4+
Please see [Tri Dao's repo](https://github.com/Dao-AILab/flash-attention) for more information about flash attention. Also check there for how to cite the authors if you used flash attention in your work.
55

66
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
77
Please cite (see below) and credit FlashAttention if you use it.
88

99
## Installation
1010

1111
Requirements:
12-
- CUDA 11.8 and above.
12+
- CUDA 12.8 and above.
1313
- Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
14-
- JAX >=`0.4.24`. The custom sharding used for ring attention requires some somewhat advanced features.
14+
- JAX >= `0.5.*`. The custom call api changed in this version.
1515

16-
To install: `pip install flash-attn-jax` will get the latest release from pypi. This gives you the cuda 12.3 build. If you want to use the cuda 11.8 build, you can install from the releases page (but according to jax's documentation, 11.8 will stop being supported for newer versions of jax).
16+
To install: `pip install flash-attn-jax` will get the latest release from pypi. This gives you the cuda 12.8
17+
build. CUDA 11 isn't supported any more (since jax stopped supporting it).
1718

1819
### Installing from source
1920

@@ -25,7 +26,7 @@ cd flash-attn-jax
2526
cibuildwheel --only cp312-manylinux_x86_64 # I think cibuildwheel needs superuser privileges on some systems because of docker reasons?
2627
```
2728

28-
This will create a wheel in the `wheelhouse` directory. You can then install it with `pip install wheelhouse/flash_attn_jax_0.2.0-cp312-cp312-manylinux_x86_64.whl`. Or you could use setup.py to build the wheel and install it. You need cuda toolkit installed in that case.
29+
This will create a wheel in the `wheelhouse` directory. You can then install it with `pip install wheelhouse/flash_attn_jax_*.whl`. Or you could build it without docker using `uv build --wheel`. You need cuda installed in that case.
2930

3031
## Usage
3132

@@ -45,15 +46,16 @@ This supports multi-query and grouped-query attention (when hk != h). The `softm
4546
Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm (forward and backward).
4647

4748
```py
48-
os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_collectives=true'
49+
os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true'
4950
#...
5051
with Mesh(devices, axis_names=('len',)) as mesh:
5152
sharding = NamedSharding(mesh, P(None,'len')) # n l
5253
tokens = jax.device_put(tokens, sharding)
5354
# invoke your jax.jit'd transformer.forward
5455
```
5556

56-
It's not entirely reliable at hiding the communication latency though, depending on the whims of the xla optimizer. I'm waiting https://github.com/google/jax/issues/20864 to be fixed, then I can make it better.
57+
The latency hiding seems to be reliable now that some bugs have been fixed, as long as you enable the
58+
latency hiding scheduler as above.
5759

5860
### GPU support
5961

@@ -63,19 +65,3 @@ FlashAttention-2 currently supports:
6365
GPUs for now.
6466
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
6567
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
66-
67-
## Citation
68-
If you use this codebase, or otherwise found our work valuable, please cite:
69-
```
70-
@inproceedings{dao2022flashattention,
71-
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
72-
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
73-
booktitle={Advances in Neural Information Processing Systems},
74-
year={2022}
75-
}
76-
@article{dao2023flashattention2,
77-
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
78-
author={Dao, Tri},
79-
year={2023}
80-
}
81-
```

pyproject.toml

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,35 @@ requires = [
66
"packaging",
77
"psutil",
88
"pybind11>=2.11.0",
9-
# "nvidia-cuda-runtime-cu12>=12.0",
10-
# "nvidia-cuda-nvrtc-cu12",
11-
# "nvidia-nvtx-cu12",
12-
"torch>=2.0.0",
139
]
1410
build-backend = "scikit_build_core.build"
1511

1612
[project]
1713
name = "flash_attn_jax"
1814
dynamic = ["version"]
19-
description = "Flash Attention: Fast and Memory-Efficient Exact Attention"
15+
description = "Flash Attention port for JAX"
2016
readme = "README.md"
21-
requires-python = ">=3.9"
17+
requires-python = ">=3.11"
2218
license = { text = "BSD-3-Clause" }
2319
authors = [
2420
{ name = "Tri Dao", email = "[email protected]" },
2521
{ name = "Emily Shepperd", email = "[email protected]" }
2622
]
27-
dependencies = []
23+
dependencies = [
24+
"jax>=0.5.0, <0.8.0"
25+
]
2826
classifiers = [
2927
"Programming Language :: Python :: 3",
3028
"License :: OSI Approved :: BSD License",
3129
"Operating System :: Unix",
3230
]
3331

32+
[dependency-groups]
33+
test = [
34+
"pytest>=7.0.0",
35+
"einops",
36+
"jax[cuda12]",
37+
]
3438
[project.urls]
3539
Homepage = "https://github.com/nshepperd/flash_attn_jax"
3640

@@ -59,7 +63,7 @@ input = "src/flash_attn_jax/__init__.py"
5963
manylinux-x86_64-image = "quay.io/pypa/manylinux_2_28_x86_64:latest"
6064
before-all = "bash scripts/install-cuda-linux.sh"
6165
build = "cp312-manylinux_x86_64"
62-
repair-wheel-command = "auditwheel repair --exclude=libcudart.so* --exclude libtorch.so* -w {dest_dir} {wheel}"
66+
repair-wheel-command = "auditwheel repair --exclude=libcudart.so* -w {dest_dir} {wheel}"
6367

6468
[tool.cibuildwheel.environment]
6569
PATH="/opt/rh/gcc-toolset-13/root/usr/bin:/usr/local/cuda/bin:$PATH"

scripts/install-cuda-linux.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/bin/bash
22
set -eux
33

4-
VER=${1:-12.4}
4+
VER=${1:-12.8}
55
VER=${VER//./-} # Convert version to format used in package names
66

77
dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo

src/flash_attn_jax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .flash import flash_mha
2-
__version__ = 'v0.2.2'
2+
__version__ = 'v0.3.0'

src/flash_attn_jax/flash.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
from jax.lib import xla_client
1313
from jaxlib.hlo_helpers import custom_call
1414
from jax.experimental.custom_partitioning import custom_partitioning
15+
from jax.extend.core import Primitive
1516

1617
from jax.sharding import PartitionSpec as P
1718
from jax.sharding import Mesh
1819
from jax.sharding import NamedSharding
19-
from jax.sharding import PositionalSharding
2020

2121
from einops import rearrange
2222
import einops
@@ -31,11 +31,11 @@
3131
# about sharding or padding, which will be handled when they are
3232
# lowered to hlo, using the physical "hlo" primitives, which directly
3333
# lower to XLA CustomCall.
34-
_flash_mha_fwd_p = core.Primitive("flash_mha_fwd")
34+
_flash_mha_fwd_p = Primitive("flash_mha_fwd")
3535
_flash_mha_fwd_p.multiple_results = True
3636
_flash_mha_fwd_p.def_impl(partial(xla.apply_primitive, _flash_mha_fwd_p))
3737

38-
_flash_mha_bwd_p = core.Primitive("flash_mha_bwd")
38+
_flash_mha_bwd_p = Primitive("flash_mha_bwd")
3939
_flash_mha_bwd_p.multiple_results = True
4040
_flash_mha_bwd_p.def_impl(partial(xla.apply_primitive, _flash_mha_bwd_p))
4141

@@ -79,7 +79,7 @@ def _flash_mha_fwd_abstract(q, k, v, softmax_scale=None, is_causal=None, window_
7979
assert q_dtype == k_dtype and q_dtype == v_dtype
8080
assert q_dtype in [jnp.bfloat16, jnp.float16]
8181
return (
82-
ShapedArray(q.shape, q_dtype, named_shape=q.named_shape),
82+
ShapedArray(q.shape, q_dtype),
8383
ShapedArray([n, h, l], jnp.float32)
8484
)
8585
_flash_mha_fwd_p.def_abstract_eval(_flash_mha_fwd_abstract)
@@ -96,9 +96,9 @@ def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None, is_caus
9696
assert len(set([dout_dtype, q_dtype, k_dtype, v_dtype, out_dtype])) == 1
9797
assert q_dtype in [jnp.bfloat16, jnp.float16]
9898
return (
99-
ShapedArray(q.shape, q_dtype, named_shape=q.named_shape),
100-
ShapedArray(k.shape, k_dtype, named_shape=k.named_shape),
101-
ShapedArray(v.shape, v_dtype, named_shape=v.named_shape),
99+
ShapedArray(q.shape, q_dtype),
100+
ShapedArray(k.shape, k_dtype),
101+
ShapedArray(v.shape, v_dtype),
102102
)
103103
_flash_mha_bwd_p.def_abstract_eval(_flash_mha_bwd_abstract)
104104

src/flash_attn_jax/flash_hlo.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
from jax.lib import xla_client
1313
from jax.experimental.custom_partitioning import custom_partitioning
1414

15-
from jax.sharding import PartitionSpec as P
16-
from jax.sharding import Mesh
17-
from jax.sharding import NamedSharding
18-
from jax.sharding import PositionalSharding
15+
from jax.extend.core import Primitive
1916

2017
from einops import rearrange
2118
import einops
@@ -25,15 +22,15 @@
2522

2623
# ==== Register primitives ====
2724

28-
_flash_mha_fwd_hlo_p = core.Primitive("flash_mha_fwd_hlo")
25+
_flash_mha_fwd_hlo_p = Primitive("flash_mha_fwd_hlo")
2926
_flash_mha_fwd_hlo_p.multiple_results = True
3027
_flash_mha_fwd_hlo_p.def_impl(partial(xla.apply_primitive, _flash_mha_fwd_hlo_p))
3128

32-
_flash_mha_bwd_hlo_p = core.Primitive("flash_mha_bwd_hlo")
29+
_flash_mha_bwd_hlo_p = Primitive("flash_mha_bwd_hlo")
3330
_flash_mha_bwd_hlo_p.multiple_results = True
3431
_flash_mha_bwd_hlo_p.def_impl(partial(xla.apply_primitive, _flash_mha_bwd_hlo_p))
3532

36-
_custom_call_p = core.Primitive("custom_call")
33+
_custom_call_p = Primitive("custom_call")
3734
_custom_call_p.multiple_results = True
3835
_custom_call_p.def_impl(partial(xla.apply_primitive, _custom_call_p))
3936

@@ -48,13 +45,18 @@ def _flash_mha_bwd_hlo(dout, q, k, v, out, lse, softmax_scale, is_causal, window
4845
return dq, dk, dv
4946

5047
def custom_call(*args, call_target_name, result_types, backend_config, operand_layouts, result_layouts):
51-
return _custom_call_p.bind(*args, call_target_name=call_target_name, result_types=result_types, backend_config=backend_config, operand_layouts=operand_layouts, result_layouts=result_layouts)
48+
return _custom_call_p.bind(*args, call_target_name=call_target_name,
49+
result_types=tuple(result_types),
50+
backend_config=backend_config,
51+
operand_layouts=tuple(operand_layouts),
52+
result_layouts=tuple(result_layouts))
5253

5354
# ==== HLO lowerings ====
5455

5556
# Register functions defined in gpu_ops as custom call target for GPUs
5657
for _name, _value in flash_api.get_registrations().items():
57-
xla_client.register_custom_call_target(_name, _value, platform="gpu")
58+
# xla_client.register_custom_call_target(_name, _value, platform="gpu")
59+
jax.ffi.register_ffi_target(_name, _value, platform="gpu", api_version=0)
5860

5961
def default_layouts(*shapes):
6062
def row_major(shape):
@@ -85,6 +87,7 @@ def _flash_mha_fwd_hlo_lowering(ctx, q, k, v, softmax_scale=None, is_causal=Fals
8587
[nk, lk, hk, dk] = k_shape
8688
assert k_shape == v_shape, "K and V must have the same shape"
8789
assert [n, d] == [nk, dk], "Q and K must have the same batch size and head size"
90+
assert isinstance(window_size, (tuple, list))
8891

8992
opaque = flash_api.make_flash_mha_fwd_args(
9093
0.0, # p_dropout
@@ -164,6 +167,7 @@ def _flash_mha_bwd_hlo_lowering(ctx, dout, q, k, v, out, lse, softmax_scale=None
164167
[nk, lk, hk, dk] = k_shape
165168
assert n == nk
166169
assert d == dk
170+
assert isinstance(window_size, (tuple, list))
167171

168172
assert (list(map(list, [dout_shape, q_shape, k_shape, v_shape, out_shape, lse_shape])) ==
169173
[[n, lq, hq, d], [n, lq, hq, d], [n, lk, hk, d], [n, lk, hk, d],
@@ -238,7 +242,7 @@ def _flash_mha_fwd_abstract(q, k, v, softmax_scale=None, is_causal=None, window_
238242
assert q_dtype == k_dtype and q_dtype == v_dtype
239243
assert q_dtype in [jnp.bfloat16, jnp.float16]
240244
return (
241-
ShapedArray(q.shape, q_dtype, named_shape=q.named_shape),
245+
ShapedArray(q.shape, q_dtype),
242246
ShapedArray([n, h, l], jnp.float32)
243247
)
244248
_flash_mha_fwd_hlo_p.def_abstract_eval(_flash_mha_fwd_abstract)
@@ -255,9 +259,9 @@ def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None, is_caus
255259
assert len(set([dout_dtype, q_dtype, k_dtype, v_dtype, out_dtype])) == 1
256260
assert q_dtype in [jnp.bfloat16, jnp.float16]
257261
return (
258-
ShapedArray(q.shape, q_dtype, named_shape=q.named_shape),
259-
ShapedArray(k.shape, k_dtype, named_shape=k.named_shape),
260-
ShapedArray(v.shape, v_dtype, named_shape=v.named_shape),
262+
ShapedArray(q.shape, q_dtype),
263+
ShapedArray(k.shape, k_dtype),
264+
ShapedArray(v.shape, v_dtype),
261265
)
262266
_flash_mha_bwd_hlo_p.def_abstract_eval(_flash_mha_bwd_abstract)
263267

@@ -278,10 +282,10 @@ def _custom_call_hlo_lowering(ctx, *args, call_target_name, result_types, backen
278282
out = mlir.custom_call(
279283
call_target_name,
280284
operands=args,
281-
result_types=result_types,
285+
result_types=list(result_types),
282286
backend_config=backend_config,
283-
operand_layouts=operand_layouts,
284-
result_layouts=result_layouts,
287+
operand_layouts=list(operand_layouts),
288+
result_layouts=list(result_layouts),
285289
).results
286290
return out
287291

0 commit comments

Comments
 (0)