Skip to content

Commit 206598a

Browse files
committed
Adds configurable arch input for targeted builds
Enables building for a single compute capability to reduce build time when targeting specific GPU architectures. Updates wheel naming convention to include arch identifier when specified, ensuring proper artifact identification for architecture-specific builds.
1 parent 9609d4f commit 206598a

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

.github/workflows/_build.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ on:
3232
description: "Upload wheel to this release"
3333
required: false
3434
type: string
35+
arch:
36+
description: "Target a single compute capability. Leave empty to build default archs"
37+
required: false
38+
type: string
3539

3640
defaults:
3741
run:
@@ -59,6 +63,7 @@ jobs:
5963
echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
6064
echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
6165
echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
66+
echo "MATRIX_ARCH=${{ inputs.arch }}" >> $GITHUB_ENV
6267
6368
- name: Free up disk space
6469
if: ${{ runner.os == 'Linux' }}
@@ -170,12 +175,21 @@ jobs:
170175
export FLASH_DMATTN_FORCE_BUILD="TRUE"
171176
export FLASH_DMATTN_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }}
172177
178+
# If specified, limit to a single compute capability to speed up build
179+
if [ -n "${MATRIX_ARCH}" ]; then
180+
export FLASH_DMATTN_CUDA_ARCHS="${MATRIX_ARCH}"
181+
fi
182+
173183
# 5h timeout since GH allows max 6h and we want some buffer
174184
EXIT_CODE=0
175185
timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$?
176186
177187
if [ $EXIT_CODE -eq 0 ]; then
178-
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
188+
if [ -n "${MATRIX_ARCH}" ]; then
189+
tmpname=sm${MATRIX_ARCH}cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
190+
else
191+
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
192+
fi
179193
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
180194
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
181195
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV

0 commit comments

Comments
 (0)