3737
3838PACKAGE_NAME = "flash_dmattn"
3939
40+ BASE_WHEEL_URL = (
41+ "https://github.com/SmallDoges/flash-dmattn/releases/download/{tag_name}/{wheel_name}"
42+ )
43+
4044# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
4145# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
4246# Also useful when user only wants Triton/Flex backends without CUDA compilation
@@ -307,6 +311,67 @@ def get_package_version():
307311 return str (public_version )
308312
309313
314+ def get_wheel_url ():
315+ torch_version_raw = parse (torch .__version__ )
316+ python_version = f"cp{ sys .version_info .major } { sys .version_info .minor } "
317+ platform_name = get_platform ()
318+ flash_version = get_package_version ()
319+ torch_version = f"{ torch_version_raw .major } .{ torch_version_raw .minor } "
320+ cxx11_abi = str (torch ._C ._GLIBCXX_USE_CXX11_ABI ).upper ()
321+
322+ # Determine the version numbers that will be used to determine the correct wheel
323+ # We're using the CUDA version used to build torch, not the one currently installed
324+ # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
325+ torch_cuda_version = parse (torch .version .cuda )
326+ # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
327+ # to save CI time. Minor versions should be compatible.
328+ torch_cuda_version = parse ("11.8" ) if torch_cuda_version .major == 11 else parse ("12.3" )
329+ # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
330+ cuda_version = f"{ torch_cuda_version .major } "
331+
332+ # Determine wheel URL based on CUDA version, torch version, python version and OS
333+ wheel_filename = f"{ PACKAGE_NAME } -{ flash_version } +cu{ cuda_version } torch{ torch_version } cxx11abi{ cxx11_abi } -{ python_version } -{ python_version } -{ platform_name } .whl"
334+
335+ wheel_url = BASE_WHEEL_URL .format (tag_name = f"v{ flash_version } " , wheel_name = wheel_filename )
336+
337+ return wheel_url , wheel_filename
338+
339+
340+ class CachedWheelsCommand (_bdist_wheel ):
341+ """
342+ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
343+ find an existing wheel (which is currently the case for all flash attention installs). We use
344+ the environment parameters to detect whether there is already a pre-built version of a compatible
345+ wheel available and short-circuits the standard full build pipeline.
346+ """
347+
348+ def run (self ):
349+ if FORCE_BUILD :
350+ return super ().run ()
351+
352+ wheel_url , wheel_filename = get_wheel_url ()
353+ print ("Guessing wheel URL: " , wheel_url )
354+ try :
355+ urllib .request .urlretrieve (wheel_url , wheel_filename )
356+
357+ # Make the archive
358+ # Lifted from the root wheel processing command
359+ # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
360+ if not os .path .exists (self .dist_dir ):
361+ os .makedirs (self .dist_dir )
362+
363+ impl_tag , abi_tag , plat_tag = self .get_tag ()
364+ archive_basename = f"{ self .wheel_dist_name } -{ impl_tag } -{ abi_tag } -{ plat_tag } "
365+
366+ wheel_path = os .path .join (self .dist_dir , archive_basename + ".whl" )
367+ print ("Raw wheel path" , wheel_path )
368+ os .rename (wheel_filename , wheel_path )
369+ except (urllib .error .HTTPError , urllib .error .URLError ):
370+ print ("Precompiled wheel not found. Building from source..." )
371+ # If the wheel could not be downloaded, build from source
372+ super ().run ()
373+
374+
310375class NinjaBuildExtension (BuildExtension ):
311376 def __init__ (self , * args , ** kwargs ) -> None :
312377 # do not override env MAX_JOBS if already exists
@@ -329,7 +394,9 @@ def __init__(self, *args, **kwargs) -> None:
329394
330395setup (
331396 ext_modules = ext_modules ,
332- cmdclass = {"build_ext" : NinjaBuildExtension }
397+ cmdclass = {"bdist_wheel" : CachedWheelsCommand , " build_ext" : NinjaBuildExtension }
333398 if ext_modules
334- else {},
399+ else {
400+ "bdist_wheel" : CachedWheelsCommand ,
401+ },
335402)
0 commit comments