From 352b3ad9d38f29cba8686ab779008af8d6c12582 Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Mon, 29 Sep 2025 23:22:29 +0000 Subject: [PATCH 01/10] cute dsl --- tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py b/tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py index 4e087e3c5f..b9ac6e6b0d 100644 --- a/tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py +++ b/tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py @@ -43,6 +43,10 @@ def create_mc_tensor(torch_tensor_cpu, dtype, leading_dim, is_dynamic_layout=Tru torch_symm_tensor.copy_(torch_tensor_cpu) symm = symm_mem.rendezvous(torch_symm_tensor, group=dist.group.WORLD.group_name) mc_ptr = symm.multicast_ptr + + if not mc_ptr: + raise ValueError("Multicast support is not available") + # create MC tensor memref cute_tensor_mc = from_dlpack( cutlass_torch.as_tensor(mc_ptr, torch_tensor_cpu.shape, torch_tensor_cpu.dtype), From 38d494099dff725fd251a282783384a1a196952c Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Tue, 23 Sep 2025 23:39:56 +0000 Subject: [PATCH 02/10] checksum check --- flashinfer/artifacts.py | 88 +++++++++++++++++++++++++++++------------ 1 file changed, 63 insertions(+), 25 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index fe4aeab60b..e0cf7e8042 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -23,6 +23,7 @@ from typing import Generator import requests # type: ignore[import-untyped] import shutil +import hashlib # Create logger for artifacts module to avoid circular import with jit.core logger = logging.getLogger("flashinfer.artifacts") @@ -35,6 +36,7 @@ download_file, safe_urljoin, FLASHINFER_CUBIN_DIR, + download_file, ) @@ -78,50 +80,88 @@ def get_available_cubin_files( return tuple() -@dataclass(frozen=True) class ArtifactPath: - TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen" + TRTLLM_GEN_FMHA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( - "e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802" + "a72d85b019dc125b9f711300cb989430f762f5a6/batched_gemm-145d1b1-9e1d49a/" ) TRTLLM_GEN_GEMM: str = ( - "037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e" + "a72d85b019dc125b9f711300cb989430f762f5a6/gemm-145d1b1-f91dc9e/" ) - CUDNN_SDPA: str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn" - DEEPGEMM: str = "51d730202c9eef782f06ecc950005331d85c5d4b/deep-gemm" - + CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/" + DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/" +# TODO: Should be deprecated @dataclass(frozen=True) class MetaInfoHash: TRTLLM_GEN_FMHA: str = ( "2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d" ) TRTLLM_GEN_BMM: str = ( - "c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34" + "9490085267aed30a387bfff024a0605e1ca4d39dfe06a5abc159d7d7e129bdf4" ) DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48" TRTLLM_GEN_GEMM: str = ( - "0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba" + "7d8ef4e6d89b6990e3e90a3d3a21e96918824d819f8f897a9bfd994925b9ea67" ) -def get_cubin_file_list() -> Generator[str, None, None]: - base = FLASHINFER_CUBINS_REPOSITORY +# @dataclass(frozen=True) +class CheckSumHash: + TRTLLM_GEN_FMHA: str = "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4" + TRTLLM_GEN_BMM: str = "efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd" + DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" + TRTLLM_GEN_GEMM: str = "e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a" - # The meta info header files first. - yield safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h") - yield safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h") - yield safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h") + map_checksums: [dict[str, str]] = { + safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA, + safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "checksums.txt"): TRTLLM_GEN_BMM, + safe_urljoin(ArtifactPath.DEEPGEMM, "checksums.txt"): DEEPGEMM, + safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "checksums.txt"): TRTLLM_GEN_GEMM, + } - # All the actual kernel cubin's. - for kernel in [ + +def get_checksums(subdirs): + checksums = {} + for subdir in subdirs: + uri = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, safe_urljoin(subdir, "checksums.txt")) + checksum_path = FLASHINFER_CUBIN_DIR / safe_urljoin(subdir, "checksums.txt") + download_file(uri, checksum_path) + with open(checksum_path, "r") as f: + for line in f: + sha256, filename = line.strip().split() + + # Distinguish between all meta info header files + if ".h" in filename: + filename = safe_urljoin(subdir, filename) + checksums[filename] = sha256 + return checksums + + +def get_subdir_file_list(): + base = FLASHINFER_CUBINS_REPOSITORY + + cubin_dirs = [ ArtifactPath.TRTLLM_GEN_FMHA, ArtifactPath.TRTLLM_GEN_BMM, ArtifactPath.TRTLLM_GEN_GEMM, ArtifactPath.DEEPGEMM, - ]: - for name in get_available_cubin_files(safe_urljoin(base, kernel)): - yield safe_urljoin(kernel, name) + ] + + # Get checksums of all files + checksums = get_checksums(cubin_dirs) + + # The meta info header files first. + yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")]) + yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")]) + yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")]) + + # All the actual kernel cubin's. + for cubin_dir in cubin_dirs: + checksum_path = safe_urljoin(cubin_dir, "checksums.txt") + yield (checksum_path, CheckSumHash.map_checksums[checksum_path]) + for name in get_available_cubin_files(safe_urljoin(base, cubin_dir)): + yield (safe_urljoin(cubin_dir, name), checksums[name]) def download_artifacts() -> None: @@ -130,13 +170,11 @@ def download_artifacts() -> None: # use a shared session to make use of HTTP keep-alive and reuse of # HTTPS connections. session = requests.Session() - - cubin_files = list(get_cubin_file_list()) + cubin_files = list(get_subdir_file_list()) num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4")) with tqdm_logging_redirect( total=len(cubin_files), desc="Downloading cubins" ) as pbar: - def update_pbar_cb(_) -> None: pbar.update(1) @@ -165,7 +203,7 @@ def get_artifacts_status() -> tuple[tuple[str, bool], ...]: Check which cubins are already downloaded and return (num_downloaded, total). Does not download any cubins. """ - cubin_files = get_cubin_file_list() + cubin_files = get_subdir_file_list() def _check_file_status(file_name: str) -> tuple[str, bool]: # get_cubin stores cubins in FLASHINFER_CUBIN_DIR with the same relative path @@ -174,7 +212,7 @@ def _check_file_status(file_name: str) -> tuple[str, bool]: exists = os.path.isfile(local_path) return (file_name, exists) - return tuple(_check_file_status(file_name) for file_name in cubin_files) + return tuple(_check_file_status(file_name) for file_name, _ in cubin_files) def clear_cubin(): From 29a1484ea8dffbb5488b6b8693226ae7c76a0977 Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Wed, 24 Sep 2025 23:44:54 +0000 Subject: [PATCH 03/10] Trigger CI pipeline From 20222e3b438dda901611f6afb38548506eca1420 Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Thu, 25 Sep 2025 21:05:22 +0000 Subject: [PATCH 04/10] trigger cicd From 6d0f0ca040e3a18b0f7dd288f783391aa9de221c Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Tue, 30 Sep 2025 03:48:56 +0000 Subject: [PATCH 05/10] fix --- tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py b/tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py index b9ac6e6b0d..4e087e3c5f 100644 --- a/tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py +++ b/tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py @@ -43,10 +43,6 @@ def create_mc_tensor(torch_tensor_cpu, dtype, leading_dim, is_dynamic_layout=Tru torch_symm_tensor.copy_(torch_tensor_cpu) symm = symm_mem.rendezvous(torch_symm_tensor, group=dist.group.WORLD.group_name) mc_ptr = symm.multicast_ptr - - if not mc_ptr: - raise ValueError("Multicast support is not available") - # create MC tensor memref cute_tensor_mc = from_dlpack( cutlass_torch.as_tensor(mc_ptr, torch_tensor_cpu.shape, torch_tensor_cpu.dtype), From c0b83c7c5109e8ea4abf63fcaafd3b41f93f865a Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Mon, 6 Oct 2025 22:48:55 +0000 Subject: [PATCH 06/10] fix types --- flashinfer/artifacts.py | 45 +++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index e0cf7e8042..011d0966ff 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -23,7 +23,6 @@ from typing import Generator import requests # type: ignore[import-untyped] import shutil -import hashlib # Create logger for artifacts module to avoid circular import with jit.core logger = logging.getLogger("flashinfer.artifacts") @@ -91,6 +90,7 @@ class ArtifactPath: CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/" DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/" + # TODO: Should be deprecated @dataclass(frozen=True) class MetaInfoHash: @@ -108,12 +108,18 @@ class MetaInfoHash: # @dataclass(frozen=True) class CheckSumHash: - TRTLLM_GEN_FMHA: str = "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4" - TRTLLM_GEN_BMM: str = "efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd" - DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" - TRTLLM_GEN_GEMM: str = "e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a" + TRTLLM_GEN_FMHA: str = ( + "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4" + ) + TRTLLM_GEN_BMM: str = ( + "efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd" + ) + DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" + TRTLLM_GEN_GEMM: str = ( + "e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a" + ) - map_checksums: [dict[str, str]] = { + map_checksums: dict[str, str] = { safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA, safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "checksums.txt"): TRTLLM_GEN_BMM, safe_urljoin(ArtifactPath.DEEPGEMM, "checksums.txt"): DEEPGEMM, @@ -124,7 +130,9 @@ class CheckSumHash: def get_checksums(subdirs): checksums = {} for subdir in subdirs: - uri = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, safe_urljoin(subdir, "checksums.txt")) + uri = safe_urljoin( + FLASHINFER_CUBINS_REPOSITORY, safe_urljoin(subdir, "checksums.txt") + ) checksum_path = FLASHINFER_CUBIN_DIR / safe_urljoin(subdir, "checksums.txt") download_file(uri, checksum_path) with open(checksum_path, "r") as f: @@ -138,7 +146,7 @@ def get_checksums(subdirs): return checksums -def get_subdir_file_list(): +def get_subdir_file_list() -> Generator[tuple[str, str], None, None]: base = FLASHINFER_CUBINS_REPOSITORY cubin_dirs = [ @@ -152,9 +160,24 @@ def get_subdir_file_list(): checksums = get_checksums(cubin_dirs) # The meta info header files first. - yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")]) - yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")]) - yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")]) + yield ( + safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h"), + checksums[ + safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h") + ], + ) + yield ( + safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h"), + checksums[ + safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h") + ], + ) + yield ( + safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h"), + checksums[ + safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h") + ], + ) # All the actual kernel cubin's. for cubin_dir in cubin_dirs: From bfad066df12849d46ed0fc10ba4cd3743c17f303 Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Mon, 6 Oct 2025 22:51:22 +0000 Subject: [PATCH 07/10] style --- flashinfer/artifacts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 011d0966ff..3f6cedaa0c 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -118,7 +118,6 @@ class CheckSumHash: TRTLLM_GEN_GEMM: str = ( "e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a" ) - map_checksums: dict[str, str] = { safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA, safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "checksums.txt"): TRTLLM_GEN_BMM, From fd07386605851dd9278f4f87dd0bb8870dd25c4c Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Wed, 8 Oct 2025 20:21:18 +0000 Subject: [PATCH 08/10] deprecate metainfohash --- flashinfer/artifacts.py | 2 +- flashinfer/jit/attention/modules.py | 7 ++++--- flashinfer/jit/cubin_loader.py | 13 +++++++++++++ flashinfer/jit/fused_moe.py | 9 ++++++--- flashinfer/jit/gemm/core.py | 6 +++--- 5 files changed, 27 insertions(+), 10 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 3f6cedaa0c..0648bb3fd0 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -91,7 +91,7 @@ class ArtifactPath: DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/" -# TODO: Should be deprecated +# TODO (jimmyzho): Should be deprecated except DEEPGEMM @dataclass(frozen=True) class MetaInfoHash: TRTLLM_GEN_FMHA: str = ( diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index 3fb4a289d3..ff25928a90 100644 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -28,7 +28,7 @@ sm90a_nvcc_flags, current_compilation_context, ) -from ...jit.cubin_loader import get_cubin +from ...jit.cubin_loader import get_cubin, get_meta_hash from ..utils import ( dtype_map, filename_safe_dtype_map, @@ -1575,7 +1575,8 @@ def gen_trtllm_gen_fmha_module(): # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( - f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_FMHA + f"{include_path}/{header_name}.h", + get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt"), ) # make sure "flashinferMetaInfo.h" is downloaded or cached @@ -1591,7 +1592,7 @@ def gen_trtllm_gen_fmha_module(): extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path], extra_cuda_cflags=[ f'-DTLLM_GEN_FMHA_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_FMHA}\\"', - f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{MetaInfoHash.TRTLLM_GEN_FMHA}\\"', + f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt")}\\"', ], ) diff --git a/flashinfer/jit/cubin_loader.py b/flashinfer/jit/cubin_loader.py index 286aefabf6..0c2c8a06cd 100644 --- a/flashinfer/jit/cubin_loader.py +++ b/flashinfer/jit/cubin_loader.py @@ -127,6 +127,19 @@ def download_file( return False +def get_meta_hash(checksum_path: str) -> str: + """ + Load the file from local cache (checksums.txt) + and get the hash of corresponding flashinferMetaInfo.h file + """ + with open(checksum_path, "r") as f: + for line in f: + sha256, filename = line.strip().split() + if ".h" in filename: + return sha256 + raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}") + + def load_cubin(cubin_path: str, sha256: str) -> bytes: """ Load a cubin from the provide local path and diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index 56c7d2e751..e83a6cdbf9 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -17,10 +17,10 @@ from typing import List from . import env as jit_env -from ..artifacts import ArtifactPath, MetaInfoHash +from ..artifacts import ArtifactPath from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags from .cpp_ext import is_cuda_version_at_least -from .cubin_loader import get_cubin +from .cubin_loader import get_cubin, get_meta_hash from .gemm.cutlass.generate_kernels import generate_gemm_operations @@ -179,7 +179,10 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec: header_name = "flashinferMetaInfo" # use `get_cubin` to get "flashinferMetaInfo.h" - metainfo = get_cubin(f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_BMM) + metainfo = get_cubin( + f"{include_path}/{header_name}.h", + get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_BMM}/checksums.txt"), + ) # make sure "flashinferMetaInfo.h" is downloaded or cached assert metainfo, f"{header_name}.h not found" diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index 9f55cb0920..a494122852 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -20,7 +20,7 @@ import jinja2 import torch -from ...artifacts import ArtifactPath, MetaInfoHash +from ...artifacts import ArtifactPath from .. import env as jit_env from ..core import ( JitSpec, @@ -30,7 +30,7 @@ sm100f_nvcc_flags, current_compilation_context, ) -from ..cubin_loader import get_cubin +from ..cubin_loader import get_cubin, get_meta_hash from ..utils import dtype_cutlass_map, filename_safe_dtype_map, write_if_different @@ -364,7 +364,7 @@ def gen_trtllm_gen_gemm_module() -> JitSpec: # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( f"{include_path}/{header_name}.h", - MetaInfoHash.TRTLLM_GEN_GEMM, + get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_GEMM}/checksums.txt"), ) # make sure "flashinferMetaInfo.h" is downloaded or cached assert metainfo, f"{header_name}.h not found" From d429957eefd076d6a37fa9bb9526a58eaeff57b5 Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Thu, 9 Oct 2025 06:33:50 +0000 Subject: [PATCH 09/10] update checksum check --- flashinfer/artifacts.py | 22 ++++++++----------- flashinfer/jit/attention/modules.py | 4 ++-- flashinfer/jit/cubin_loader.py | 17 +++++++++++++- flashinfer/jit/fused_moe.py | 2 +- flashinfer/jit/gemm/core.py | 2 +- .../test_load_cubin_compile_race_condition.py | 7 +----- 6 files changed, 30 insertions(+), 24 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 0648bb3fd0..14f2894e84 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -32,10 +32,10 @@ from .jit.cubin_loader import ( FLASHINFER_CUBINS_REPOSITORY, - download_file, safe_urljoin, FLASHINFER_CUBIN_DIR, download_file, + verify_cubin, ) @@ -91,22 +91,11 @@ class ArtifactPath: DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/" -# TODO (jimmyzho): Should be deprecated except DEEPGEMM @dataclass(frozen=True) class MetaInfoHash: - TRTLLM_GEN_FMHA: str = ( - "2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d" - ) - TRTLLM_GEN_BMM: str = ( - "9490085267aed30a387bfff024a0605e1ca4d39dfe06a5abc159d7d7e129bdf4" - ) DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48" - TRTLLM_GEN_GEMM: str = ( - "7d8ef4e6d89b6990e3e90a3d3a21e96918824d819f8f897a9bfd994925b9ea67" - ) -# @dataclass(frozen=True) class CheckSumHash: TRTLLM_GEN_FMHA: str = ( "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4" @@ -197,12 +186,13 @@ def download_artifacts() -> None: with tqdm_logging_redirect( total=len(cubin_files), desc="Downloading cubins" ) as pbar: + def update_pbar_cb(_) -> None: pbar.update(1) with ThreadPoolExecutor(num_threads) as pool: futures = [] - for name in cubin_files: + for name, _ in cubin_files: source = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, name) local_path = FLASHINFER_CUBIN_DIR / name # Ensure parent directory exists @@ -219,6 +209,12 @@ def update_pbar_cb(_) -> None: if not all_success: raise RuntimeError("Failed to download cubins") + # Check checksums of all downloaded cubins + for name, checksum in cubin_files: + local_path = FLASHINFER_CUBIN_DIR / name + if not verify_cubin(str(local_path), checksum): + raise RuntimeError("Failed to download cubins: checksum mismatch") + def get_artifacts_status() -> tuple[tuple[str, bool], ...]: """ diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index ff25928a90..cf468534c1 100644 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -1568,7 +1568,7 @@ def gen_fmha_cutlass_sm100a_module( def gen_trtllm_gen_fmha_module(): - from ...artifacts import ArtifactPath, MetaInfoHash + from ...artifacts import ArtifactPath include_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/include" header_name = "flashInferMetaInfo" @@ -1576,7 +1576,7 @@ def gen_trtllm_gen_fmha_module(): # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( f"{include_path}/{header_name}.h", - get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt"), + get_meta_hash(ArtifactPath.TRTLLM_GEN_FMHA), ) # make sure "flashinferMetaInfo.h" is downloaded or cached diff --git a/flashinfer/jit/cubin_loader.py b/flashinfer/jit/cubin_loader.py index 0c2c8a06cd..61d298caed 100644 --- a/flashinfer/jit/cubin_loader.py +++ b/flashinfer/jit/cubin_loader.py @@ -132,7 +132,7 @@ def get_meta_hash(checksum_path: str) -> str: Load the file from local cache (checksums.txt) and get the hash of corresponding flashinferMetaInfo.h file """ - with open(checksum_path, "r") as f: + with open(checksum_path + "/checksums.txt", "r") as f: for line in f: sha256, filename = line.strip().split() if ".h" in filename: @@ -140,6 +140,21 @@ def get_meta_hash(checksum_path: str) -> str: raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}") +def verify_cubin(cubin_path: str, expected_sha256: str) -> bool: + """ + Verify the cubin file against the sha256 checksum. + """ + with open(cubin_path, "rb") as f: + data = f.read() + actual_sha256 = hashlib.sha256(data).hexdigest() + if actual_sha256 != expected_sha256: + logger.warning( + f"sha256 mismatch (expected {expected_sha256} actual {actual_sha256}) for {cubin_path}" + ) + return False + return True + + def load_cubin(cubin_path: str, sha256: str) -> bytes: """ Load a cubin from the provide local path and diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index e83a6cdbf9..77e03d4388 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -181,7 +181,7 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec: # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( f"{include_path}/{header_name}.h", - get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_BMM}/checksums.txt"), + get_meta_hash(ArtifactPath.TRTLLM_GEN_BMM), ) # make sure "flashinferMetaInfo.h" is downloaded or cached assert metainfo, f"{header_name}.h not found" diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index a494122852..7446d92ffb 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -364,7 +364,7 @@ def gen_trtllm_gen_gemm_module() -> JitSpec: # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( f"{include_path}/{header_name}.h", - get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_GEMM}/checksums.txt"), + get_meta_hash(ArtifactPath.TRTLLM_GEN_GEMM), ) # make sure "flashinferMetaInfo.h" is downloaded or cached assert metainfo, f"{header_name}.h not found" diff --git a/tests/utils/test_load_cubin_compile_race_condition.py b/tests/utils/test_load_cubin_compile_race_condition.py index 29b2165a54..23e9975706 100644 --- a/tests/utils/test_load_cubin_compile_race_condition.py +++ b/tests/utils/test_load_cubin_compile_race_condition.py @@ -34,17 +34,12 @@ def worker_process(temp_dir): os.environ["FLASHINFER_CUBIN_DIR"] = temp_dir # Import here to ensure FLASHINFER_CUBIN_DIR is set before module loads - from flashinfer.artifacts import ArtifactPath, MetaInfoHash - from flashinfer.jit.cubin_loader import get_cubin + from flashinfer.artifacts import ArtifactPath # Define the target file - same for all processes include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include" header_name = "flashinferMetaInfo" - # Use get_cubin to get "flashinferMetaInfo.h" - # Note: all processes target the same file name - metainfo = get_cubin(f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_BMM) # noqa: F841 - # Read the file from FLASHINFER_CUBIN_DIR # NOTE(Zihao): instead of using metainfo, we directly read from the file path, # that aligns with how we compile the kernel. From 12dc8809d192db7293c9a20581c763b169b1453a Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Thu, 9 Oct 2025 16:19:51 +0000 Subject: [PATCH 10/10] pathfix --- flashinfer/jit/cubin_loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flashinfer/jit/cubin_loader.py b/flashinfer/jit/cubin_loader.py index 61d298caed..4d97555b63 100644 --- a/flashinfer/jit/cubin_loader.py +++ b/flashinfer/jit/cubin_loader.py @@ -132,7 +132,8 @@ def get_meta_hash(checksum_path: str) -> str: Load the file from local cache (checksums.txt) and get the hash of corresponding flashinferMetaInfo.h file """ - with open(checksum_path + "/checksums.txt", "r") as f: + local_path = FLASHINFER_CUBIN_DIR / safe_urljoin(checksum_path, "checksums.txt") + with open(local_path, "r") as f: for line in f: sha256, filename = line.strip().split() if ".h" in filename: