Skip to content

Commit 755beff

Browse files
authored
feat: add an env to allow specify cubins directory (#1462)
<!-- .github/pull_request_template.md --> ## 📌 Description Add a env `FLASHINFER_CUBIN_DIR` for cubins directory which gives more flexibility for aot deployment. If it's not specified, will fall back to `FLASHINFER_CACHE_DIR / "cubins"` which is what we currently have. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent cd33db6 commit 755beff

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

flashinfer/jit/cubin_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import filelock
2424

2525
from .core import logger
26-
from .env import FLASHINFER_CACHE_DIR
26+
from .env import FLASHINFER_CUBIN_DIR
2727

2828
# This is the storage path for the cubins, it can be replaced
2929
# with a local path for testing.
@@ -145,7 +145,7 @@ def get_cubin(name, sha256, file_extension=".cubin"):
145145
None on failure.
146146
"""
147147
cubin_fname = name + file_extension
148-
cubin_path = FLASHINFER_CACHE_DIR / "cubins" / cubin_fname
148+
cubin_path = FLASHINFER_CUBIN_DIR / cubin_fname
149149
cubin = load_cubin(cubin_path, sha256)
150150
if cubin:
151151
return cubin

flashinfer/jit/env.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
)
3131

3232
FLASHINFER_CACHE_DIR = FLASHINFER_BASE_DIR / ".cache" / "flashinfer"
33+
FLASHINFER_CUBIN_DIR = pathlib.Path(
34+
os.getenv("FLASHINFER_CUBIN_DIR", (FLASHINFER_CACHE_DIR / "cubins").as_posix())
35+
)
3336

3437

3538
def _get_workspace_dir_name() -> pathlib.Path:

0 commit comments

Comments
 (0)