-
Notifications
You must be signed in to change notification settings - Fork 40
Add support for targeted GPU architecture builds #171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
0ebcdb9
Adds arch input parameter to build workflow
algo-home 9609d4f
Adds GPU architecture matrix to publish workflow
algo-home 206598a
Adds configurable arch input for targeted builds
algo-home 8a8a456
Adds SM architecture detection for wheel naming
algo-home File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,12 +7,12 @@ | |||||||||
| import re | ||||||||||
| import ast | ||||||||||
| import glob | ||||||||||
| import shutil | ||||||||||
| from pathlib import Path | ||||||||||
| from packaging.version import parse, Version | ||||||||||
| import platform | ||||||||||
| from typing import Optional | ||||||||||
|
|
||||||||||
| from setuptools import setup, find_packages | ||||||||||
| from setuptools import setup | ||||||||||
| import subprocess | ||||||||||
|
|
||||||||||
| import urllib.request | ||||||||||
|
|
@@ -22,7 +22,6 @@ | |||||||||
| import torch | ||||||||||
| from torch.utils.cpp_extension import ( | ||||||||||
| BuildExtension, | ||||||||||
| CppExtension, | ||||||||||
| CUDAExtension, | ||||||||||
| CUDA_HOME, | ||||||||||
| ) | ||||||||||
|
|
@@ -83,6 +82,20 @@ def cuda_archs(): | |||||||||
| return os.getenv("FLASH_DMATTN_CUDA_ARCHS", "80;86;89;90;100;120").split(";") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def detect_preferred_sm_arch() -> Optional[str]: | ||||||||||
| """Detect the preferred SM arch from the current CUDA device. | ||||||||||
| Returns None if CUDA is unavailable or detection fails. | ||||||||||
| """ | ||||||||||
| try: | ||||||||||
| if torch.cuda.is_available(): | ||||||||||
| idx = torch.cuda.current_device() | ||||||||||
| major, minor = torch.cuda.get_device_capability(idx) | ||||||||||
| return f"{major}{minor}" | ||||||||||
| except Exception: | ||||||||||
| pass | ||||||||||
| return None | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def get_platform(): | ||||||||||
| """ | ||||||||||
| Returns the platform name as used in wheel filenames. | ||||||||||
|
|
@@ -237,6 +250,7 @@ def get_package_version(): | |||||||||
|
|
||||||||||
|
|
||||||||||
| def get_wheel_url(): | ||||||||||
| sm_arch = detect_preferred_sm_arch() | ||||||||||
| torch_version_raw = parse(torch.__version__) | ||||||||||
| python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" | ||||||||||
| platform_name = get_platform() | ||||||||||
|
|
@@ -255,7 +269,7 @@ def get_wheel_url(): | |||||||||
| cuda_version = f"{torch_cuda_version.major}" | ||||||||||
|
|
||||||||||
| # Determine wheel URL based on CUDA version, torch version, python version and OS | ||||||||||
| wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" | ||||||||||
| wheel_filename = f"{PACKAGE_NAME}-{flash_version}+sm{sm_arch}cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" | ||||||||||
|
||||||||||
| wheel_filename = f"{PACKAGE_NAME}-{flash_version}+sm{sm_arch}cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" | |
| sm_arch_str = f"sm{sm_arch}" if sm_arch is not None else "" | |
| plus = "+" if sm_arch_str else "" | |
| wheel_filename = f"{PACKAGE_NAME}-{flash_version}{plus}{sm_arch_str}cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The wheel naming logic is duplicated between setup.py and the workflow file. Consider extracting this logic to a shared function or script to avoid inconsistencies and reduce maintenance burden.