Skip to content

Commit 7ee54c7

Browse files
authored
ci/cd: bring up flashinfer-cubin package (#1718)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description User prefers a standalone wheel for cubin files in flashinfer, this PR implements this feature. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent c1ffbd0 commit 7ee54c7

File tree

11 files changed

+458
-4
lines changed

11 files changed

+458
-4
lines changed

β€Ž.pre-commit-config.yamlβ€Ž

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ repos:
4747
rev: 'v1.17.1' # Use the sha / tag you want to point at
4848
hooks:
4949
- id: mypy
50+
args: ["--config-file", "pyproject.toml", "--exclude", "flashinfer-cubin/"]
51+
files: ^flashinfer/
52+
exclude: ^(flashinfer-cubin/|3rdparty/|build/)
5053

5154
- repo: https://github.com/astral-sh/ruff-pre-commit
5255
# Ruff version.

β€Žflashinfer-cubin/.gitignoreβ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
build/
2+
flashinfer_cubin/cubins/

β€Žflashinfer-cubin/MANIFEST.inβ€Ž

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
include README.md
2+
include LICENSE
3+
recursive-include flashinfer_cubin/cubins *
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Copyright (c) 2025 by FlashInfer team.
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
"""
17+
18+
import os
19+
import subprocess
20+
import sys
21+
from pathlib import Path
22+
23+
24+
def build_wheel():
25+
"""Build the flashinfer-cubin wheel."""
26+
27+
# Change to the flashinfer-cubin directory
28+
script_dir = Path(__file__).parent
29+
os.chdir(script_dir)
30+
31+
print("Building flashinfer-cubin wheel...")
32+
print(f"Working directory: {script_dir}")
33+
34+
# Clean previous builds
35+
dist_dir = script_dir / "dist"
36+
build_dir = script_dir / "build"
37+
egg_info_dir = script_dir / "flashinfer_cubin.egg-info"
38+
39+
for dir_to_clean in [dist_dir, build_dir, egg_info_dir]:
40+
if dir_to_clean.exists():
41+
print(f"Cleaning {dir_to_clean}")
42+
import shutil
43+
44+
shutil.rmtree(dir_to_clean)
45+
46+
# Build wheel
47+
try:
48+
subprocess.run([sys.executable, "setup.py", "bdist_wheel"], check=True)
49+
50+
print("Wheel built successfully!")
51+
52+
# List built wheels
53+
if dist_dir.exists():
54+
wheels = list(dist_dir.glob("*.whl"))
55+
if wheels:
56+
print(f"Built wheel: {wheels[0]}")
57+
else:
58+
print("No wheel files found in dist/")
59+
60+
except subprocess.CalledProcessError as e:
61+
print(f"Failed to build wheel: {e}")
62+
sys.exit(1)
63+
64+
65+
if __name__ == "__main__":
66+
build_wheel()
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Copyright (c) 2025 by FlashInfer team.
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
"""
17+
18+
import os
19+
import sys
20+
import argparse
21+
from pathlib import Path
22+
23+
# Add parent directory to path to import flashinfer modules
24+
sys.path.insert(0, str(Path(__file__).parent.parent))
25+
26+
from flashinfer.artifacts import download_artifacts
27+
from flashinfer.jit.cubin_loader import FLASHINFER_CUBINS_REPOSITORY
28+
29+
30+
def main():
31+
parser = argparse.ArgumentParser(
32+
description="Download FlashInfer cubins from artifactory"
33+
)
34+
parser.add_argument(
35+
"--output-dir",
36+
"-o",
37+
type=str,
38+
default="flashinfer_cubin/cubins",
39+
help="Output directory for cubins (default: flashinfer_cubin/cubins)",
40+
)
41+
parser.add_argument(
42+
"--threads",
43+
"-t",
44+
type=int,
45+
default=4,
46+
help="Number of download threads (default: 4)",
47+
)
48+
parser.add_argument(
49+
"--repository",
50+
"-r",
51+
type=str,
52+
default=None,
53+
help="Override the cubins repository URL",
54+
)
55+
56+
args = parser.parse_args()
57+
58+
# Set environment variables to control download behavior
59+
if args.repository:
60+
os.environ["FLASHINFER_CUBINS_REPOSITORY"] = args.repository
61+
62+
os.environ["FLASHINFER_CUBIN_DIR"] = str(Path(args.output_dir).absolute())
63+
os.environ["FLASHINFER_CUBIN_DOWNLOAD_THREADS"] = str(args.threads)
64+
65+
print(f"Downloading cubins to {args.output_dir}")
66+
print(
67+
f"Repository: {os.environ.get('FLASHINFER_CUBINS_REPOSITORY', FLASHINFER_CUBINS_REPOSITORY)}"
68+
)
69+
70+
# Use the existing download_artifacts function
71+
try:
72+
download_artifacts()
73+
print("Download complete!")
74+
except Exception as e:
75+
print(f"Download failed: {e}")
76+
sys.exit(1)
77+
78+
79+
if __name__ == "__main__":
80+
main()
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
from pathlib import Path
19+
20+
# Get the path to the cubins directory within this package
21+
CUBIN_DIR = Path(__file__).parent / "cubins"
22+
23+
24+
def get_cubin_dir():
25+
"""Get the directory containing the cubins."""
26+
return str(CUBIN_DIR)
27+
28+
29+
def list_cubins():
30+
"""List all available cubin files."""
31+
if not CUBIN_DIR.exists():
32+
return []
33+
34+
cubins = []
35+
for root, _, files in os.walk(CUBIN_DIR):
36+
for file in files:
37+
if file.endswith(".cubin"):
38+
rel_path = os.path.relpath(os.path.join(root, file), CUBIN_DIR)
39+
cubins.append(rel_path)
40+
return sorted(cubins)
41+
42+
43+
def get_cubin_path(relative_path):
44+
"""Get the absolute path to a specific cubin file."""
45+
return str(CUBIN_DIR / relative_path)
46+
47+
48+
# Read version from build metadata or fallback to main flashinfer version.txt
49+
def _get_version():
50+
# First try to read from build metadata (for wheel distributions)
51+
try:
52+
from . import _build_meta
53+
54+
return _build_meta.__version__
55+
except ImportError:
56+
pass
57+
58+
# Fallback to reading from the main flashinfer version.txt (for development)
59+
version_file = Path(__file__).parent.parent.parent / "version.txt"
60+
if version_file.exists():
61+
with open(version_file, "r") as f:
62+
return f.read().strip()
63+
return "0.0.0"
64+
65+
66+
__version__ = _get_version()
67+
__all__ = ["get_cubin_dir", "list_cubins", "get_cubin_path", "CUBIN_DIR"]
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
[build-system]
2+
requires = ["setuptools>=61.0", "wheel", "requests", "filelock", "torch", "tqdm"] # NOTE(Zihao): we should remove torch once https://github.com/flashinfer-ai/flashinfer/pull/1641 merged
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "flashinfer-cubin"
7+
dynamic = ["version"]
8+
description = "Pre-compiled cubins for FlashInfer"
9+
readme = {text = "This package contains pre-compiled CUDA kernels (cubins) for FlashInfer. It provides all necessary cubin files downloaded from the FlashInfer artifactory.", content-type = "text/plain"}
10+
requires-python = ">=3.8"
11+
license = {text = "Apache-2.0"}
12+
authors = [
13+
{name = "FlashInfer team"},
14+
]
15+
maintainers = [
16+
{name = "FlashInfer team"},
17+
]
18+
classifiers = [
19+
"Development Status :: 4 - Beta",
20+
"Intended Audience :: Developers",
21+
"License :: OSI Approved :: Apache Software License",
22+
"Operating System :: OS Independent",
23+
"Programming Language :: Python :: 3",
24+
"Programming Language :: Python :: 3.8",
25+
"Programming Language :: Python :: 3.9",
26+
"Programming Language :: Python :: 3.10",
27+
"Programming Language :: Python :: 3.11",
28+
"Programming Language :: Python :: 3.12",
29+
"Topic :: Software Development :: Libraries :: Python Modules",
30+
]
31+
dependencies = [
32+
"requests",
33+
"filelock",
34+
]
35+
36+
[project.urls]
37+
Homepage = "https://github.com/flashinfer-ai/flashinfer"
38+
Documentation = "https://github.com/flashinfer-ai/flashinfer"
39+
Repository = "https://github.com/flashinfer-ai/flashinfer"
40+
"Issue Tracker" = "https://github.com/flashinfer-ai/flashinfer/issues"
41+
42+
[tool.setuptools]
43+
packages = ["flashinfer_cubin"]
44+
include-package-data = true
45+
46+
[tool.setuptools.dynamic]
47+
version = {attr = "flashinfer_cubin.__version__"}
48+
49+
[tool.setuptools.package-data]
50+
flashinfer_cubin = ["cubins/**/*"]
51+
52+
[tool.setuptools.cmdclass]
53+
build_py = "setup.DownloadAndBuildPy"
54+
sdist = "setup.CustomSdist"

0 commit comments

Comments
Β (0)