Skip to content

Commit 68a0a08

Browse files
committed
Remove unused code and dependencies from setup.py
1 parent 3caf64c commit 68a0a08

File tree

1 file changed

+1
-339
lines changed

1 file changed

+1
-339
lines changed

setup.py

Lines changed: 1 addition & 339 deletions
Original file line numberDiff line numberDiff line change
@@ -1,341 +1,3 @@
1-
# Copyright (c) 2025, Jingze Shi.
2-
3-
import sys
4-
import functools
5-
import warnings
6-
import os
7-
import re
8-
import ast
9-
import glob
10-
from pathlib import Path
11-
from packaging.version import parse, Version
12-
import platform
13-
from typing import Optional
14-
151
from setuptools import setup
16-
import subprocess
17-
18-
import urllib.request
19-
import urllib.error
20-
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
21-
22-
import torch
23-
from torch.utils.cpp_extension import (
24-
BuildExtension,
25-
CUDAExtension,
26-
CUDA_HOME,
27-
)
28-
29-
30-
with open("README.md", "r", encoding="utf-8") as fh:
31-
long_description = fh.read()
32-
33-
34-
# ninja build does not work unless include_dirs are abs path
35-
this_dir = os.path.dirname(os.path.abspath(__file__))
36-
37-
PACKAGE_NAME = "flash_sparse_attn"
38-
39-
BASE_WHEEL_URL = (
40-
"https://github.com/flash-algo/flash-sparse-attention/releases/download/{tag_name}/{wheel_name}"
41-
)
42-
43-
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
44-
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
45-
# Also useful when user only wants Triton/Flex backends without CUDA compilation
46-
FORCE_BUILD = os.getenv("FLASH_SPARSE_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
47-
SKIP_CUDA_BUILD = os.getenv("FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
48-
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
49-
FORCE_CXX11_ABI = os.getenv("FLASH_SPARSE_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
50-
51-
# Auto-detect if user wants only Triton/Flex backends based on pip install command
52-
# This helps avoid unnecessary CUDA compilation when user only wants Python backends
53-
def should_skip_cuda_build():
54-
"""Determine if CUDA build should be skipped based on installation context."""
55-
56-
if SKIP_CUDA_BUILD:
57-
return True
58-
59-
if FORCE_BUILD:
60-
return False # User explicitly wants to build, respect that
61-
62-
# Check command line arguments for installation hints
63-
if len(sys.argv) > 1:
64-
install_args = ' '.join(sys.argv)
65-
66-
# Check if Triton or Flex extras are requested
67-
has_triton_or_flex = 'triton' in install_args or 'flex' in install_args
68-
has_all_or_dev = 'all' in install_args or 'dev' in install_args
69-
70-
if has_triton_or_flex and not has_all_or_dev:
71-
print("Detected Triton/Flex-only installation. Skipping CUDA compilation.")
72-
print("Set FLASH_SPARSE_ATTENTION_FORCE_BUILD=TRUE to force CUDA compilation.")
73-
return True
74-
75-
return False
76-
77-
# Update SKIP_CUDA_BUILD based on auto-detection
78-
SKIP_CUDA_BUILD = should_skip_cuda_build()
79-
80-
@functools.lru_cache(maxsize=None)
81-
def cuda_archs():
82-
return os.getenv("FLASH_SPARSE_ATTENTION_CUDA_ARCHS", "80;90;100").split(";")
83-
84-
85-
def detect_preferred_sm_arch() -> Optional[str]:
86-
"""Detect the preferred SM arch from the current CUDA device.
87-
Returns None if CUDA is unavailable or detection fails.
88-
"""
89-
try:
90-
if torch.cuda.is_available():
91-
idx = torch.cuda.current_device()
92-
major, minor = torch.cuda.get_device_capability(idx)
93-
return f"{major}{minor}"
94-
except Exception:
95-
pass
96-
return None
97-
98-
99-
def get_platform():
100-
"""
101-
Returns the platform name as used in wheel filenames.
102-
"""
103-
if sys.platform.startswith("linux"):
104-
return f'linux_{platform.uname().machine}'
105-
elif sys.platform == "darwin":
106-
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
107-
return f"macosx_{mac_version}_x86_64"
108-
elif sys.platform == "win32":
109-
return "win_amd64"
110-
else:
111-
raise ValueError("Unsupported platform: {}".format(sys.platform))
112-
113-
114-
def get_cuda_bare_metal_version(cuda_dir):
115-
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
116-
output = raw_output.split()
117-
release_idx = output.index("release") + 1
118-
bare_metal_version = parse(output[release_idx].split(",")[0])
119-
120-
return raw_output, bare_metal_version
121-
122-
123-
def check_if_cuda_home_none(global_option: str) -> None:
124-
if CUDA_HOME is not None:
125-
return
126-
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
127-
# in that case.
128-
warnings.warn(
129-
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
130-
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
131-
"only images whose names contain 'devel' will provide nvcc."
132-
)
133-
134-
135-
def append_nvcc_threads(nvcc_extra_args):
136-
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
137-
return nvcc_extra_args + ["--threads", nvcc_threads]
138-
139-
140-
cmdclass = {}
141-
ext_modules = []
142-
143-
# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
144-
# files included in the source distribution, in case the user compiles from source.
145-
if os.path.isdir(".git"):
146-
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"], check=True)
147-
else:
148-
assert (
149-
os.path.exists("csrc/cutlass/include/cutlass/cutlass.h")
150-
), "csrc/cutlass is missing, please use source distribution or git clone"
151-
152-
if not SKIP_CUDA_BUILD:
153-
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
154-
TORCH_MAJOR = int(torch.__version__.split(".")[0])
155-
TORCH_MINOR = int(torch.__version__.split(".")[1])
156-
157-
check_if_cuda_home_none("flash_sparse_attn")
158-
# Check, if CUDA11 is installed for compute capability 8.0
159-
cc_flag = []
160-
if CUDA_HOME is not None:
161-
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
162-
if bare_metal_version < Version("11.7"):
163-
raise RuntimeError(
164-
"Flash Sparse Attention is only supported on CUDA 11.7 and above. "
165-
"Note: make sure nvcc has a supported version by running nvcc -V."
166-
)
167-
168-
if "80" in cuda_archs():
169-
cc_flag.append("-gencode")
170-
cc_flag.append("arch=compute_80,code=sm_80")
171-
172-
if CUDA_HOME is not None:
173-
if bare_metal_version >= Version("11.8") and "86" in cuda_archs():
174-
cc_flag.append("-gencode")
175-
cc_flag.append("arch=compute_86,code=sm_86")
176-
if bare_metal_version >= Version("11.8") and "89" in cuda_archs():
177-
cc_flag.append("-gencode")
178-
cc_flag.append("arch=compute_89,code=sm_89")
179-
if bare_metal_version >= Version("11.8") and "90" in cuda_archs():
180-
cc_flag.append("-gencode")
181-
cc_flag.append("arch=compute_90,code=sm_90")
182-
if bare_metal_version >= Version("12.8") and "100" in cuda_archs():
183-
cc_flag.append("-gencode")
184-
cc_flag.append("arch=compute_100,code=sm_100")
185-
if bare_metal_version >= Version("12.8") and "120" in cuda_archs():
186-
cc_flag.append("-gencode")
187-
cc_flag.append("arch=compute_120,code=sm_120")
188-
189-
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
190-
# torch._C._GLIBCXX_USE_CXX11_ABI
191-
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
192-
if FORCE_CXX11_ABI:
193-
torch._C._GLIBCXX_USE_CXX11_ABI = True
194-
195-
nvcc_flags = [
196-
"-O3",
197-
"-std=c++17",
198-
"-U__CUDA_NO_HALF_OPERATORS__",
199-
"-U__CUDA_NO_HALF_CONVERSIONS__",
200-
"-U__CUDA_NO_HALF2_OPERATORS__",
201-
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
202-
"--expt-relaxed-constexpr",
203-
"--expt-extended-lambda",
204-
"--use_fast_math",
205-
# "--ptxas-options=-v",
206-
# "--ptxas-options=-O2",
207-
# "-lineinfo",
208-
# "-DFLASHATTENTION_DISABLE_BACKWARD",
209-
# "-DFLASHATTENTION_DISABLE_SOFTCAP",
210-
# "-DFLASHATTENTION_DISABLE_UNEVEN_K",
211-
]
212-
213-
compiler_c17_flag=["-O3", "-std=c++17"]
214-
# Add Windows-specific flags
215-
if sys.platform == "win32" and os.getenv('DISTUTILS_USE_SDK') == '1':
216-
nvcc_flags.extend(["-Xcompiler", "/Zc:__cplusplus"])
217-
compiler_c17_flag=["-O2", "/std:c++17", "/Zc:__cplusplus"]
218-
219-
ext_modules.append(
220-
CUDAExtension(
221-
name="flash_sparse_attn_cuda",
222-
sources=(
223-
[
224-
"csrc/flash_sparse_attn/flash_api.cpp",
225-
]
226-
+ sorted(glob.glob("csrc/flash_sparse_attn/src/instantiations/flash_*.cu"))
227-
),
228-
extra_compile_args={
229-
"cxx": compiler_c17_flag,
230-
"nvcc": append_nvcc_threads(nvcc_flags + cc_flag),
231-
},
232-
include_dirs=[
233-
Path(this_dir) / "csrc" / "flash_sparse_attn",
234-
Path(this_dir) / "csrc" / "flash_sparse_attn" / "src",
235-
Path(this_dir) / "csrc" / "cutlass" / "include",
236-
],
237-
)
238-
)
239-
240-
241-
def get_package_version():
242-
with open(Path(this_dir) / "flash_sparse_attn" / "__init__.py", "r") as f:
243-
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
244-
public_version = ast.literal_eval(version_match.group(1))
245-
local_version = os.environ.get("FLASH_SPARSE_ATTENTION_LOCAL_VERSION")
246-
if local_version:
247-
return f"{public_version}+{local_version}"
248-
else:
249-
return str(public_version)
250-
251-
252-
def get_wheel_url():
253-
sm_arch = detect_preferred_sm_arch()
254-
torch_version_raw = parse(torch.__version__)
255-
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
256-
platform_name = get_platform()
257-
flash_version = get_package_version()
258-
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
259-
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
260-
261-
# Determine the version numbers that will be used to determine the correct wheel
262-
# We're using the CUDA version used to build torch, not the one currently installed
263-
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
264-
torch_cuda_version = parse(torch.version.cuda)
265-
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
266-
# to save CI time. Minor versions should be compatible.
267-
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
268-
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
269-
cuda_version = f"{torch_cuda_version.major}"
270-
271-
# Determine wheel URL based on CUDA version, torch version, python version and OS
272-
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+sm{sm_arch}cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
273-
274-
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
275-
276-
return wheel_url, wheel_filename
277-
278-
279-
class CachedWheelsCommand(_bdist_wheel):
280-
"""
281-
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
282-
find an existing wheel (which is currently the case for all flash attention installs). We use
283-
the environment parameters to detect whether there is already a pre-built version of a compatible
284-
wheel available and short-circuits the standard full build pipeline.
285-
"""
286-
287-
def run(self):
288-
if FORCE_BUILD:
289-
return super().run()
290-
291-
wheel_url, wheel_filename = get_wheel_url()
292-
print("Guessing wheel URL: ", wheel_url)
293-
try:
294-
urllib.request.urlretrieve(wheel_url, wheel_filename)
295-
296-
# Make the archive
297-
# Lifted from the root wheel processing command
298-
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
299-
if not os.path.exists(self.dist_dir):
300-
os.makedirs(self.dist_dir)
301-
302-
impl_tag, abi_tag, plat_tag = self.get_tag()
303-
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
304-
305-
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
306-
print("Raw wheel path", wheel_path)
307-
os.rename(wheel_filename, wheel_path)
308-
except (urllib.error.HTTPError, urllib.error.URLError):
309-
print("Precompiled wheel not found. Building from source...")
310-
# If the wheel could not be downloaded, build from source
311-
super().run()
312-
313-
314-
class NinjaBuildExtension(BuildExtension):
315-
def __init__(self, *args, **kwargs) -> None:
316-
# do not override env MAX_JOBS if already exists
317-
if not os.environ.get("MAX_JOBS"):
318-
import psutil
319-
320-
# calculate the maximum allowed NUM_JOBS based on cores
321-
max_num_jobs_cores = max(1, (os.cpu_count() or 1) // 2)
322-
323-
# calculate the maximum allowed NUM_JOBS based on free memory
324-
free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB
325-
max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4
326-
327-
# pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
328-
max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))
329-
os.environ["MAX_JOBS"] = str(max_jobs)
330-
331-
super().__init__(*args, **kwargs)
332-
3332

334-
setup(
335-
ext_modules=ext_modules,
336-
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension}
337-
if ext_modules
338-
else {
339-
"bdist_wheel": CachedWheelsCommand,
340-
},
341-
)
3+
setup()

0 commit comments

Comments
 (0)