Skip to content

Commit 343a186

Browse files
committed
deprecate metainfohash
1 parent 480f860 commit 343a186

File tree

5 files changed

+28
-11
lines changed

5 files changed

+28
-11
lines changed

flashinfer/artifacts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class ArtifactPath:
8585
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
8686

8787

88-
# TODO: Should be deprecated
88+
# TODO (jimmyzho): Should be deprecated except DEEPGEMM
8989
@dataclass(frozen=True)
9090
class MetaInfoHash:
9191
TRTLLM_GEN_FMHA: str = (

flashinfer/jit/attention/modules.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import jinja2
2121
import torch
2222

23-
from ...artifacts import ArtifactPath, MetaInfoHash
23+
from ...artifacts import ArtifactPath
2424
from .. import env as jit_env
2525
from ..core import (
2626
JitSpec,
@@ -29,7 +29,7 @@
2929
sm90a_nvcc_flags,
3030
current_compilation_context,
3131
)
32-
from ...jit.cubin_loader import get_cubin
32+
from ...jit.cubin_loader import get_cubin, get_meta_hash
3333
from ..utils import (
3434
dtype_map,
3535
filename_safe_dtype_map,
@@ -1574,7 +1574,8 @@ def gen_trtllm_gen_fmha_module():
15741574

15751575
# use `get_cubin` to get "flashinferMetaInfo.h"
15761576
metainfo = get_cubin(
1577-
f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_FMHA
1577+
f"{include_path}/{header_name}.h",
1578+
get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt"),
15781579
)
15791580

15801581
# make sure "flashinferMetaInfo.h" is downloaded or cached
@@ -1590,7 +1591,7 @@ def gen_trtllm_gen_fmha_module():
15901591
extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path],
15911592
extra_cuda_cflags=[
15921593
f'-DTLLM_GEN_FMHA_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_FMHA}\\"',
1593-
f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{MetaInfoHash.TRTLLM_GEN_FMHA}\\"',
1594+
f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt")}\\"',
15941595
],
15951596
)
15961597

flashinfer/jit/cubin_loader.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,19 @@ def download_file(
120120
return False
121121

122122

123+
def get_meta_hash(checksum_path: str) -> str:
124+
"""
125+
Load the file from local cache (checksums.txt)
126+
and get the hash of corresponding flashinferMetaInfo.h file
127+
"""
128+
with open(checksum_path, "r") as f:
129+
for line in f:
130+
sha256, filename = line.strip().split()
131+
if ".h" in filename:
132+
return sha256
133+
raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}")
134+
135+
123136
def load_cubin(cubin_path: str, sha256: str) -> bytes:
124137
"""
125138
Load a cubin from the provide local path and

flashinfer/jit/fused_moe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from typing import List
1818

1919
from . import env as jit_env
20-
from ..artifacts import ArtifactPath, MetaInfoHash
20+
from ..artifacts import ArtifactPath
2121
from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags
2222
from .cpp_ext import is_cuda_version_at_least
23-
from .cubin_loader import get_cubin
23+
from .cubin_loader import get_cubin, get_meta_hash
2424
from .gemm.cutlass.generate_kernels import generate_gemm_operations
2525

2626

@@ -179,7 +179,10 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:
179179
header_name = "flashinferMetaInfo"
180180

181181
# use `get_cubin` to get "flashinferMetaInfo.h"
182-
metainfo = get_cubin(f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_BMM)
182+
metainfo = get_cubin(
183+
f"{include_path}/{header_name}.h",
184+
get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_BMM}/checksums.txt"),
185+
)
183186
# make sure "flashinferMetaInfo.h" is downloaded or cached
184187
assert metainfo, f"{header_name}.h not found"
185188

flashinfer/jit/gemm/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import jinja2
2121
import torch
2222

23-
from ...artifacts import ArtifactPath, MetaInfoHash
23+
from ...artifacts import ArtifactPath
2424
from .. import env as jit_env
2525
from ..core import (
2626
JitSpec,
@@ -30,7 +30,7 @@
3030
sm100f_nvcc_flags,
3131
current_compilation_context,
3232
)
33-
from ..cubin_loader import get_cubin
33+
from ..cubin_loader import get_cubin, get_meta_hash
3434
from ..utils import dtype_cutlass_map, filename_safe_dtype_map, write_if_different
3535

3636

@@ -364,7 +364,7 @@ def gen_trtllm_gen_gemm_module() -> JitSpec:
364364
# use `get_cubin` to get "flashinferMetaInfo.h"
365365
metainfo = get_cubin(
366366
f"{include_path}/{header_name}.h",
367-
MetaInfoHash.TRTLLM_GEN_GEMM,
367+
get_meta_hash(f"{ArtifactPath.TRTLLM_GEN_GEMM}/checksums.txt"),
368368
)
369369
# make sure "flashinferMetaInfo.h" is downloaded or cached
370370
assert metainfo, f"{header_name}.h not found"

0 commit comments

Comments
 (0)