Skip to content

Commit bd9476e

Browse files
committed
2 parents 0576b36 + da8b8a4 commit bd9476e

File tree

2 files changed

+73
-3
lines changed

2 files changed

+73
-3
lines changed

.github/workflows/publish.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
4646
os: [ubuntu-22.04, ubuntu-22.04-arm]
4747
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
48-
torch-version: ["2.4.0", "2.5.1", "2.6.0", "2.7.1", "2.8.0"]
48+
torch-version: ["2.5.1", "2.6.0", "2.7.1", "2.8.0"]
4949
cuda-version: ["12.9.1"]
5050
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
5151
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
@@ -60,6 +60,9 @@ jobs:
6060
# Pytorch < 2.5 does not support Python 3.13
6161
- torch-version: "2.4.0"
6262
python-version: "3.13"
63+
- torch-version: "2.5.1"
64+
python-version: "3.13"
65+
os: ubuntu-22.04-arm
6366
uses: ./.github/workflows/_build.yml
6467
with:
6568
runs-on: ${{ matrix.os }}

setup.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737

3838
PACKAGE_NAME = "flash_dmattn"
3939

40+
BASE_WHEEL_URL = (
41+
"https://github.com/SmallDoges/flash-dmattn/releases/download/{tag_name}/{wheel_name}"
42+
)
43+
4044
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
4145
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
4246
# Also useful when user only wants Triton/Flex backends without CUDA compilation
@@ -307,6 +311,67 @@ def get_package_version():
307311
return str(public_version)
308312

309313

314+
def get_wheel_url():
315+
torch_version_raw = parse(torch.__version__)
316+
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
317+
platform_name = get_platform()
318+
flash_version = get_package_version()
319+
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
320+
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
321+
322+
# Determine the version numbers that will be used to determine the correct wheel
323+
# We're using the CUDA version used to build torch, not the one currently installed
324+
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
325+
torch_cuda_version = parse(torch.version.cuda)
326+
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
327+
# to save CI time. Minor versions should be compatible.
328+
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
329+
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
330+
cuda_version = f"{torch_cuda_version.major}"
331+
332+
# Determine wheel URL based on CUDA version, torch version, python version and OS
333+
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
334+
335+
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
336+
337+
return wheel_url, wheel_filename
338+
339+
340+
class CachedWheelsCommand(_bdist_wheel):
341+
"""
342+
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
343+
find an existing wheel (which is currently the case for all flash attention installs). We use
344+
the environment parameters to detect whether there is already a pre-built version of a compatible
345+
wheel available and short-circuits the standard full build pipeline.
346+
"""
347+
348+
def run(self):
349+
if FORCE_BUILD:
350+
return super().run()
351+
352+
wheel_url, wheel_filename = get_wheel_url()
353+
print("Guessing wheel URL: ", wheel_url)
354+
try:
355+
urllib.request.urlretrieve(wheel_url, wheel_filename)
356+
357+
# Make the archive
358+
# Lifted from the root wheel processing command
359+
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
360+
if not os.path.exists(self.dist_dir):
361+
os.makedirs(self.dist_dir)
362+
363+
impl_tag, abi_tag, plat_tag = self.get_tag()
364+
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
365+
366+
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
367+
print("Raw wheel path", wheel_path)
368+
os.rename(wheel_filename, wheel_path)
369+
except (urllib.error.HTTPError, urllib.error.URLError):
370+
print("Precompiled wheel not found. Building from source...")
371+
# If the wheel could not be downloaded, build from source
372+
super().run()
373+
374+
310375
class NinjaBuildExtension(BuildExtension):
311376
def __init__(self, *args, **kwargs) -> None:
312377
# do not override env MAX_JOBS if already exists
@@ -329,7 +394,9 @@ def __init__(self, *args, **kwargs) -> None:
329394

330395
setup(
331396
ext_modules=ext_modules,
332-
cmdclass={"build_ext": NinjaBuildExtension}
397+
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension}
333398
if ext_modules
334-
else {},
399+
else {
400+
"bdist_wheel": CachedWheelsCommand,
401+
},
335402
)

0 commit comments

Comments
 (0)