Skip to content

Commit c60a014

Browse files
authored
Merge pull request #171 from SmallDoges/auto-workflow
Add support for targeted GPU architecture builds
2 parents 0594911 + 8a8a456 commit c60a014

File tree

4 files changed

+40
-5
lines changed

4 files changed

+40
-5
lines changed

.github/workflows/_build.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ on:
3232
description: "Upload wheel to this release"
3333
required: false
3434
type: string
35+
arch:
36+
description: "Target a single compute capability. Leave empty to build default archs"
37+
required: false
38+
type: string
3539

3640
defaults:
3741
run:
@@ -59,6 +63,7 @@ jobs:
5963
echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
6064
echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
6165
echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
66+
echo "MATRIX_ARCH=${{ inputs.arch }}" >> $GITHUB_ENV
6267
6368
- name: Free up disk space
6469
if: ${{ runner.os == 'Linux' }}
@@ -170,12 +175,21 @@ jobs:
170175
export FLASH_DMATTN_FORCE_BUILD="TRUE"
171176
export FLASH_DMATTN_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }}
172177
178+
# If specified, limit to a single compute capability to speed up build
179+
if [ -n "${MATRIX_ARCH}" ]; then
180+
export FLASH_DMATTN_CUDA_ARCHS="${MATRIX_ARCH}"
181+
fi
182+
173183
# 5h timeout since GH allows max 6h and we want some buffer
174184
EXIT_CODE=0
175185
timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$?
176186
177187
if [ $EXIT_CODE -eq 0 ]; then
178-
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
188+
if [ -n "${MATRIX_ARCH}" ]; then
189+
tmpname=sm${MATRIX_ARCH}cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
190+
else
191+
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
192+
fi
179193
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
180194
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
181195
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV

.github/workflows/build.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ on:
3333
description: "Upload wheel to this release"
3434
required: false
3535
type: string
36+
arch:
37+
description: "Target a single compute capability. Leave empty to use project default"
38+
required: false
39+
type: string
3640

3741
jobs:
3842
build-wheels:
@@ -45,3 +49,4 @@ jobs:
4549
cxx11_abi: ${{ inputs.cxx11_abi }}
4650
upload-to-release: ${{ inputs.upload-to-release }}
4751
release-version: ${{ inputs.release-version }}
52+
arch: ${{ inputs.arch }}

.github/workflows/publish.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ jobs:
5252
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
5353
# when building without C++11 ABI and using it on nvcr images.
5454
cxx11_abi: ["FALSE", "TRUE"]
55+
arch: ["80", "86", "89", "90", "100", "120"]
5556
include:
5657
- torch-version: "2.9.0.dev20250904"
5758
cuda-version: "13.0"
@@ -70,6 +71,7 @@ jobs:
7071
cuda-version: ${{ matrix.cuda-version }}
7172
torch-version: ${{ matrix.torch-version }}
7273
cxx11_abi: ${{ matrix.cxx11_abi }}
74+
arch: ${{ matrix.arch }}
7375
release-version: ${{ needs.setup_release.outputs.release-version }}
7476
upload-to-release: true
7577

setup.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
import re
88
import ast
99
import glob
10-
import shutil
1110
from pathlib import Path
1211
from packaging.version import parse, Version
1312
import platform
13+
from typing import Optional
1414

15-
from setuptools import setup, find_packages
15+
from setuptools import setup
1616
import subprocess
1717

1818
import urllib.request
@@ -22,7 +22,6 @@
2222
import torch
2323
from torch.utils.cpp_extension import (
2424
BuildExtension,
25-
CppExtension,
2625
CUDAExtension,
2726
CUDA_HOME,
2827
)
@@ -83,6 +82,20 @@ def cuda_archs():
8382
return os.getenv("FLASH_DMATTN_CUDA_ARCHS", "80;86;89;90;100;120").split(";")
8483

8584

85+
def detect_preferred_sm_arch() -> Optional[str]:
86+
"""Detect the preferred SM arch from the current CUDA device.
87+
Returns None if CUDA is unavailable or detection fails.
88+
"""
89+
try:
90+
if torch.cuda.is_available():
91+
idx = torch.cuda.current_device()
92+
major, minor = torch.cuda.get_device_capability(idx)
93+
return f"{major}{minor}"
94+
except Exception:
95+
pass
96+
return None
97+
98+
8699
def get_platform():
87100
"""
88101
Returns the platform name as used in wheel filenames.
@@ -237,6 +250,7 @@ def get_package_version():
237250

238251

239252
def get_wheel_url():
253+
sm_arch = detect_preferred_sm_arch()
240254
torch_version_raw = parse(torch.__version__)
241255
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
242256
platform_name = get_platform()
@@ -255,7 +269,7 @@ def get_wheel_url():
255269
cuda_version = f"{torch_cuda_version.major}"
256270

257271
# Determine wheel URL based on CUDA version, torch version, python version and OS
258-
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
272+
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+sm{sm_arch}cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
259273

260274
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
261275

0 commit comments

Comments
 (0)