Skip to content

Commit 8a8a456

Browse files
committed
Adds SM architecture detection for wheel naming
Introduces automatic detection of the preferred SM (Streaming Multiprocessor) architecture from the current CUDA device to improve wheel filename specificity. The detection function safely handles cases where CUDA is unavailable or detection fails by returning None. This enhancement allows for more precise wheel identification based on the actual hardware capabilities rather than relying solely on CUDA version information. Removes unused imports to clean up the codebase.
1 parent 206598a commit 8a8a456

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

setup.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
import re
88
import ast
99
import glob
10-
import shutil
1110
from pathlib import Path
1211
from packaging.version import parse, Version
1312
import platform
13+
from typing import Optional
1414

15-
from setuptools import setup, find_packages
15+
from setuptools import setup
1616
import subprocess
1717

1818
import urllib.request
@@ -22,7 +22,6 @@
2222
import torch
2323
from torch.utils.cpp_extension import (
2424
BuildExtension,
25-
CppExtension,
2625
CUDAExtension,
2726
CUDA_HOME,
2827
)
@@ -83,6 +82,20 @@ def cuda_archs():
8382
return os.getenv("FLASH_DMATTN_CUDA_ARCHS", "80;86;89;90;100;120").split(";")
8483

8584

85+
def detect_preferred_sm_arch() -> Optional[str]:
86+
"""Detect the preferred SM arch from the current CUDA device.
87+
Returns None if CUDA is unavailable or detection fails.
88+
"""
89+
try:
90+
if torch.cuda.is_available():
91+
idx = torch.cuda.current_device()
92+
major, minor = torch.cuda.get_device_capability(idx)
93+
return f"{major}{minor}"
94+
except Exception:
95+
pass
96+
return None
97+
98+
8699
def get_platform():
87100
"""
88101
Returns the platform name as used in wheel filenames.
@@ -237,6 +250,7 @@ def get_package_version():
237250

238251

239252
def get_wheel_url():
253+
sm_arch = detect_preferred_sm_arch()
240254
torch_version_raw = parse(torch.__version__)
241255
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
242256
platform_name = get_platform()
@@ -255,7 +269,7 @@ def get_wheel_url():
255269
cuda_version = f"{torch_cuda_version.major}"
256270

257271
# Determine wheel URL based on CUDA version, torch version, python version and OS
258-
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
272+
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+sm{sm_arch}cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
259273

260274
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
261275

0 commit comments

Comments
 (0)