Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 85 additions & 29 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@

from .jit.cubin_loader import (
FLASHINFER_CUBINS_REPOSITORY,
download_file,
safe_urljoin,
FLASHINFER_CUBIN_DIR,
download_file,
verify_cubin,
)


Expand Down Expand Up @@ -78,50 +79,100 @@ def get_available_cubin_files(
return tuple()


@dataclass(frozen=True)
class ArtifactPath:
TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen"
TRTLLM_GEN_FMHA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802"
"a72d85b019dc125b9f711300cb989430f762f5a6/batched_gemm-145d1b1-9e1d49a/"
)
TRTLLM_GEN_GEMM: str = (
"037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e"
"a72d85b019dc125b9f711300cb989430f762f5a6/gemm-145d1b1-f91dc9e/"
)
CUDNN_SDPA: str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn"
DEEPGEMM: str = "51d730202c9eef782f06ecc950005331d85c5d4b/deep-gemm"
CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"


@dataclass(frozen=True)
class MetaInfoHash:
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"


class CheckSumHash:
TRTLLM_GEN_FMHA: str = (
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
"b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
)
TRTLLM_GEN_BMM: str = (
"c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34"
"efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd"
)
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
TRTLLM_GEN_GEMM: str = (
"0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba"
"e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a"
)


def get_cubin_file_list() -> Generator[str, None, None]:
map_checksums: dict[str, str] = {
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA,
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "checksums.txt"): TRTLLM_GEN_BMM,
safe_urljoin(ArtifactPath.DEEPGEMM, "checksums.txt"): DEEPGEMM,
safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "checksums.txt"): TRTLLM_GEN_GEMM,
}


def get_checksums(subdirs):
checksums = {}
for subdir in subdirs:
uri = safe_urljoin(
FLASHINFER_CUBINS_REPOSITORY, safe_urljoin(subdir, "checksums.txt")
)
checksum_path = FLASHINFER_CUBIN_DIR / safe_urljoin(subdir, "checksums.txt")
download_file(uri, checksum_path)
with open(checksum_path, "r") as f:
for line in f:
sha256, filename = line.strip().split()

# Distinguish between all meta info header files
if ".h" in filename:
filename = safe_urljoin(subdir, filename)
checksums[filename] = sha256
return checksums


def get_subdir_file_list() -> Generator[tuple[str, str], None, None]:
base = FLASHINFER_CUBINS_REPOSITORY

# The meta info header files first.
yield safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")
yield safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")
yield safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")

# All the actual kernel cubin's.
for kernel in [
cubin_dirs = [
ArtifactPath.TRTLLM_GEN_FMHA,
ArtifactPath.TRTLLM_GEN_BMM,
ArtifactPath.TRTLLM_GEN_GEMM,
ArtifactPath.DEEPGEMM,
]:
for name in get_available_cubin_files(safe_urljoin(base, kernel)):
yield safe_urljoin(kernel, name)
]

# Get checksums of all files
checksums = get_checksums(cubin_dirs)

# The meta info header files first.
yield (
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h"),
checksums[
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")
],
)
yield (
safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h"),
checksums[
safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")
],
)
yield (
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h"),
checksums[
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")
],
)

# All the actual kernel cubin's.
for cubin_dir in cubin_dirs:
checksum_path = safe_urljoin(cubin_dir, "checksums.txt")
yield (checksum_path, CheckSumHash.map_checksums[checksum_path])
for name in get_available_cubin_files(safe_urljoin(base, cubin_dir)):
yield (safe_urljoin(cubin_dir, name), checksums[name])


def download_artifacts() -> None:
Expand All @@ -130,8 +181,7 @@ def download_artifacts() -> None:
# use a shared session to make use of HTTP keep-alive and reuse of
# HTTPS connections.
session = requests.Session()

cubin_files = list(get_cubin_file_list())
cubin_files = list(get_subdir_file_list())
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
with tqdm_logging_redirect(
total=len(cubin_files), desc="Downloading cubins"
Expand All @@ -142,7 +192,7 @@ def update_pbar_cb(_) -> None:

with ThreadPoolExecutor(num_threads) as pool:
futures = []
for name in cubin_files:
for name, _ in cubin_files:
source = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, name)
local_path = FLASHINFER_CUBIN_DIR / name
# Ensure parent directory exists
Expand All @@ -159,13 +209,19 @@ def update_pbar_cb(_) -> None:
if not all_success:
raise RuntimeError("Failed to download cubins")

# Check checksums of all downloaded cubins
for name, checksum in cubin_files:
local_path = FLASHINFER_CUBIN_DIR / name
if not verify_cubin(str(local_path), checksum):
raise RuntimeError("Failed to download cubins: checksum mismatch")


def get_artifacts_status() -> tuple[tuple[str, bool], ...]:
"""
Check which cubins are already downloaded and return (num_downloaded, total).
Does not download any cubins.
"""
cubin_files = get_cubin_file_list()
cubin_files = get_subdir_file_list()

def _check_file_status(file_name: str) -> tuple[str, bool]:
# get_cubin stores cubins in FLASHINFER_CUBIN_DIR with the same relative path
Expand All @@ -174,7 +230,7 @@ def _check_file_status(file_name: str) -> tuple[str, bool]:
exists = os.path.isfile(local_path)
return (file_name, exists)

return tuple(_check_file_status(file_name) for file_name in cubin_files)
return tuple(_check_file_status(file_name) for file_name, _ in cubin_files)


def clear_cubin():
Expand Down
9 changes: 5 additions & 4 deletions flashinfer/jit/attention/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
sm90a_nvcc_flags,
current_compilation_context,
)
from ...jit.cubin_loader import get_cubin
from ...jit.cubin_loader import get_cubin, get_meta_hash
from ..utils import (
dtype_map,
filename_safe_dtype_map,
Expand Down Expand Up @@ -1568,14 +1568,15 @@ def gen_fmha_cutlass_sm100a_module(


def gen_trtllm_gen_fmha_module():
from ...artifacts import ArtifactPath, MetaInfoHash
from ...artifacts import ArtifactPath

include_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/include"
header_name = "flashInferMetaInfo"

# use `get_cubin` to get "flashinferMetaInfo.h"
metainfo = get_cubin(
f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_FMHA
f"{include_path}/{header_name}.h",
get_meta_hash(ArtifactPath.TRTLLM_GEN_FMHA),
)

# make sure "flashinferMetaInfo.h" is downloaded or cached
Expand All @@ -1591,7 +1592,7 @@ def gen_trtllm_gen_fmha_module():
extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path],
extra_cuda_cflags=[
f'-DTLLM_GEN_FMHA_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_FMHA}\\"',
f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{MetaInfoHash.TRTLLM_GEN_FMHA}\\"',
f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt")}\\"',
],
)

Expand Down
28 changes: 28 additions & 0 deletions flashinfer/jit/cubin_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,34 @@ def download_file(
return False


def get_meta_hash(checksum_path: str) -> str:
"""
Load the file from local cache (checksums.txt)
and get the hash of corresponding flashinferMetaInfo.h file
"""
with open(checksum_path + "/checksums.txt", "r") as f:
for line in f:
sha256, filename = line.strip().split()
if ".h" in filename:
return sha256
raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}")


def verify_cubin(cubin_path: str, expected_sha256: str) -> bool:
"""
Verify the cubin file against the sha256 checksum.
"""
with open(cubin_path, "rb") as f:
data = f.read()
actual_sha256 = hashlib.sha256(data).hexdigest()
if actual_sha256 != expected_sha256:
logger.warning(
f"sha256 mismatch (expected {expected_sha256} actual {actual_sha256}) for {cubin_path}"
)
return False
return True


def load_cubin(cubin_path: str, sha256: str) -> bytes:
"""
Load a cubin from the provide local path and
Expand Down
9 changes: 6 additions & 3 deletions flashinfer/jit/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from typing import List

from . import env as jit_env
from ..artifacts import ArtifactPath, MetaInfoHash
from ..artifacts import ArtifactPath
from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags
from .cpp_ext import is_cuda_version_at_least
from .cubin_loader import get_cubin
from .cubin_loader import get_cubin, get_meta_hash
from .gemm.cutlass.generate_kernels import generate_gemm_operations


Expand Down Expand Up @@ -179,7 +179,10 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:
header_name = "flashinferMetaInfo"

# use `get_cubin` to get "flashinferMetaInfo.h"
metainfo = get_cubin(f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_BMM)
metainfo = get_cubin(
f"{include_path}/{header_name}.h",
get_meta_hash(ArtifactPath.TRTLLM_GEN_BMM),
)
# make sure "flashinferMetaInfo.h" is downloaded or cached
assert metainfo, f"{header_name}.h not found"

Expand Down
6 changes: 3 additions & 3 deletions flashinfer/jit/gemm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jinja2
import torch

from ...artifacts import ArtifactPath, MetaInfoHash
from ...artifacts import ArtifactPath
from .. import env as jit_env
from ..core import (
JitSpec,
Expand All @@ -30,7 +30,7 @@
sm100f_nvcc_flags,
current_compilation_context,
)
from ..cubin_loader import get_cubin
from ..cubin_loader import get_cubin, get_meta_hash
from ..utils import dtype_cutlass_map, filename_safe_dtype_map, write_if_different


Expand Down Expand Up @@ -364,7 +364,7 @@ def gen_trtllm_gen_gemm_module() -> JitSpec:
# use `get_cubin` to get "flashinferMetaInfo.h"
metainfo = get_cubin(
f"{include_path}/{header_name}.h",
MetaInfoHash.TRTLLM_GEN_GEMM,
get_meta_hash(ArtifactPath.TRTLLM_GEN_GEMM),
)
# make sure "flashinferMetaInfo.h" is downloaded or cached
assert metainfo, f"{header_name}.h not found"
Expand Down
7 changes: 1 addition & 6 deletions tests/utils/test_load_cubin_compile_race_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,12 @@ def worker_process(temp_dir):
os.environ["FLASHINFER_CUBIN_DIR"] = temp_dir

# Import here to ensure FLASHINFER_CUBIN_DIR is set before module loads
from flashinfer.artifacts import ArtifactPath, MetaInfoHash
from flashinfer.jit.cubin_loader import get_cubin
from flashinfer.artifacts import ArtifactPath

# Define the target file - same for all processes
include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include"
header_name = "flashinferMetaInfo"

# Use get_cubin to get "flashinferMetaInfo.h"
# Note: all processes target the same file name
metainfo = get_cubin(f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_BMM) # noqa: F841

# Read the file from FLASHINFER_CUBIN_DIR
# NOTE(Zihao): instead of using metainfo, we directly read from the file path,
# that aligns with how we compile the kernel.
Expand Down