Skip to content

Commit d429957

Browse files
committed
update checksum check
1 parent fd07386 commit d429957

File tree

6 files changed

+30
-24
lines changed

6 files changed

+30
-24
lines changed

flashinfer/artifacts.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232

3333
from .jit.cubin_loader import (
3434
FLASHINFER_CUBINS_REPOSITORY,
35-
download_file,
3635
safe_urljoin,
3736
FLASHINFER_CUBIN_DIR,
3837
download_file,
38+
verify_cubin,
3939
)
4040

4141

@@ -91,22 +91,11 @@ class ArtifactPath:
9191
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
9292

9393

94-
# TODO (jimmyzho): Should be deprecated except DEEPGEMM
9594
@dataclass(frozen=True)
9695
class MetaInfoHash:
97-
TRTLLM_GEN_FMHA: str = (
98-
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
99-
)
100-
TRTLLM_GEN_BMM: str = (
101-
"9490085267aed30a387bfff024a0605e1ca4d39dfe06a5abc159d7d7e129bdf4"
102-
)
10396
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
104-
TRTLLM_GEN_GEMM: str = (
105-
"7d8ef4e6d89b6990e3e90a3d3a21e96918824d819f8f897a9bfd994925b9ea67"
106-
)
10797

10898

109-
# @dataclass(frozen=True)
11099
class CheckSumHash:
111100
TRTLLM_GEN_FMHA: str = (
112101
"b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
@@ -197,12 +186,13 @@ def download_artifacts() -> None:
197186
with tqdm_logging_redirect(
198187
total=len(cubin_files), desc="Downloading cubins"
199188
) as pbar:
189+
200190
def update_pbar_cb(_) -> None:
201191
pbar.update(1)
202192

203193
with ThreadPoolExecutor(num_threads) as pool:
204194
futures = []
205-
for name in cubin_files:
195+
for name, _ in cubin_files:
206196
source = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, name)
207197
local_path = FLASHINFER_CUBIN_DIR / name
208198
# Ensure parent directory exists
@@ -219,6 +209,12 @@ def update_pbar_cb(_) -> None:
219209
if not all_success:
220210
raise RuntimeError("Failed to download cubins")
221211

212+
# Check checksums of all downloaded cubins
213+
for name, checksum in cubin_files:
214+
local_path = FLASHINFER_CUBIN_DIR / name
215+
if not verify_cubin(str(local_path), checksum):
216+
raise RuntimeError("Failed to download cubins: checksum mismatch")
217+
222218

223219
def get_artifacts_status() -> tuple[tuple[str, bool], ...]:
224220
"""

flashinfer/jit/attention/modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,15 +1568,15 @@ def gen_fmha_cutlass_sm100a_module(
15681568

15691569

15701570
def gen_trtllm_gen_fmha_module():
1571-
from ...artifacts import ArtifactPath, MetaInfoHash
1571+
from ...artifacts import ArtifactPath
15721572

15731573
include_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/include"
15741574
header_name = "flashInferMetaInfo"
15751575

15761576
# use `get_cubin` to get "flashinferMetaInfo.h"
15771577
metainfo = get_cubin(
15781578
f"{include_path}/{header_name}.h",
1579-
get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt"),
1579+
get_meta_hash(ArtifactPath.TRTLLM_GEN_FMHA),
15801580
)
15811581

15821582
# make sure "flashinferMetaInfo.h" is downloaded or cached

flashinfer/jit/cubin_loader.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,29 @@ def get_meta_hash(checksum_path: str) -> str:
132132
Load the file from local cache (checksums.txt)
133133
and get the hash of corresponding flashinferMetaInfo.h file
134134
"""
135-
with open(checksum_path, "r") as f:
135+
with open(checksum_path + "/checksums.txt", "r") as f:
136136
for line in f:
137137
sha256, filename = line.strip().split()
138138
if ".h" in filename:
139139
return sha256
140140
raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}")
141141

142142

143+
def verify_cubin(cubin_path: str, expected_sha256: str) -> bool:
144+
"""
145+
Verify the cubin file against the sha256 checksum.
146+
"""
147+
with open(cubin_path, "rb") as f:
148+
data = f.read()
149+
actual_sha256 = hashlib.sha256(data).hexdigest()
150+
if actual_sha256 != expected_sha256:
151+
logger.warning(
152+
f"sha256 mismatch (expected {expected_sha256} actual {actual_sha256}) for {cubin_path}"
153+
)
154+
return False
155+
return True
156+
157+
143158
def load_cubin(cubin_path: str, sha256: str) -> bytes:
144159
"""
145160
Load a cubin from the provide local path and

flashinfer/jit/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:
181181
# use `get_cubin` to get "flashinferMetaInfo.h"
182182
metainfo = get_cubin(
183183
f"{include_path}/{header_name}.h",
184-
get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_BMM}/checksums.txt"),
184+
get_meta_hash(ArtifactPath.TRTLLM_GEN_BMM),
185185
)
186186
# make sure "flashinferMetaInfo.h" is downloaded or cached
187187
assert metainfo, f"{header_name}.h not found"

flashinfer/jit/gemm/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def gen_trtllm_gen_gemm_module() -> JitSpec:
364364
# use `get_cubin` to get "flashinferMetaInfo.h"
365365
metainfo = get_cubin(
366366
f"{include_path}/{header_name}.h",
367-
get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_GEMM}/checksums.txt"),
367+
get_meta_hash(ArtifactPath.TRTLLM_GEN_GEMM),
368368
)
369369
# make sure "flashinferMetaInfo.h" is downloaded or cached
370370
assert metainfo, f"{header_name}.h not found"

tests/utils/test_load_cubin_compile_race_condition.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,12 @@ def worker_process(temp_dir):
3434
os.environ["FLASHINFER_CUBIN_DIR"] = temp_dir
3535

3636
# Import here to ensure FLASHINFER_CUBIN_DIR is set before module loads
37-
from flashinfer.artifacts import ArtifactPath, MetaInfoHash
38-
from flashinfer.jit.cubin_loader import get_cubin
37+
from flashinfer.artifacts import ArtifactPath
3938

4039
# Define the target file - same for all processes
4140
include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include"
4241
header_name = "flashinferMetaInfo"
4342

44-
# Use get_cubin to get "flashinferMetaInfo.h"
45-
# Note: all processes target the same file name
46-
metainfo = get_cubin(f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_BMM) # noqa: F841
47-
4843
# Read the file from FLASHINFER_CUBIN_DIR
4944
# NOTE(Zihao): instead of using metainfo, we directly read from the file path,
5045
# that aligns with how we compile the kernel.

0 commit comments

Comments
 (0)