Skip to content

Commit b4b9e22

Browse files
yzh119cyx-6
andauthored
feat: update flashinfer-cli (#1613)
<!-- .github/pull_request_template.md --> ## 📌 Description - add functionality of clear cache/cubin, and show config - using click library - show progress bar when downloading the cubins Usage: ``` python3 -m flashinfer show-config python3 -m flashinfer list-cubin python3 -m flashinfer clear-cache python3 -m flashinfer clear-cubin python3 -m flashinfer download-cubin ``` More features will be coming in the future. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Yaxing Cai <[email protected]>
1 parent 9d328b2 commit b4b9e22

File tree

4 files changed

+253
-33
lines changed

4 files changed

+253
-33
lines changed

flashinfer/__main__.py

Lines changed: 125 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,131 @@
1515
"""
1616

1717
# flashinfer-cli
18-
import argparse
18+
import click
19+
from tabulate import tabulate # type: ignore[import-untyped]
20+
21+
from .artifacts import (
22+
ArtifactPath,
23+
download_artifacts,
24+
clear_cubin,
25+
get_artifacts_status,
26+
)
27+
from .jit import clear_cache_dir
28+
from .jit.cubin_loader import FLASHINFER_CUBINS_REPOSITORY
29+
from .jit.env import FLASHINFER_CACHE_DIR, FLASHINFER_CUBIN_DIR
30+
from .jit.core import current_compilation_context
31+
from .jit.cpp_ext import get_cuda_path, get_cuda_version
32+
33+
34+
def _download_cubin():
35+
"""Helper function to download cubin"""
36+
try:
37+
download_artifacts()
38+
click.secho("✅ All cubin download tasks completed successfully.", fg="green")
39+
except Exception as e:
40+
click.secho(f"❌ Cubin download failed: {e}", fg="red")
41+
42+
43+
@click.group(invoke_without_command=True)
44+
@click.option(
45+
"--download-cubin", "download_cubin_flag", is_flag=True, help="Download artifacts"
46+
)
47+
@click.pass_context
48+
def cli(ctx, download_cubin_flag):
49+
"""FlashInfer CLI"""
50+
if download_cubin_flag:
51+
_download_cubin()
52+
elif ctx.invoked_subcommand is None:
53+
click.echo(ctx.get_help())
54+
55+
56+
# list of environment variables
57+
env_variables = {
58+
"FLASHINFER_CACHE_DIR": FLASHINFER_CACHE_DIR,
59+
"FLASHINFER_CUBIN_DIR": FLASHINFER_CUBIN_DIR,
60+
"CUDA_HOME": get_cuda_path(),
61+
"CUDA_VERSION": get_cuda_version(),
62+
"FLASHINFER_CUDA_ARCH_LIST": current_compilation_context.TARGET_CUDA_ARCHS,
63+
"FLASHINFER_CUBINS_REPOSITORY": FLASHINFER_CUBINS_REPOSITORY,
64+
}
65+
66+
67+
@cli.command("show-config")
68+
def show_config_cmd():
69+
"""Show configuration"""
70+
import torch
71+
72+
# Section: Torch Version Info
73+
click.secho("=== Torch Version Info ===", fg="yellow")
74+
click.secho("Torch version:", fg="magenta", nl=False)
75+
click.secho(f" {torch.__version__}", fg="cyan")
76+
click.secho("", fg="white")
77+
78+
# Section: Environment Variables
79+
click.secho("=== Environment Variables ===", fg="yellow")
80+
for name, value in env_variables.items():
81+
click.secho(f"{name}:", fg="magenta", nl=False)
82+
click.secho(f" {value}", fg="cyan")
83+
click.secho("", fg="white")
84+
85+
# Section: Artifact path
86+
click.secho("=== Artifact Path ===", fg="yellow")
87+
# list all artifact paths
88+
for name, path in ArtifactPath.__dict__.items():
89+
if not name.startswith("__"):
90+
click.secho(f"{name}:", fg="magenta", nl=False)
91+
click.secho(f" {path}", fg="cyan")
92+
click.secho("", fg="white")
93+
94+
# Section: Downloaded Cubins
95+
click.secho("=== Downloaded Cubins ===", fg="yellow")
96+
97+
status = get_artifacts_status()
98+
num_downloaded = sum(1 for _, _, exists in status if exists)
99+
total_cubins = len(status)
100+
101+
click.secho(f"Downloaded {num_downloaded}/{total_cubins} cubins", fg="cyan")
102+
103+
104+
@cli.command("list-cubins")
105+
def list_cubins_cmd():
106+
"""List downloaded cubins"""
107+
status = get_artifacts_status()
108+
table_data = []
109+
for name, extension, exists in status:
110+
status_str = "Downloaded" if exists else "Missing"
111+
color = "green" if exists else "red"
112+
table_data.append([f"{name}{extension}", click.style(status_str, fg=color)])
113+
114+
click.echo(tabulate(table_data, headers=["Cubin", "Status"], tablefmt="github"))
115+
click.secho("", fg="white")
116+
117+
118+
@cli.command("download-cubin")
119+
def download_cubin_cmd():
120+
"""Download artifacts"""
121+
_download_cubin()
122+
123+
124+
@cli.command("clear-cache")
125+
def clear_cache_cmd():
126+
"""Clear cache"""
127+
try:
128+
clear_cache_dir()
129+
click.secho("✅ Cache cleared successfully.", fg="green")
130+
except Exception as e:
131+
click.secho(f"❌ Cache clear failed: {e}", fg="red")
132+
133+
134+
@cli.command("clear-cubin")
135+
def clear_cubin_cmd():
136+
"""Clear cubin"""
137+
try:
138+
clear_cubin()
139+
click.secho("✅ Cubin cleared successfully.", fg="green")
140+
except Exception as e:
141+
click.secho(f"❌ Cubin clear failed: {e}", fg="red")
19142

20-
from .artifacts import download_artifacts
21143

22144
if __name__ == "__main__":
23-
parser = argparse.ArgumentParser("FlashInfer CLI")
24-
parser.add_argument(
25-
"--download-cubin", action="store_true", help="Download artifacts"
26-
)
27-
28-
args = parser.parse_args()
29-
30-
if args.download_cubin:
31-
if download_artifacts():
32-
print("✅ All cubin download tasks completed successfully.")
33-
else:
34-
print("❌ Some cubin download tasks failed.")
145+
cli()

flashinfer/artifacts.py

Lines changed: 113 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,70 @@
2020
from concurrent.futures import ThreadPoolExecutor, as_completed
2121

2222
import requests # type: ignore[import-untyped]
23+
import shutil
2324

2425
from .jit.core import logger
25-
from .jit.cubin_loader import FLASHINFER_CUBINS_REPOSITORY, get_cubin
26+
from .jit.cubin_loader import (
27+
FLASHINFER_CUBINS_REPOSITORY,
28+
get_cubin,
29+
FLASHINFER_CUBIN_DIR,
30+
)
31+
32+
33+
import logging
34+
from contextlib import contextmanager
35+
36+
37+
@contextmanager
38+
def temp_env_var(key, value):
39+
old_value = os.environ.get(key, None)
40+
os.environ[key] = value
41+
try:
42+
yield
43+
finally:
44+
if old_value is None:
45+
os.environ.pop(key, None)
46+
else:
47+
os.environ[key] = old_value
48+
49+
50+
@contextmanager
51+
def patch_logger_for_tqdm(logger):
52+
"""
53+
Context manager to patch the logger so that log messages are displayed using tqdm.write,
54+
preventing interference with tqdm progress bars.
55+
"""
56+
import tqdm
57+
58+
class TqdmLoggingHandler(logging.Handler):
59+
def emit(self, record):
60+
try:
61+
msg = self.format(record)
62+
tqdm.write(msg, end="\n")
63+
except Exception:
64+
self.handleError(record)
65+
66+
# Save original handlers and level
67+
original_handlers = logger.handlers[:]
68+
original_level = logger.level
69+
70+
# Remove all existing handlers to prevent duplicate output
71+
for h in original_handlers:
72+
logger.removeHandler(h)
73+
74+
# Add our tqdm-aware handler
75+
handler = TqdmLoggingHandler()
76+
handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
77+
logger.addHandler(handler)
78+
logger.setLevel(logging.INFO)
79+
try:
80+
yield
81+
finally:
82+
# Remove tqdm handler and restore original handlers and level
83+
logger.removeHandler(handler)
84+
for h in original_handlers:
85+
logger.addHandler(h)
86+
logger.setLevel(original_level)
2687

2788

2889
def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
@@ -72,11 +133,9 @@ class MetaInfoHash:
72133
)
73134

74135

75-
def download_artifacts() -> bool:
76-
env_backup = os.environ.get("FLASHINFER_CUBIN_CHECKSUM_DISABLED", None)
77-
os.environ["FLASHINFER_CUBIN_CHECKSUM_DISABLED"] = "1"
136+
def get_cubin_file_list():
78137
cubin_files = [
79-
(ArtifactPath.TRTLLM_GEN_FMHA + "flashInferMetaInfo", ".h"),
138+
(ArtifactPath.TRTLLM_GEN_FMHA + "include/flashInferMetaInfo", ".h"),
80139
(ArtifactPath.TRTLLM_GEN_GEMM + "include/flashinferMetaInfo", ".h"),
81140
(ArtifactPath.TRTLLM_GEN_BMM + "include/flashinferMetaInfo", ".h"),
82141
]
@@ -92,19 +151,54 @@ def download_artifacts() -> bool:
92151
FLASHINFER_CUBINS_REPOSITORY + "/" + kernel
93152
)
94153
]
95-
pool = ThreadPoolExecutor(4)
96-
futures = []
154+
return cubin_files
155+
156+
157+
def download_artifacts():
158+
import tqdm
159+
160+
with temp_env_var("FLASHINFER_CUBIN_CHECKSUM_DISABLED", "1"):
161+
cubin_files = get_cubin_file_list()
162+
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
163+
pool = ThreadPoolExecutor(num_threads)
164+
futures = []
165+
for name, extension in cubin_files:
166+
ret = pool.submit(get_cubin, name, "", extension)
167+
futures.append(ret)
168+
results = []
169+
with (
170+
patch_logger_for_tqdm(logger),
171+
tqdm(total=len(futures), desc="Downloading cubins") as pbar,
172+
):
173+
for ret in as_completed(futures):
174+
result = ret.result()
175+
results.append(result)
176+
pbar.update(1)
177+
all_success = all(results)
178+
if not all_success:
179+
raise RuntimeError("Failed to download cubins")
180+
181+
182+
def get_artifacts_status():
183+
"""
184+
Check which cubins are already downloaded and return (num_downloaded, total).
185+
Does not download any cubins.
186+
"""
187+
cubin_files = get_cubin_file_list()
188+
status = []
97189
for name, extension in cubin_files:
98-
ret = pool.submit(get_cubin, name, "", extension)
99-
futures.append(ret)
100-
results = []
101-
for ret in as_completed(futures):
102-
result = ret.result()
103-
results.append(result)
104-
all_success = all(results)
105-
if not env_backup:
106-
os.environ.pop("FLASHINFER_CUBIN_CHECKSUM_DISABLED")
190+
# get_cubin stores cubins in FLASHINFER_CUBIN_DIR with the same relative path
191+
# Remove any leading slashes from name
192+
rel_path = name.lstrip("/")
193+
local_path = os.path.join(FLASHINFER_CUBIN_DIR, rel_path)
194+
exists = os.path.isfile(local_path + extension)
195+
status.append((name, extension, exists))
196+
return status
197+
198+
199+
def clear_cubin():
200+
if os.path.exists(FLASHINFER_CUBIN_DIR):
201+
print(f"Clearing cubin directory: {FLASHINFER_CUBIN_DIR}")
202+
shutil.rmtree(FLASHINFER_CUBIN_DIR)
107203
else:
108-
os.environ["FLASHINFER_CUBIN_CHECKSUM_DISABLED"] = env_backup
109-
110-
return all_success
204+
print(f"Cubin directory does not exist: {FLASHINFER_CUBIN_DIR}")

flashinfer/jit/cpp_ext.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@
2222
from ..compilation_context import CompilationContext
2323

2424

25+
@functools.cache
26+
def get_cuda_path() -> str:
27+
if CUDA_HOME is None:
28+
# get output of "which nvcc"
29+
result = subprocess.run(["which", "nvcc"], capture_output=True)
30+
if result.returncode != 0:
31+
raise RuntimeError("Could not find nvcc")
32+
return result.stdout.decode("utf-8").strip()
33+
else:
34+
return CUDA_HOME
35+
36+
2537
@functools.cache
2638
def get_cuda_version() -> Version:
2739
if CUDA_HOME is None:

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def generate_build_meta(aot_build_meta: dict) -> None:
6060
"requests",
6161
"pynvml",
6262
"einops",
63+
"click",
64+
"tqdm",
65+
"tabulate",
6366
"packaging>=24.2",
6467
"nvidia-cudnn-frontend>=1.13.0",
6568
]

0 commit comments

Comments
 (0)