Skip to content

Commit fd07386

Browse files
committed
deprecate metainfohash
1 parent bfad066 commit fd07386

File tree

5 files changed

+27
-10
lines changed

5 files changed

+27
-10
lines changed

flashinfer/artifacts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class ArtifactPath:
9191
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
9292

9393

94-
# TODO: Should be deprecated
94+
# TODO (jimmyzho): Should be deprecated except DEEPGEMM
9595
@dataclass(frozen=True)
9696
class MetaInfoHash:
9797
TRTLLM_GEN_FMHA: str = (

flashinfer/jit/attention/modules.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
sm90a_nvcc_flags,
2929
current_compilation_context,
3030
)
31-
from ...jit.cubin_loader import get_cubin
31+
from ...jit.cubin_loader import get_cubin, get_meta_hash
3232
from ..utils import (
3333
dtype_map,
3434
filename_safe_dtype_map,
@@ -1575,7 +1575,8 @@ def gen_trtllm_gen_fmha_module():
15751575

15761576
# use `get_cubin` to get "flashinferMetaInfo.h"
15771577
metainfo = get_cubin(
1578-
f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_FMHA
1578+
f"{include_path}/{header_name}.h",
1579+
get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt"),
15791580
)
15801581

15811582
# make sure "flashinferMetaInfo.h" is downloaded or cached
@@ -1591,7 +1592,7 @@ def gen_trtllm_gen_fmha_module():
15911592
extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path],
15921593
extra_cuda_cflags=[
15931594
f'-DTLLM_GEN_FMHA_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_FMHA}\\"',
1594-
f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{MetaInfoHash.TRTLLM_GEN_FMHA}\\"',
1595+
f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt")}\\"',
15951596
],
15961597
)
15971598

flashinfer/jit/cubin_loader.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,19 @@ def download_file(
127127
return False
128128

129129

130+
def get_meta_hash(checksum_path: str) -> str:
131+
"""
132+
Load the file from local cache (checksums.txt)
133+
and get the hash of corresponding flashinferMetaInfo.h file
134+
"""
135+
with open(checksum_path, "r") as f:
136+
for line in f:
137+
sha256, filename = line.strip().split()
138+
if ".h" in filename:
139+
return sha256
140+
raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}")
141+
142+
130143
def load_cubin(cubin_path: str, sha256: str) -> bytes:
131144
"""
132145
Load a cubin from the provide local path and

flashinfer/jit/fused_moe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from typing import List
1818

1919
from . import env as jit_env
20-
from ..artifacts import ArtifactPath, MetaInfoHash
20+
from ..artifacts import ArtifactPath
2121
from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags
2222
from .cpp_ext import is_cuda_version_at_least
23-
from .cubin_loader import get_cubin
23+
from .cubin_loader import get_cubin, get_meta_hash
2424
from .gemm.cutlass.generate_kernels import generate_gemm_operations
2525

2626

@@ -179,7 +179,10 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:
179179
header_name = "flashinferMetaInfo"
180180

181181
# use `get_cubin` to get "flashinferMetaInfo.h"
182-
metainfo = get_cubin(f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_BMM)
182+
metainfo = get_cubin(
183+
f"{include_path}/{header_name}.h",
184+
get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_BMM}/checksums.txt"),
185+
)
183186
# make sure "flashinferMetaInfo.h" is downloaded or cached
184187
assert metainfo, f"{header_name}.h not found"
185188

flashinfer/jit/gemm/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import jinja2
2121
import torch
2222

23-
from ...artifacts import ArtifactPath, MetaInfoHash
23+
from ...artifacts import ArtifactPath
2424
from .. import env as jit_env
2525
from ..core import (
2626
JitSpec,
@@ -30,7 +30,7 @@
3030
sm100f_nvcc_flags,
3131
current_compilation_context,
3232
)
33-
from ..cubin_loader import get_cubin
33+
from ..cubin_loader import get_cubin, get_meta_hash
3434
from ..utils import dtype_cutlass_map, filename_safe_dtype_map, write_if_different
3535

3636

@@ -364,7 +364,7 @@ def gen_trtllm_gen_gemm_module() -> JitSpec:
364364
# use `get_cubin` to get "flashinferMetaInfo.h"
365365
metainfo = get_cubin(
366366
f"{include_path}/{header_name}.h",
367-
MetaInfoHash.TRTLLM_GEN_GEMM,
367+
get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_GEMM}/checksums.txt"),
368368
)
369369
# make sure "flashinferMetaInfo.h" is downloaded or cached
370370
assert metainfo, f"{header_name}.h not found"

0 commit comments

Comments
 (0)