77import re
88import ast
99import glob
10- import shutil
1110from pathlib import Path
1211from packaging .version import parse , Version
1312import platform
13+ from typing import Optional
1414
15- from setuptools import setup , find_packages
15+ from setuptools import setup
1616import subprocess
1717
1818import urllib .request
2222import torch
2323from torch .utils .cpp_extension import (
2424 BuildExtension ,
25- CppExtension ,
2625 CUDAExtension ,
2726 CUDA_HOME ,
2827)
@@ -83,6 +82,20 @@ def cuda_archs():
8382 return os .getenv ("FLASH_DMATTN_CUDA_ARCHS" , "80;86;89;90;100;120" ).split (";" )
8483
8584
85+ def detect_preferred_sm_arch () -> Optional [str ]:
86+ """Detect the preferred SM arch from the current CUDA device.
87+ Returns None if CUDA is unavailable or detection fails.
88+ """
89+ try :
90+ if torch .cuda .is_available ():
91+ idx = torch .cuda .current_device ()
92+ major , minor = torch .cuda .get_device_capability (idx )
93+ return f"{ major } { minor } "
94+ except Exception :
95+ pass
96+ return None
97+
98+
8699def get_platform ():
87100 """
88101 Returns the platform name as used in wheel filenames.
@@ -237,6 +250,7 @@ def get_package_version():
237250
238251
239252def get_wheel_url ():
253+ sm_arch = detect_preferred_sm_arch ()
240254 torch_version_raw = parse (torch .__version__ )
241255 python_version = f"cp{ sys .version_info .major } { sys .version_info .minor } "
242256 platform_name = get_platform ()
@@ -255,7 +269,7 @@ def get_wheel_url():
255269 cuda_version = f"{ torch_cuda_version .major } "
256270
257271 # Determine wheel URL based on CUDA version, torch version, python version and OS
258- wheel_filename = f"{ PACKAGE_NAME } -{ flash_version } +cu{ cuda_version } torch{ torch_version } cxx11abi{ cxx11_abi } -{ python_version } -{ python_version } -{ platform_name } .whl"
272+ wheel_filename = f"{ PACKAGE_NAME } -{ flash_version } +sm { sm_arch } cu{ cuda_version } torch{ torch_version } cxx11abi{ cxx11_abi } -{ python_version } -{ python_version } -{ platform_name } .whl"
259273
260274 wheel_url = BASE_WHEEL_URL .format (tag_name = f"v{ flash_version } " , wheel_name = wheel_filename )
261275
0 commit comments