Skip to content

Commit 075775e

Browse files
authored
bugfix: remove the append "a" logic if user specifies cuda arch explicitly (#1798)
<!-- .github/pull_request_template.md --> ## 📌 Description In our design of compilation context we will append "a" to cuda archs if the major arch is >= 9, however, we shouldn't change it if user specify them explicitly. This PR changes the logic so that we only append "a" to cuda arch if `FLASHINFER_CUDA_ARCH_LIST` is not set explicitly, otherwise follow the user specified cuda arch. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 2931569 commit 075775e

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

flashinfer/compilation_context.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ def __init__(self):
3636
for arch in os.environ["FLASHINFER_CUDA_ARCH_LIST"].split(" "):
3737
major, minor = arch.split(".")
3838
major = int(major)
39-
if major >= 9:
40-
if minor.isdigit():
41-
minor = str(minor) + "a"
4239
self.TARGET_CUDA_ARCHS.add((int(major), str(minor)))
4340
else:
4441
try:

scripts/task_test_aot_build_import.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ export MAX_JOBS
2020
export FLASHINFER_CUDA_ARCH_LIST=$(python3 -c '
2121
import torch
2222
cuda_ver = torch.version.cuda
23-
arches = ["7.5", "8.0", "8.9", "9.0"]
23+
arches = ["7.5", "8.0", "8.9", "9.0a"]
2424
if cuda_ver is not None:
2525
try:
2626
major, minor = map(int, cuda_ver.split(".")[:2])
2727
if (major, minor) >= (12, 8):
28-
arches.append("10.0")
29-
arches.append("12.0")
28+
arches.append("10.0a")
29+
arches.append("12.0a")
3030
except Exception:
3131
pass
3232
print(" ".join(arches))

0 commit comments

Comments
 (0)