Skip to content

Commit 2931569

Browse files
authored
[Perf] Cache device property functions to avoid recomputation (#1824)
<!-- .github/pull_request_template.md --> ## 📌 Description Reduce overhead among cuda launches by caching device property functions. In small batches, we observed GPU bubbles which means for certain cases, CPU workload (e.g. cuda launch preparations) delays GPU kernel launches. In this PR, we simply cache device property functions to reduce the CPU workload overhead. **Before** <img width="2192" height="861" alt="Screenshot 2025-09-30 at 3 41 21 PM" src="https://github.com/user-attachments/assets/762d9334-da03-4359-91a1-8af9368a8bb5" /> **After** <img width="1910" height="231" alt="Screenshot 2025-09-30 at 3 54 54 PM" src="https://github.com/user-attachments/assets/9c5389d4-eae8-4722-b117-ba6e822f1c43" /> ## 🔍 Related Issues N/A ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes N/A --------- Signed-off-by: Jialin Ouyang <[email protected]>
1 parent d50cfbc commit 2931569

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

flashinfer/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ def set_log_level(lvl_str: str) -> None:
547547
get_logging_module().set_log_level(log_level_map[lvl_str].value)
548548

549549

550+
@functools.cache
550551
def device_support_pdl(device: torch.device) -> bool:
551552
if device.type != "cuda":
552553
return False
@@ -573,6 +574,7 @@ def round_up(x: int, y: int) -> int:
573574
return ceil_div(x, y) * y
574575

575576

577+
@functools.cache
576578
def get_device_sm_count(device: torch.device) -> int:
577579
return torch.cuda.get_device_properties(device).multi_processor_count
578580

0 commit comments

Comments
 (0)