Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion .github/workflows/_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ on:
description: "Upload wheel to this release"
required: false
type: string
arch:
description: "Target a single compute capability. Leave empty to build default archs"
required: false
type: string

defaults:
run:
Expand Down Expand Up @@ -59,6 +63,7 @@ jobs:
echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_ARCH=${{ inputs.arch }}" >> $GITHUB_ENV

- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
Expand Down Expand Up @@ -170,12 +175,21 @@ jobs:
export FLASH_DMATTN_FORCE_BUILD="TRUE"
export FLASH_DMATTN_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }}

# If specified, limit to a single compute capability to speed up build
if [ -n "${MATRIX_ARCH}" ]; then
export FLASH_DMATTN_CUDA_ARCHS="${MATRIX_ARCH}"
fi

# 5h timeout since GH allows max 6h and we want some buffer
EXIT_CODE=0
timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$?

if [ $EXIT_CODE -eq 0 ]; then
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
if [ -n "${MATRIX_ARCH}" ]; then
tmpname=sm${MATRIX_ARCH}cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
else
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
fi
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
Comment on lines +188 to 193
Copy link

Copilot AI Sep 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The wheel naming logic is duplicated between setup.py and the workflow file. Consider extracting this logic to a shared function or script to avoid inconsistencies and reduce maintenance burden.

Suggested change
if [ -n "${MATRIX_ARCH}" ]; then
tmpname=sm${MATRIX_ARCH}cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
else
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
fi
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
# Use shared Python script to generate wheel name suffix
wheel_suffix=$(python scripts/wheel_name.py \
--arch "${MATRIX_ARCH}" \
--cuda-version "${WHEEL_CUDA_VERSION}" \
--torch-version "${MATRIX_TORCH_VERSION}" \
--cxx11-abi "${{ inputs.cxx11_abi }}")
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+${wheel_suffix}-/2")

Copilot uses AI. Check for mistakes.
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ on:
description: "Upload wheel to this release"
required: false
type: string
arch:
description: "Target a single compute capability. Leave empty to use project default"
required: false
type: string

jobs:
build-wheels:
Expand All @@ -45,3 +49,4 @@ jobs:
cxx11_abi: ${{ inputs.cxx11_abi }}
upload-to-release: ${{ inputs.upload-to-release }}
release-version: ${{ inputs.release-version }}
arch: ${{ inputs.arch }}
2 changes: 2 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ["FALSE", "TRUE"]
arch: ["80", "86", "89", "90", "100", "120"]
include:
- torch-version: "2.9.0.dev20250904"
cuda-version: "13.0"
Expand All @@ -70,6 +71,7 @@ jobs:
cuda-version: ${{ matrix.cuda-version }}
torch-version: ${{ matrix.torch-version }}
cxx11_abi: ${{ matrix.cxx11_abi }}
arch: ${{ matrix.arch }}
release-version: ${{ needs.setup_release.outputs.release-version }}
upload-to-release: true

Expand Down
22 changes: 18 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import re
import ast
import glob
import shutil
from pathlib import Path
from packaging.version import parse, Version
import platform
from typing import Optional

from setuptools import setup, find_packages
from setuptools import setup
import subprocess

import urllib.request
Expand All @@ -22,7 +22,6 @@
import torch
from torch.utils.cpp_extension import (
BuildExtension,
CppExtension,
CUDAExtension,
CUDA_HOME,
)
Expand Down Expand Up @@ -83,6 +82,20 @@ def cuda_archs():
return os.getenv("FLASH_DMATTN_CUDA_ARCHS", "80;86;89;90;100;120").split(";")


def detect_preferred_sm_arch() -> Optional[str]:
"""Detect the preferred SM arch from the current CUDA device.
Returns None if CUDA is unavailable or detection fails.
"""
try:
if torch.cuda.is_available():
idx = torch.cuda.current_device()
major, minor = torch.cuda.get_device_capability(idx)
return f"{major}{minor}"
except Exception:
pass
return None


def get_platform():
"""
Returns the platform name as used in wheel filenames.
Expand Down Expand Up @@ -237,6 +250,7 @@ def get_package_version():


def get_wheel_url():
sm_arch = detect_preferred_sm_arch()
torch_version_raw = parse(torch.__version__)
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
Expand All @@ -255,7 +269,7 @@ def get_wheel_url():
cuda_version = f"{torch_cuda_version.major}"

# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+sm{sm_arch}cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
Copy link

Copilot AI Sep 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wheel filename will include 'sm{sm_arch}' even when sm_arch is None, resulting in 'smNone' in the filename. This should be handled conditionally to maintain backward compatibility when CUDA is unavailable or detection fails.

Suggested change
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+sm{sm_arch}cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
sm_arch_str = f"sm{sm_arch}" if sm_arch is not None else ""
plus = "+" if sm_arch_str else ""
wheel_filename = f"{PACKAGE_NAME}-{flash_version}{plus}{sm_arch_str}cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"

Copilot uses AI. Check for mistakes.

wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)

Expand Down
Loading