diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index fe4aeab60b..14f2894e84 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -32,9 +32,10 @@ from .jit.cubin_loader import ( FLASHINFER_CUBINS_REPOSITORY, - download_file, safe_urljoin, FLASHINFER_CUBIN_DIR, + download_file, + verify_cubin, ) @@ -78,50 +79,100 @@ def get_available_cubin_files( return tuple() -@dataclass(frozen=True) class ArtifactPath: - TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen" + TRTLLM_GEN_FMHA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( - "e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802" + "a72d85b019dc125b9f711300cb989430f762f5a6/batched_gemm-145d1b1-9e1d49a/" ) TRTLLM_GEN_GEMM: str = ( - "037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e" + "a72d85b019dc125b9f711300cb989430f762f5a6/gemm-145d1b1-f91dc9e/" ) - CUDNN_SDPA: str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn" - DEEPGEMM: str = "51d730202c9eef782f06ecc950005331d85c5d4b/deep-gemm" + CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/" + DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/" @dataclass(frozen=True) class MetaInfoHash: + DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48" + + +class CheckSumHash: TRTLLM_GEN_FMHA: str = ( - "2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d" + "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4" ) TRTLLM_GEN_BMM: str = ( - "c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34" + "efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd" ) - DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48" + DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" TRTLLM_GEN_GEMM: str = ( - "0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba" + "e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a" ) - - -def get_cubin_file_list() -> Generator[str, None, None]: + map_checksums: dict[str, str] = { + safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA, + safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "checksums.txt"): TRTLLM_GEN_BMM, + safe_urljoin(ArtifactPath.DEEPGEMM, "checksums.txt"): DEEPGEMM, + safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "checksums.txt"): TRTLLM_GEN_GEMM, + } + + +def get_checksums(subdirs): + checksums = {} + for subdir in subdirs: + uri = safe_urljoin( + FLASHINFER_CUBINS_REPOSITORY, safe_urljoin(subdir, "checksums.txt") + ) + checksum_path = FLASHINFER_CUBIN_DIR / safe_urljoin(subdir, "checksums.txt") + download_file(uri, checksum_path) + with open(checksum_path, "r") as f: + for line in f: + sha256, filename = line.strip().split() + + # Distinguish between all meta info header files + if ".h" in filename: + filename = safe_urljoin(subdir, filename) + checksums[filename] = sha256 + return checksums + + +def get_subdir_file_list() -> Generator[tuple[str, str], None, None]: base = FLASHINFER_CUBINS_REPOSITORY - # The meta info header files first. - yield safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h") - yield safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h") - yield safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h") - - # All the actual kernel cubin's. - for kernel in [ + cubin_dirs = [ ArtifactPath.TRTLLM_GEN_FMHA, ArtifactPath.TRTLLM_GEN_BMM, ArtifactPath.TRTLLM_GEN_GEMM, ArtifactPath.DEEPGEMM, - ]: - for name in get_available_cubin_files(safe_urljoin(base, kernel)): - yield safe_urljoin(kernel, name) + ] + + # Get checksums of all files + checksums = get_checksums(cubin_dirs) + + # The meta info header files first. + yield ( + safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h"), + checksums[ + safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h") + ], + ) + yield ( + safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h"), + checksums[ + safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h") + ], + ) + yield ( + safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h"), + checksums[ + safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h") + ], + ) + + # All the actual kernel cubin's. + for cubin_dir in cubin_dirs: + checksum_path = safe_urljoin(cubin_dir, "checksums.txt") + yield (checksum_path, CheckSumHash.map_checksums[checksum_path]) + for name in get_available_cubin_files(safe_urljoin(base, cubin_dir)): + yield (safe_urljoin(cubin_dir, name), checksums[name]) def download_artifacts() -> None: @@ -130,8 +181,7 @@ def download_artifacts() -> None: # use a shared session to make use of HTTP keep-alive and reuse of # HTTPS connections. session = requests.Session() - - cubin_files = list(get_cubin_file_list()) + cubin_files = list(get_subdir_file_list()) num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4")) with tqdm_logging_redirect( total=len(cubin_files), desc="Downloading cubins" @@ -142,7 +192,7 @@ def update_pbar_cb(_) -> None: with ThreadPoolExecutor(num_threads) as pool: futures = [] - for name in cubin_files: + for name, _ in cubin_files: source = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, name) local_path = FLASHINFER_CUBIN_DIR / name # Ensure parent directory exists @@ -159,13 +209,19 @@ def update_pbar_cb(_) -> None: if not all_success: raise RuntimeError("Failed to download cubins") + # Check checksums of all downloaded cubins + for name, checksum in cubin_files: + local_path = FLASHINFER_CUBIN_DIR / name + if not verify_cubin(str(local_path), checksum): + raise RuntimeError("Failed to download cubins: checksum mismatch") + def get_artifacts_status() -> tuple[tuple[str, bool], ...]: """ Check which cubins are already downloaded and return (num_downloaded, total). Does not download any cubins. """ - cubin_files = get_cubin_file_list() + cubin_files = get_subdir_file_list() def _check_file_status(file_name: str) -> tuple[str, bool]: # get_cubin stores cubins in FLASHINFER_CUBIN_DIR with the same relative path @@ -174,7 +230,7 @@ def _check_file_status(file_name: str) -> tuple[str, bool]: exists = os.path.isfile(local_path) return (file_name, exists) - return tuple(_check_file_status(file_name) for file_name in cubin_files) + return tuple(_check_file_status(file_name) for file_name, _ in cubin_files) def clear_cubin(): diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index 3fb4a289d3..cf468534c1 100644 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -28,7 +28,7 @@ sm90a_nvcc_flags, current_compilation_context, ) -from ...jit.cubin_loader import get_cubin +from ...jit.cubin_loader import get_cubin, get_meta_hash from ..utils import ( dtype_map, filename_safe_dtype_map, @@ -1568,14 +1568,15 @@ def gen_fmha_cutlass_sm100a_module( def gen_trtllm_gen_fmha_module(): - from ...artifacts import ArtifactPath, MetaInfoHash + from ...artifacts import ArtifactPath include_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/include" header_name = "flashInferMetaInfo" # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( - f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_FMHA + f"{include_path}/{header_name}.h", + get_meta_hash(ArtifactPath.TRTLLM_GEN_FMHA), ) # make sure "flashinferMetaInfo.h" is downloaded or cached @@ -1591,7 +1592,7 @@ def gen_trtllm_gen_fmha_module(): extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path], extra_cuda_cflags=[ f'-DTLLM_GEN_FMHA_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_FMHA}\\"', - f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{MetaInfoHash.TRTLLM_GEN_FMHA}\\"', + f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt")}\\"', ], ) diff --git a/flashinfer/jit/cubin_loader.py b/flashinfer/jit/cubin_loader.py index 286aefabf6..4d97555b63 100644 --- a/flashinfer/jit/cubin_loader.py +++ b/flashinfer/jit/cubin_loader.py @@ -127,6 +127,35 @@ def download_file( return False +def get_meta_hash(checksum_path: str) -> str: + """ + Load the file from local cache (checksums.txt) + and get the hash of corresponding flashinferMetaInfo.h file + """ + local_path = FLASHINFER_CUBIN_DIR / safe_urljoin(checksum_path, "checksums.txt") + with open(local_path, "r") as f: + for line in f: + sha256, filename = line.strip().split() + if ".h" in filename: + return sha256 + raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}") + + +def verify_cubin(cubin_path: str, expected_sha256: str) -> bool: + """ + Verify the cubin file against the sha256 checksum. + """ + with open(cubin_path, "rb") as f: + data = f.read() + actual_sha256 = hashlib.sha256(data).hexdigest() + if actual_sha256 != expected_sha256: + logger.warning( + f"sha256 mismatch (expected {expected_sha256} actual {actual_sha256}) for {cubin_path}" + ) + return False + return True + + def load_cubin(cubin_path: str, sha256: str) -> bytes: """ Load a cubin from the provide local path and diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index 56c7d2e751..77e03d4388 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -17,10 +17,10 @@ from typing import List from . import env as jit_env -from ..artifacts import ArtifactPath, MetaInfoHash +from ..artifacts import ArtifactPath from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags from .cpp_ext import is_cuda_version_at_least -from .cubin_loader import get_cubin +from .cubin_loader import get_cubin, get_meta_hash from .gemm.cutlass.generate_kernels import generate_gemm_operations @@ -179,7 +179,10 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec: header_name = "flashinferMetaInfo" # use `get_cubin` to get "flashinferMetaInfo.h" - metainfo = get_cubin(f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_BMM) + metainfo = get_cubin( + f"{include_path}/{header_name}.h", + get_meta_hash(ArtifactPath.TRTLLM_GEN_BMM), + ) # make sure "flashinferMetaInfo.h" is downloaded or cached assert metainfo, f"{header_name}.h not found" diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index 9f55cb0920..7446d92ffb 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -20,7 +20,7 @@ import jinja2 import torch -from ...artifacts import ArtifactPath, MetaInfoHash +from ...artifacts import ArtifactPath from .. import env as jit_env from ..core import ( JitSpec, @@ -30,7 +30,7 @@ sm100f_nvcc_flags, current_compilation_context, ) -from ..cubin_loader import get_cubin +from ..cubin_loader import get_cubin, get_meta_hash from ..utils import dtype_cutlass_map, filename_safe_dtype_map, write_if_different @@ -364,7 +364,7 @@ def gen_trtllm_gen_gemm_module() -> JitSpec: # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( f"{include_path}/{header_name}.h", - MetaInfoHash.TRTLLM_GEN_GEMM, + get_meta_hash(ArtifactPath.TRTLLM_GEN_GEMM), ) # make sure "flashinferMetaInfo.h" is downloaded or cached assert metainfo, f"{header_name}.h not found" diff --git a/tests/utils/test_load_cubin_compile_race_condition.py b/tests/utils/test_load_cubin_compile_race_condition.py index 29b2165a54..23e9975706 100644 --- a/tests/utils/test_load_cubin_compile_race_condition.py +++ b/tests/utils/test_load_cubin_compile_race_condition.py @@ -34,17 +34,12 @@ def worker_process(temp_dir): os.environ["FLASHINFER_CUBIN_DIR"] = temp_dir # Import here to ensure FLASHINFER_CUBIN_DIR is set before module loads - from flashinfer.artifacts import ArtifactPath, MetaInfoHash - from flashinfer.jit.cubin_loader import get_cubin + from flashinfer.artifacts import ArtifactPath # Define the target file - same for all processes include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include" header_name = "flashinferMetaInfo" - # Use get_cubin to get "flashinferMetaInfo.h" - # Note: all processes target the same file name - metainfo = get_cubin(f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_BMM) # noqa: F841 - # Read the file from FLASHINFER_CUBIN_DIR # NOTE(Zihao): instead of using metainfo, we directly read from the file path, # that aligns with how we compile the kernel.