Skip to content

Commit 1df9b3b

Browse files
committed
Merge branch 'cmake'.
Switches the build system to cmake, which is more flexible, and fixes up the compatibility for jax 0.5.0--0.7.0.
2 parents be5d1b4 + ca2839a commit 1df9b3b

20 files changed

+921
-484
lines changed

.clangd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ CompileFlags:
66
- --use_fast_math
77
- --threads
88
- -gencode
9+
- -forward-unknown-to-host-compiler
10+
- --generate-code=*
11+
- -Xcompiler=*
912
Add:
1013
- --no-cuda-version-check
1114
---

.github/workflows/publish.yml

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ jobs:
4141
fail-fast: false
4242
matrix:
4343
os: [ubuntu-20.04]
44-
python-version: ['cp39', 'cp310', 'cp311', 'cp312']
45-
cuda-version: ['11.8', '12.3']
44+
python-version: ['cp311', 'cp312']
45+
cuda-version: ['12.8']
4646

4747
steps:
4848
- name: Checkout
@@ -51,7 +51,7 @@ jobs:
5151
- name: Set up python
5252
uses: actions/setup-python@v4
5353
with:
54-
python-version: '3.10'
54+
python-version: '3.11'
5555

5656
- name: Set CUDA and PyTorch versions
5757
run: |
@@ -76,7 +76,7 @@ jobs:
7676
uses: pypa/[email protected]
7777
env:
7878
CIBW_BUILD: ${{ matrix.python-version }}-manylinux_x86_64
79-
CIBW_MANYLINUX_X86_64_IMAGE: sameli/manylinux2014_x86_64_cuda_${{ matrix.cuda-version }}
79+
CIBW_BEFORE_ALL: bash scripts/install-cuda-linux.sh ${{ matrix.cuda-version }}
8080
CIBW_BUILD_VERBOSITY: 1
8181

8282
- name: Log Built Wheels
@@ -128,17 +128,15 @@ jobs:
128128

129129
- uses: actions/setup-python@v4
130130
with:
131-
python-version: '3.10'
131+
python-version: '3.11'
132132

133133
- name: Install dependencies
134134
run: |
135-
pip install setuptools==68.0.0
136-
pip install git+https://github.com/nshepperd/setuptools-cuda-cpp
137-
pip install ninja packaging wheel pybind11
135+
pip install uv
138136
139137
- name: Build core package
140138
run: |
141-
CUDA_HOME=/ python setup.py sdist --dist-dir=dist
139+
uv build --sdist
142140
143141
- name: Retrieve release distributions
144142
uses: actions/download-artifact@v4

CMakeLists.txt

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
2+
cmake_minimum_required(VERSION 3.18)
3+
4+
project(flash_attn LANGUAGES CXX CUDA)
5+
6+
set(CMAKE_JOB_POOLS cuda=6)
7+
set(CMAKE_INSTALL_RPATH "$ORIGIN/nvidia/cuda_runtime/lib")
8+
# Make sure RPATH is used instead of RUNPATH
9+
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
10+
11+
# == Find dependencies ==
12+
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module)
13+
14+
execute_process(
15+
COMMAND ${Python_EXECUTABLE} -m pybind11 --cmakedir
16+
OUTPUT_VARIABLE pybind11_DIR
17+
OUTPUT_STRIP_TRAILING_WHITESPACE
18+
)
19+
20+
find_package(pybind11 CONFIG REQUIRED)
21+
22+
# == Setup CUDA ==
23+
string(REGEX REPLACE "--generate-code=arch=compute_[0-9]+,code=\\[?compute_[0-9]+,sm_[0-9]+\\]?" ""
24+
CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
25+
string(REGEX REPLACE "-gencode arch=compute_[0-9]+,code=sm_[0-9]+" ""
26+
CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
27+
28+
message(WARNING "CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
29+
30+
# Set up ccache
31+
find_program(CCACHE_PROGRAM ccache)
32+
if(CCACHE_PROGRAM)
33+
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
34+
message(STATUS "Using ccache: ${CCACHE_PROGRAM}")
35+
endif()
36+
37+
# CUDA handling
38+
# Get CUDA architectures from environment or use default
39+
if(DEFINED ENV{FLASH_ATTN_CUDA_ARCHS})
40+
set(CMAKE_CUDA_ARCHITECTURES $ENV{FLASH_ATTN_CUDA_ARCHS})
41+
else()
42+
# set(CMAKE_CUDA_ARCHITECTURES "80;90;100;120")
43+
set(CMAKE_CUDA_ARCHITECTURES "80")
44+
endif()
45+
46+
find_package(CUDAToolkit REQUIRED)
47+
48+
49+
# CUDA flags
50+
set(CUDA_FLAGS
51+
-O3
52+
-std=c++20
53+
--use_fast_math
54+
--expt-relaxed-constexpr
55+
--expt-extended-lambda
56+
-U__CUDA_NO_HALF_OPERATORS__
57+
-U__CUDA_NO_HALF_CONVERSIONS__
58+
-U__CUDA_NO_HALF2_OPERATORS__
59+
-U__CUDA_NO_BFLOAT16_CONVERSIONS__
60+
-DFLASHATTENTION_DISABLE_DROPOUT=1
61+
-DFLASHATTENTION_DISABLE_ALIBI=1
62+
)
63+
64+
# Collect source files
65+
file(GLOB CUDA_SOURCES
66+
"csrc/flash_attn/src/flash_fwd_hdim*.cu"
67+
"csrc/flash_attn/src/flash_bwd_hdim*.cu"
68+
"csrc/flash_attn/src/flash_fwd_split_hdim*.cu"
69+
)
70+
71+
file(GLOB CC_SOURCES
72+
"csrc/flash_attn/*.cpp"
73+
)
74+
75+
# Create CUDA extension
76+
pybind11_add_module(flash_api
77+
${CC_SOURCES}
78+
${CUDA_SOURCES}
79+
)
80+
81+
set_property(TARGET flash_api PROPERTY JOB_POOL_COMPILE cuda)
82+
83+
target_compile_options(flash_api PRIVATE
84+
$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>
85+
)
86+
87+
target_include_directories(flash_api PRIVATE
88+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/flash_attn
89+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/flash_attn/src
90+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass/include
91+
)
92+
93+
target_link_libraries(flash_api PRIVATE
94+
CUDA::cudart
95+
)
96+
97+
if(FLASH_ATTENTION_FORCE_CXX11_ABI)
98+
target_compile_definitions(flash_api PRIVATE
99+
_GLIBCXX_USE_CXX11_ABI=1
100+
)
101+
endif()
102+
103+
# Installation
104+
install(TARGETS flash_api
105+
DESTINATION ${SKBUILD_PLATLIB_DIR}/flash_attn_jax_lib
106+
)
107+
108+
install(DIRECTORY src/flash_attn_jax/
109+
DESTINATION ${SKBUILD_PLATLIB_DIR}/flash_attn_jax
110+
FILES_MATCHING PATTERN "*.py"
111+
)

README.md

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
# FlashAttention JAX
22
This repository provides a jax binding to <https://github.com/Dao-AILab/flash-attention>. To avoid depending on pytorch, since torch and jax installations often conflict, this is a fork of the official repo.
33

4-
Please see [Tri Dao's repo](https://github.com/Dao-AILab/flash-attention) for more information about flash attention.
4+
Please see [Tri Dao's repo](https://github.com/Dao-AILab/flash-attention) for more information about flash attention. Also check there for how to cite the authors if you used flash attention in your work.
55

66
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
77
Please cite (see below) and credit FlashAttention if you use it.
88

99
## Installation
1010

1111
Requirements:
12-
- CUDA 11.8 and above.
12+
- CUDA 12.8 and above.
1313
- Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
14-
- JAX >=`0.4.24`. The custom sharding used for ring attention requires some somewhat advanced features.
14+
- JAX >= `0.5.*`. The custom call api changed in this version.
1515

16-
To install: `pip install flash-attn-jax` will get the latest release from pypi. This gives you the cuda 12.3 build. If you want to use the cuda 11.8 build, you can install from the releases page (but according to jax's documentation, 11.8 will stop being supported for newer versions of jax).
16+
To install: `pip install flash-attn-jax` will get the latest release from pypi. This gives you the cuda 12.8
17+
build. CUDA 11 isn't supported any more (since jax stopped supporting it).
1718

1819
### Installing from source
1920

@@ -25,7 +26,7 @@ cd flash-attn-jax
2526
cibuildwheel --only cp312-manylinux_x86_64 # I think cibuildwheel needs superuser privileges on some systems because of docker reasons?
2627
```
2728

28-
This will create a wheel in the `wheelhouse` directory. You can then install it with `pip install wheelhouse/flash_attn_jax_0.2.0-cp312-cp312-manylinux_x86_64.whl`. Or you could use setup.py to build the wheel and install it. You need cuda toolkit installed in that case.
29+
This will create a wheel in the `wheelhouse` directory. You can then install it with `pip install wheelhouse/flash_attn_jax_*.whl`. Or you could build it without docker using `uv build --wheel`. You need cuda installed in that case.
2930

3031
## Usage
3132

@@ -45,15 +46,16 @@ This supports multi-query and grouped-query attention (when hk != h). The `softm
4546
Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm (forward and backward).
4647

4748
```py
48-
os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_collectives=true'
49+
os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true'
4950
#...
5051
with Mesh(devices, axis_names=('len',)) as mesh:
5152
sharding = NamedSharding(mesh, P(None,'len')) # n l
5253
tokens = jax.device_put(tokens, sharding)
5354
# invoke your jax.jit'd transformer.forward
5455
```
5556

56-
It's not entirely reliable at hiding the communication latency though, depending on the whims of the xla optimizer. I'm waiting https://github.com/google/jax/issues/20864 to be fixed, then I can make it better.
57+
The latency hiding seems to be reliable now that some bugs have been fixed, as long as you enable the
58+
latency hiding scheduler as above.
5759

5860
### GPU support
5961

@@ -63,19 +65,3 @@ FlashAttention-2 currently supports:
6365
GPUs for now.
6466
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
6567
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
66-
67-
## Citation
68-
If you use this codebase, or otherwise found our work valuable, please cite:
69-
```
70-
@inproceedings{dao2022flashattention,
71-
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
72-
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
73-
booktitle={Advances in Neural Information Processing Systems},
74-
year={2022}
75-
}
76-
@article{dao2023flashattention2,
77-
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
78-
author={Dao, Tri},
79-
year={2023}
80-
}
81-
```

Tupfile

Lines changed: 0 additions & 15 deletions
This file was deleted.

csrc/flash_attn/src/flash_fwd_launch_template.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "static_switch.h"
1111
#include "flash.h"
1212
#include "flash_fwd_kernel.h"
13+
#include "kernel_traits.h"
1314

1415
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
1516
__global__ void flash_fwd_kernel(__grid_constant__ const Flash_fwd_params params) {

csrc/flash_attn/src/kernel_traits.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44

55
#pragma once
66

7-
#include "cute/algorithm/copy.hpp"
8-
9-
#include "cutlass/cutlass.h"
10-
#include "cutlass/layout/layout.h"
11-
#include <cutlass/numeric_types.h>
7+
#include "cute/tensor.hpp"
8+
#include "cute/atom/mma_atom.hpp"
9+
#include "cute/atom/copy_atom.hpp"
1210

1311
using namespace cute;
1412

make_compile_commands.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import os, sys
2+
from subprocess import Popen, PIPE
3+
import json
4+
import re
5+
6+
Popen(["cmake", ".", "-B", "build"]).wait()
7+
with open("build/compile_commands.json", "r") as f:
8+
compile_commands = json.load(f)
9+
10+
# --options-file CMakeFiles/flash_attn_2_cuda.dir/includes_CUDA.rsp
11+
re_options = re.compile(r"--options-file ([A-Za-z0-9/\._]*)")
12+
13+
for command in compile_commands:
14+
if re_options.search(command["command"]):
15+
m = re_options.search(command["command"])
16+
options_file = m.group(1)
17+
with open(os.path.join('build', options_file), "r") as f:
18+
options = f.read()
19+
command["command"] = command["command"].replace(m.group(0), options)
20+
21+
with open("compile_commands.json", "w") as f:
22+
json.dump(compile_commands, f, indent=2)

pyproject.toml

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,70 @@
11
[build-system]
2-
requires = ["setuptools", "wheel", "setuptools-cuda-cpp @ git+https://github.com/nshepperd/setuptools-cuda-cpp", "packaging", "pybind11"]
2+
requires = [
3+
"scikit-build-core>=0.8.0",
4+
"cmake>=3.18",
5+
"ninja>=1.10",
6+
"packaging",
7+
"psutil",
8+
"pybind11>=2.11.0",
9+
]
10+
build-backend = "scikit_build_core.build"
11+
12+
[project]
13+
name = "flash_attn_jax"
14+
dynamic = ["version"]
15+
description = "Flash Attention port for JAX"
16+
readme = "README.md"
17+
requires-python = ">=3.11"
18+
license = { text = "BSD-3-Clause" }
19+
authors = [
20+
{ name = "Tri Dao", email = "[email protected]" },
21+
{ name = "Emily Shepperd", email = "[email protected]" }
22+
]
23+
dependencies = [
24+
"jax>=0.5.0, <0.8.0"
25+
]
26+
classifiers = [
27+
"Programming Language :: Python :: 3",
28+
"License :: OSI Approved :: BSD License",
29+
"Operating System :: Unix",
30+
]
31+
32+
[dependency-groups]
33+
test = [
34+
"pytest>=7.0.0",
35+
"einops",
36+
"jax[cuda12]",
37+
]
38+
[project.urls]
39+
Homepage = "https://github.com/nshepperd/flash_attn_jax"
40+
41+
[tool.scikit-build]
42+
wheel.expand-macos-universal-tags = false
43+
cmake.version = ">=3.26.1"
44+
ninja.version = ">=1.11"
45+
build.verbose = true
46+
cmake.build-type = "Release"
47+
cmake.args = []
48+
49+
[tool.scikit-build.cmake.define]
50+
SKBUILD = "ON"
51+
FLASH_ATTENTION_FORCE_BUILD = { env = "FLASH_ATTENTION_FORCE_BUILD" }
52+
FLASH_ATTENTION_SKIP_CUDA_BUILD = { env = "FLASH_ATTENTION_SKIP_CUDA_BUILD" }
53+
FLASH_ATTENTION_FORCE_CXX11_ABI = { env = "FLASH_ATTENTION_FORCE_CXX11_ABI" }
54+
FLASH_ATTENTION_TRITON_AMD_ENABLE = { env = "FLASH_ATTENTION_TRITON_AMD_ENABLE" }
55+
FLASH_ATTN_CUDA_ARCHS = { env = "FLASH_ATTN_CUDA_ARCHS" }
56+
CMAKE_VERBOSE_MAKEFILE = "ON"
57+
58+
[tool.scikit-build.metadata.version]
59+
provider = "scikit_build_core.metadata.regex"
60+
input = "src/flash_attn_jax/__init__.py"
361

462
[tool.cibuildwheel]
5-
manylinux-x86_64-image = "sameli/manylinux_2_28_x86_64_cuda_12.3"
63+
# manylinux-x86_64-image = "quay.io/pypa/manylinux_2_28_x86_64:latest"
64+
before-all = "bash scripts/install-cuda-linux.sh"
65+
build = "cp312-manylinux_x86_64"
66+
repair-wheel-command = "auditwheel repair --exclude=libcudart.so* -w {dest_dir} {wheel}"
67+
68+
[tool.cibuildwheel.environment]
69+
PATH="/opt/rh/gcc-toolset-13/root/usr/bin:/usr/local/cuda/bin:$PATH"
70+
CUDA_HOME="/usr/local/cuda"

0 commit comments

Comments
 (0)