|
1 | | -# Copyright (c) 2025, Jingze Shi. |
2 | | - |
3 | | -import sys |
4 | | -import functools |
5 | | -import warnings |
6 | | -import os |
7 | | -import re |
8 | | -import ast |
9 | | -import glob |
10 | | -from pathlib import Path |
11 | | -from packaging.version import parse, Version |
12 | | -import platform |
13 | | -from typing import Optional |
14 | | - |
15 | 1 | from setuptools import setup |
16 | | -import subprocess |
17 | | - |
18 | | -import urllib.request |
19 | | -import urllib.error |
20 | | -from wheel.bdist_wheel import bdist_wheel as _bdist_wheel |
21 | | - |
22 | | -import torch |
23 | | -from torch.utils.cpp_extension import ( |
24 | | - BuildExtension, |
25 | | - CUDAExtension, |
26 | | - CUDA_HOME, |
27 | | -) |
28 | | - |
29 | | - |
30 | | -with open("README.md", "r", encoding="utf-8") as fh: |
31 | | - long_description = fh.read() |
32 | | - |
33 | | - |
34 | | -# ninja build does not work unless include_dirs are abs path |
35 | | -this_dir = os.path.dirname(os.path.abspath(__file__)) |
36 | | - |
37 | | -PACKAGE_NAME = "flash_sparse_attn" |
38 | | - |
39 | | -BASE_WHEEL_URL = ( |
40 | | - "https://github.com/flash-algo/flash-sparse-attention/releases/download/{tag_name}/{wheel_name}" |
41 | | -) |
42 | | - |
43 | | -# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels |
44 | | -# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation |
45 | | -# Also useful when user only wants Triton/Flex backends without CUDA compilation |
46 | | -FORCE_BUILD = os.getenv("FLASH_SPARSE_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" |
47 | | -SKIP_CUDA_BUILD = os.getenv("FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" |
48 | | -# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI |
49 | | -FORCE_CXX11_ABI = os.getenv("FLASH_SPARSE_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" |
50 | | - |
51 | | -# Auto-detect if user wants only Triton/Flex backends based on pip install command |
52 | | -# This helps avoid unnecessary CUDA compilation when user only wants Python backends |
53 | | -def should_skip_cuda_build(): |
54 | | - """Determine if CUDA build should be skipped based on installation context.""" |
55 | | - |
56 | | - if SKIP_CUDA_BUILD: |
57 | | - return True |
58 | | - |
59 | | - if FORCE_BUILD: |
60 | | - return False # User explicitly wants to build, respect that |
61 | | - |
62 | | - # Check command line arguments for installation hints |
63 | | - if len(sys.argv) > 1: |
64 | | - install_args = ' '.join(sys.argv) |
65 | | - |
66 | | - # Check if Triton or Flex extras are requested |
67 | | - has_triton_or_flex = 'triton' in install_args or 'flex' in install_args |
68 | | - has_all_or_dev = 'all' in install_args or 'dev' in install_args |
69 | | - |
70 | | - if has_triton_or_flex and not has_all_or_dev: |
71 | | - print("Detected Triton/Flex-only installation. Skipping CUDA compilation.") |
72 | | - print("Set FLASH_SPARSE_ATTENTION_FORCE_BUILD=TRUE to force CUDA compilation.") |
73 | | - return True |
74 | | - |
75 | | - return False |
76 | | - |
77 | | -# Update SKIP_CUDA_BUILD based on auto-detection |
78 | | -SKIP_CUDA_BUILD = should_skip_cuda_build() |
79 | | - |
80 | | -@functools.lru_cache(maxsize=None) |
81 | | -def cuda_archs(): |
82 | | - return os.getenv("FLASH_SPARSE_ATTENTION_CUDA_ARCHS", "80;90;100").split(";") |
83 | | - |
84 | | - |
85 | | -def detect_preferred_sm_arch() -> Optional[str]: |
86 | | - """Detect the preferred SM arch from the current CUDA device. |
87 | | - Returns None if CUDA is unavailable or detection fails. |
88 | | - """ |
89 | | - try: |
90 | | - if torch.cuda.is_available(): |
91 | | - idx = torch.cuda.current_device() |
92 | | - major, minor = torch.cuda.get_device_capability(idx) |
93 | | - return f"{major}{minor}" |
94 | | - except Exception: |
95 | | - pass |
96 | | - return None |
97 | | - |
98 | | - |
99 | | -def get_platform(): |
100 | | - """ |
101 | | - Returns the platform name as used in wheel filenames. |
102 | | - """ |
103 | | - if sys.platform.startswith("linux"): |
104 | | - return f'linux_{platform.uname().machine}' |
105 | | - elif sys.platform == "darwin": |
106 | | - mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) |
107 | | - return f"macosx_{mac_version}_x86_64" |
108 | | - elif sys.platform == "win32": |
109 | | - return "win_amd64" |
110 | | - else: |
111 | | - raise ValueError("Unsupported platform: {}".format(sys.platform)) |
112 | | - |
113 | | - |
114 | | -def get_cuda_bare_metal_version(cuda_dir): |
115 | | - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) |
116 | | - output = raw_output.split() |
117 | | - release_idx = output.index("release") + 1 |
118 | | - bare_metal_version = parse(output[release_idx].split(",")[0]) |
119 | | - |
120 | | - return raw_output, bare_metal_version |
121 | | - |
122 | | - |
123 | | -def check_if_cuda_home_none(global_option: str) -> None: |
124 | | - if CUDA_HOME is not None: |
125 | | - return |
126 | | - # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary |
127 | | - # in that case. |
128 | | - warnings.warn( |
129 | | - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " |
130 | | - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " |
131 | | - "only images whose names contain 'devel' will provide nvcc." |
132 | | - ) |
133 | | - |
134 | | - |
135 | | -def append_nvcc_threads(nvcc_extra_args): |
136 | | - nvcc_threads = os.getenv("NVCC_THREADS") or "4" |
137 | | - return nvcc_extra_args + ["--threads", nvcc_threads] |
138 | | - |
139 | | - |
140 | | -cmdclass = {} |
141 | | -ext_modules = [] |
142 | | - |
143 | | -# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp |
144 | | -# files included in the source distribution, in case the user compiles from source. |
145 | | -if os.path.isdir(".git"): |
146 | | - subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"], check=True) |
147 | | -else: |
148 | | - assert ( |
149 | | - os.path.exists("csrc/cutlass/include/cutlass/cutlass.h") |
150 | | - ), "csrc/cutlass is missing, please use source distribution or git clone" |
151 | | - |
152 | | -if not SKIP_CUDA_BUILD: |
153 | | - print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) |
154 | | - TORCH_MAJOR = int(torch.__version__.split(".")[0]) |
155 | | - TORCH_MINOR = int(torch.__version__.split(".")[1]) |
156 | | - |
157 | | - check_if_cuda_home_none("flash_sparse_attn") |
158 | | - # Check, if CUDA11 is installed for compute capability 8.0 |
159 | | - cc_flag = [] |
160 | | - if CUDA_HOME is not None: |
161 | | - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) |
162 | | - if bare_metal_version < Version("11.7"): |
163 | | - raise RuntimeError( |
164 | | - "Flash Sparse Attention is only supported on CUDA 11.7 and above. " |
165 | | - "Note: make sure nvcc has a supported version by running nvcc -V." |
166 | | - ) |
167 | | - |
168 | | - if "80" in cuda_archs(): |
169 | | - cc_flag.append("-gencode") |
170 | | - cc_flag.append("arch=compute_80,code=sm_80") |
171 | | - |
172 | | - if CUDA_HOME is not None: |
173 | | - if bare_metal_version >= Version("11.8") and "86" in cuda_archs(): |
174 | | - cc_flag.append("-gencode") |
175 | | - cc_flag.append("arch=compute_86,code=sm_86") |
176 | | - if bare_metal_version >= Version("11.8") and "89" in cuda_archs(): |
177 | | - cc_flag.append("-gencode") |
178 | | - cc_flag.append("arch=compute_89,code=sm_89") |
179 | | - if bare_metal_version >= Version("11.8") and "90" in cuda_archs(): |
180 | | - cc_flag.append("-gencode") |
181 | | - cc_flag.append("arch=compute_90,code=sm_90") |
182 | | - if bare_metal_version >= Version("12.8") and "100" in cuda_archs(): |
183 | | - cc_flag.append("-gencode") |
184 | | - cc_flag.append("arch=compute_100,code=sm_100") |
185 | | - if bare_metal_version >= Version("12.8") and "120" in cuda_archs(): |
186 | | - cc_flag.append("-gencode") |
187 | | - cc_flag.append("arch=compute_120,code=sm_120") |
188 | | - |
189 | | - # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as |
190 | | - # torch._C._GLIBCXX_USE_CXX11_ABI |
191 | | - # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 |
192 | | - if FORCE_CXX11_ABI: |
193 | | - torch._C._GLIBCXX_USE_CXX11_ABI = True |
194 | | - |
195 | | - nvcc_flags = [ |
196 | | - "-O3", |
197 | | - "-std=c++17", |
198 | | - "-U__CUDA_NO_HALF_OPERATORS__", |
199 | | - "-U__CUDA_NO_HALF_CONVERSIONS__", |
200 | | - "-U__CUDA_NO_HALF2_OPERATORS__", |
201 | | - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", |
202 | | - "--expt-relaxed-constexpr", |
203 | | - "--expt-extended-lambda", |
204 | | - "--use_fast_math", |
205 | | - # "--ptxas-options=-v", |
206 | | - # "--ptxas-options=-O2", |
207 | | - # "-lineinfo", |
208 | | - # "-DFLASHATTENTION_DISABLE_BACKWARD", |
209 | | - # "-DFLASHATTENTION_DISABLE_SOFTCAP", |
210 | | - # "-DFLASHATTENTION_DISABLE_UNEVEN_K", |
211 | | - ] |
212 | | - |
213 | | - compiler_c17_flag=["-O3", "-std=c++17"] |
214 | | - # Add Windows-specific flags |
215 | | - if sys.platform == "win32" and os.getenv('DISTUTILS_USE_SDK') == '1': |
216 | | - nvcc_flags.extend(["-Xcompiler", "/Zc:__cplusplus"]) |
217 | | - compiler_c17_flag=["-O2", "/std:c++17", "/Zc:__cplusplus"] |
218 | | - |
219 | | - ext_modules.append( |
220 | | - CUDAExtension( |
221 | | - name="flash_sparse_attn_cuda", |
222 | | - sources=( |
223 | | - [ |
224 | | - "csrc/flash_sparse_attn/flash_api.cpp", |
225 | | - ] |
226 | | - + sorted(glob.glob("csrc/flash_sparse_attn/src/instantiations/flash_*.cu")) |
227 | | - ), |
228 | | - extra_compile_args={ |
229 | | - "cxx": compiler_c17_flag, |
230 | | - "nvcc": append_nvcc_threads(nvcc_flags + cc_flag), |
231 | | - }, |
232 | | - include_dirs=[ |
233 | | - Path(this_dir) / "csrc" / "flash_sparse_attn", |
234 | | - Path(this_dir) / "csrc" / "flash_sparse_attn" / "src", |
235 | | - Path(this_dir) / "csrc" / "cutlass" / "include", |
236 | | - ], |
237 | | - ) |
238 | | - ) |
239 | | - |
240 | | - |
241 | | -def get_package_version(): |
242 | | - with open(Path(this_dir) / "flash_sparse_attn" / "__init__.py", "r") as f: |
243 | | - version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) |
244 | | - public_version = ast.literal_eval(version_match.group(1)) |
245 | | - local_version = os.environ.get("FLASH_SPARSE_ATTENTION_LOCAL_VERSION") |
246 | | - if local_version: |
247 | | - return f"{public_version}+{local_version}" |
248 | | - else: |
249 | | - return str(public_version) |
250 | | - |
251 | | - |
252 | | -def get_wheel_url(): |
253 | | - sm_arch = detect_preferred_sm_arch() |
254 | | - torch_version_raw = parse(torch.__version__) |
255 | | - python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" |
256 | | - platform_name = get_platform() |
257 | | - flash_version = get_package_version() |
258 | | - torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" |
259 | | - cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() |
260 | | - |
261 | | - # Determine the version numbers that will be used to determine the correct wheel |
262 | | - # We're using the CUDA version used to build torch, not the one currently installed |
263 | | - # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) |
264 | | - torch_cuda_version = parse(torch.version.cuda) |
265 | | - # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 |
266 | | - # to save CI time. Minor versions should be compatible. |
267 | | - torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") |
268 | | - # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" |
269 | | - cuda_version = f"{torch_cuda_version.major}" |
270 | | - |
271 | | - # Determine wheel URL based on CUDA version, torch version, python version and OS |
272 | | - wheel_filename = f"{PACKAGE_NAME}-{flash_version}+sm{sm_arch}cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" |
273 | | - |
274 | | - wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename) |
275 | | - |
276 | | - return wheel_url, wheel_filename |
277 | | - |
278 | | - |
279 | | -class CachedWheelsCommand(_bdist_wheel): |
280 | | - """ |
281 | | - The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot |
282 | | - find an existing wheel (which is currently the case for all flash attention installs). We use |
283 | | - the environment parameters to detect whether there is already a pre-built version of a compatible |
284 | | - wheel available and short-circuits the standard full build pipeline. |
285 | | - """ |
286 | | - |
287 | | - def run(self): |
288 | | - if FORCE_BUILD: |
289 | | - return super().run() |
290 | | - |
291 | | - wheel_url, wheel_filename = get_wheel_url() |
292 | | - print("Guessing wheel URL: ", wheel_url) |
293 | | - try: |
294 | | - urllib.request.urlretrieve(wheel_url, wheel_filename) |
295 | | - |
296 | | - # Make the archive |
297 | | - # Lifted from the root wheel processing command |
298 | | - # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 |
299 | | - if not os.path.exists(self.dist_dir): |
300 | | - os.makedirs(self.dist_dir) |
301 | | - |
302 | | - impl_tag, abi_tag, plat_tag = self.get_tag() |
303 | | - archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" |
304 | | - |
305 | | - wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") |
306 | | - print("Raw wheel path", wheel_path) |
307 | | - os.rename(wheel_filename, wheel_path) |
308 | | - except (urllib.error.HTTPError, urllib.error.URLError): |
309 | | - print("Precompiled wheel not found. Building from source...") |
310 | | - # If the wheel could not be downloaded, build from source |
311 | | - super().run() |
312 | | - |
313 | | - |
314 | | -class NinjaBuildExtension(BuildExtension): |
315 | | - def __init__(self, *args, **kwargs) -> None: |
316 | | - # do not override env MAX_JOBS if already exists |
317 | | - if not os.environ.get("MAX_JOBS"): |
318 | | - import psutil |
319 | | - |
320 | | - # calculate the maximum allowed NUM_JOBS based on cores |
321 | | - max_num_jobs_cores = max(1, (os.cpu_count() or 1) // 2) |
322 | | - |
323 | | - # calculate the maximum allowed NUM_JOBS based on free memory |
324 | | - free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB |
325 | | - max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4 |
326 | | - |
327 | | - # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation |
328 | | - max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory)) |
329 | | - os.environ["MAX_JOBS"] = str(max_jobs) |
330 | | - |
331 | | - super().__init__(*args, **kwargs) |
332 | | - |
333 | 2 |
|
334 | | -setup( |
335 | | - ext_modules=ext_modules, |
336 | | - cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension} |
337 | | - if ext_modules |
338 | | - else { |
339 | | - "bdist_wheel": CachedWheelsCommand, |
340 | | - }, |
341 | | -) |
| 3 | +setup() |
0 commit comments