Skip to content

Commit 0eef5a9

Browse files
committed
Just use cmake, it seems easier to use.
1 parent be5d1b4 commit 0eef5a9

File tree

9 files changed

+259
-271
lines changed

9 files changed

+259
-271
lines changed

.clangd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ CompileFlags:
66
- --use_fast_math
77
- --threads
88
- -gencode
9+
- -forward-unknown-to-host-compiler
10+
- --generate-code=*
11+
- -Xcompiler=*
912
Add:
1013
- --no-cuda-version-check
1114
---

CMakeLists.txt

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
2+
cmake_minimum_required(VERSION 3.18)
3+
4+
project(flash_attn LANGUAGES CXX CUDA)
5+
6+
set(CMAKE_JOB_POOLS cuda=6)
7+
set(CMAKE_INSTALL_RPATH "$ORIGIN/nvidia/cuda_runtime/lib")
8+
# Make sure RPATH is used instead of RUNPATH
9+
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
10+
11+
# == Find dependencies ==
12+
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module)
13+
14+
execute_process(
15+
COMMAND ${Python_EXECUTABLE} -m pybind11 --cmakedir
16+
OUTPUT_VARIABLE pybind11_DIR
17+
OUTPUT_STRIP_TRAILING_WHITESPACE
18+
)
19+
20+
find_package(pybind11 CONFIG REQUIRED)
21+
22+
# == Setup CUDA ==
23+
string(REGEX REPLACE "--generate-code=arch=compute_[0-9]+,code=\\[?compute_[0-9]+,sm_[0-9]+\\]?" ""
24+
CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
25+
string(REGEX REPLACE "-gencode arch=compute_[0-9]+,code=sm_[0-9]+" ""
26+
CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
27+
28+
message(WARNING "CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
29+
30+
# Set up ccache
31+
find_program(CCACHE_PROGRAM ccache)
32+
if(CCACHE_PROGRAM)
33+
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
34+
message(STATUS "Using ccache: ${CCACHE_PROGRAM}")
35+
endif()
36+
37+
# Options from environment variables
38+
option(FLASH_ATTENTION_FORCE_BUILD "Force building from source" OFF)
39+
option(FLASH_ATTENTION_SKIP_CUDA_BUILD "Skip CUDA build" OFF)
40+
option(FLASH_ATTENTION_FORCE_CXX11_ABI "Force using C++11 ABI" OFF)
41+
42+
# CUDA handling
43+
# Get CUDA architectures from environment or use default
44+
if(DEFINED ENV{FLASH_ATTN_CUDA_ARCHS})
45+
set(CMAKE_CUDA_ARCHITECTURES $ENV{FLASH_ATTN_CUDA_ARCHS})
46+
else()
47+
# set(CMAKE_CUDA_ARCHITECTURES "80;90;100;120")
48+
set(CMAKE_CUDA_ARCHITECTURES "80")
49+
endif()
50+
51+
find_package(CUDAToolkit REQUIRED)
52+
53+
54+
# CUDA flags
55+
set(CUDA_FLAGS
56+
-O3
57+
-std=c++20
58+
--use_fast_math
59+
--expt-relaxed-constexpr
60+
--expt-extended-lambda
61+
-U__CUDA_NO_HALF_OPERATORS__
62+
-U__CUDA_NO_HALF_CONVERSIONS__
63+
-U__CUDA_NO_HALF2_OPERATORS__
64+
-U__CUDA_NO_BFLOAT16_CONVERSIONS__
65+
)
66+
67+
# Collect source files
68+
file(GLOB CUDA_SOURCES
69+
"csrc/flash_attn/src/flash_fwd_hdim*.cu"
70+
"csrc/flash_attn/src/flash_bwd_hdim*.cu"
71+
"csrc/flash_attn/src/flash_fwd_split_hdim*.cu"
72+
)
73+
74+
file(GLOB CC_SOURCES
75+
"csrc/flash_attn/*.cpp"
76+
)
77+
78+
# Create CUDA extension
79+
pybind11_add_module(flash_api
80+
${CC_SOURCES}
81+
${CUDA_SOURCES}
82+
)
83+
84+
set_property(TARGET flash_api PROPERTY JOB_POOL_COMPILE cuda)
85+
86+
target_compile_options(flash_api PRIVATE
87+
$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>
88+
)
89+
90+
target_include_directories(flash_api PRIVATE
91+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/flash_attn
92+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/flash_attn/src
93+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass/include
94+
)
95+
96+
target_link_libraries(flash_api PRIVATE
97+
CUDA::cudart
98+
)
99+
100+
if(FLASH_ATTENTION_FORCE_CXX11_ABI)
101+
target_compile_definitions(flash_api PRIVATE
102+
_GLIBCXX_USE_CXX11_ABI=1
103+
)
104+
endif()
105+
106+
# Installation
107+
install(TARGETS flash_api
108+
DESTINATION ${SKBUILD_PLATLIB_DIR}/flash_attn_jax_lib
109+
)
110+
111+
install(DIRECTORY src/flash_attn_jax/
112+
DESTINATION ${SKBUILD_PLATLIB_DIR}/flash_attn_jax
113+
FILES_MATCHING PATTERN "*.py"
114+
)

Tupfile

Lines changed: 0 additions & 15 deletions
This file was deleted.

csrc/flash_attn/src/flash_fwd_launch_template.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "static_switch.h"
1111
#include "flash.h"
1212
#include "flash_fwd_kernel.h"
13+
#include "kernel_traits.h"
1314

1415
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
1516
__global__ void flash_fwd_kernel(__grid_constant__ const Flash_fwd_params params) {

csrc/flash_attn/src/kernel_traits.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44

55
#pragma once
66

7-
#include "cute/algorithm/copy.hpp"
8-
9-
#include "cutlass/cutlass.h"
10-
#include "cutlass/layout/layout.h"
11-
#include <cutlass/numeric_types.h>
7+
#include "cute/tensor.hpp"
8+
#include "cute/atom/mma_atom.hpp"
9+
#include "cute/atom/copy_atom.hpp"
1210

1311
using namespace cute;
1412

make_compile_commands.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import os, sys
2+
from subprocess import Popen, PIPE
3+
import json
4+
import re
5+
6+
Popen(["cmake", ".", "-B", "build"]).wait()
7+
with open("build/compile_commands.json", "r") as f:
8+
compile_commands = json.load(f)
9+
10+
# --options-file CMakeFiles/flash_attn_2_cuda.dir/includes_CUDA.rsp
11+
re_options = re.compile(r"--options-file ([A-Za-z0-9/\._]*)")
12+
13+
for command in compile_commands:
14+
if re_options.search(command["command"]):
15+
m = re_options.search(command["command"])
16+
options_file = m.group(1)
17+
with open(os.path.join('build', options_file), "r") as f:
18+
options = f.read()
19+
command["command"] = command["command"].replace(m.group(0), options)
20+
21+
with open("compile_commands.json", "w") as f:
22+
json.dump(compile_commands, f, indent=2)

pyproject.toml

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,66 @@
11
[build-system]
2-
requires = ["setuptools", "wheel", "setuptools-cuda-cpp @ git+https://github.com/nshepperd/setuptools-cuda-cpp", "packaging", "pybind11"]
2+
requires = [
3+
"scikit-build-core>=0.8.0",
4+
"cmake>=3.18",
5+
"ninja>=1.10",
6+
"packaging",
7+
"psutil",
8+
"pybind11>=2.11.0",
9+
# "nvidia-cuda-runtime-cu12>=12.0",
10+
# "nvidia-cuda-nvrtc-cu12",
11+
# "nvidia-nvtx-cu12",
12+
"torch>=2.0.0",
13+
]
14+
build-backend = "scikit_build_core.build"
15+
16+
[project]
17+
name = "flash_attn_jax"
18+
dynamic = ["version"]
19+
description = "Flash Attention: Fast and Memory-Efficient Exact Attention"
20+
readme = "README.md"
21+
requires-python = ">=3.9"
22+
license = { text = "BSD-3-Clause" }
23+
authors = [
24+
{ name = "Tri Dao", email = "[email protected]" },
25+
{ name = "Emily Shepperd", email = "[email protected]" }
26+
]
27+
dependencies = []
28+
classifiers = [
29+
"Programming Language :: Python :: 3",
30+
"License :: OSI Approved :: BSD License",
31+
"Operating System :: Unix",
32+
]
33+
34+
[project.urls]
35+
Homepage = "https://github.com/nshepperd/flash_attn_jax"
36+
37+
[tool.scikit-build]
38+
wheel.expand-macos-universal-tags = false
39+
cmake.version = ">=3.26.1"
40+
ninja.version = ">=1.11"
41+
build.verbose = true
42+
cmake.build-type = "Release"
43+
cmake.args = []
44+
45+
[tool.scikit-build.cmake.define]
46+
SKBUILD = "ON"
47+
FLASH_ATTENTION_FORCE_BUILD = { env = "FLASH_ATTENTION_FORCE_BUILD" }
48+
FLASH_ATTENTION_SKIP_CUDA_BUILD = { env = "FLASH_ATTENTION_SKIP_CUDA_BUILD" }
49+
FLASH_ATTENTION_FORCE_CXX11_ABI = { env = "FLASH_ATTENTION_FORCE_CXX11_ABI" }
50+
FLASH_ATTENTION_TRITON_AMD_ENABLE = { env = "FLASH_ATTENTION_TRITON_AMD_ENABLE" }
51+
FLASH_ATTN_CUDA_ARCHS = { env = "FLASH_ATTN_CUDA_ARCHS" }
52+
CMAKE_VERBOSE_MAKEFILE = "ON"
53+
54+
[tool.scikit-build.metadata.version]
55+
provider = "scikit_build_core.metadata.regex"
56+
input = "src/flash_attn_jax/__init__.py"
357

458
[tool.cibuildwheel]
5-
manylinux-x86_64-image = "sameli/manylinux_2_28_x86_64_cuda_12.3"
59+
manylinux-x86_64-image = "quay.io/pypa/manylinux_2_28_x86_64:latest"
60+
before-all = "bash scripts/install-cuda-linux.sh"
61+
build = "cp312-manylinux_x86_64"
62+
repair-wheel-command = "auditwheel repair --exclude=libcudart.so* --exclude libtorch.so* -w {dest_dir} {wheel}"
63+
64+
[tool.cibuildwheel.environment]
65+
PATH="/opt/rh/gcc-toolset-13/root/usr/bin:/usr/local/cuda/bin:$PATH"
66+
CUDA_HOME="/usr/local/cuda"

scripts/install-cuda-linux.sh

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/bin/bash
2+
set -eux
3+
4+
VER=${1:-12.4}
5+
VER=${VER//./-} # Convert version to format used in package names
6+
7+
dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
8+
9+
# Install GCC 13
10+
11+
dnf -y install gcc-toolset-13
12+
dnf -y remove gcc-toolset-14-*
13+
echo ". /opt/rh/gcc-toolset-13/enable" > /etc/profile.d/gcc.sh
14+
chmod +x /etc/profile.d/gcc.sh
15+
16+
# Create a fake package to stop cuda from stupidly installing gcc-8.5
17+
18+
dnf -y install rpm-build
19+
20+
mkdir -p ~/rpmbuild/{SPECS,RPMS,SOURCES}
21+
cd ~/rpmbuild
22+
cat > SPECS/gcc-dummy.spec <<EOF
23+
Name: gcc-dummy
24+
Version: 13
25+
Release: 1%{?dist}
26+
Summary: Dummy package to provide gcc-c++
27+
License: MIT
28+
BuildArch: noarch
29+
Provides: gcc-c++ = 13
30+
31+
%description
32+
Dummy package that provides gcc-c++ capabilities without actual compiler
33+
34+
%files
35+
36+
%changelog
37+
* Wed Feb 12 2025 User <[email protected]> - 8.5.0-1
38+
- Initial package
39+
EOF
40+
rpmbuild -bb SPECS/gcc-dummy.spec
41+
rpm -ivh ~/rpmbuild/RPMS/noarch/gcc-dummy*.rpm --nodeps
42+
43+
# Install CUDA
44+
45+
dnf -y install \
46+
cuda-compiler-${VER} \
47+
cuda-minimal-build-${VER} \
48+
cuda-nvtx-${VER} \
49+
cuda-nvrtc-devel-${VER}
50+
51+
# cuda-libraries-devel-${VER} \
52+
# dnf clean all
53+

0 commit comments

Comments
 (0)