Skip to content

Commit 25abf20

Browse files
committed
move module generation function to jit module
1 parent 2e925c7 commit 25abf20

File tree

2 files changed

+30
-32
lines changed

2 files changed

+30
-32
lines changed

flashinfer/jit/gemm/core.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,3 +493,31 @@ def gen_gemm_sm90_module() -> JitSpec:
493493
source_paths,
494494
extra_cuda_cflags=sm90a_nvcc_flags,
495495
)
496+
497+
498+
def gen_trtllm_low_latency_gemm_module() -> JitSpec:
499+
include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include"
500+
header_name = "flashinferMetaInfo"
501+
502+
# use `get_cubin` to get "flashinferMetaInfo.h"
503+
metainfo = get_cubin(
504+
f"{include_path}/{header_name}.h",
505+
MetaInfoHash.TRTLLM_GEN_GEMM,
506+
)
507+
# make sure "flashinferMetaInfo.h" is downloaded or cached
508+
assert metainfo, f"{header_name}.h not found"
509+
return gen_jit_spec(
510+
"trtllm_gemm",
511+
[
512+
jit_env.FLASHINFER_CSRC_DIR / "trtllm_low_latency_gemm_runner.cu",
513+
],
514+
extra_cuda_cflags=[
515+
"-DTLLM_GEN_EXPORT_INTERFACE",
516+
"-DTLLM_ENABLE_CUDA",
517+
f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"',
518+
]
519+
+ sm100a_nvcc_flags,
520+
# link "include" sub-directory in cache
521+
extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path],
522+
extra_ldflags=["-lcuda"],
523+
)

flashinfer/trtllm_low_latency_gemm.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
convert_to_block_layout,
2424
get_w2_permute_indices_with_cache,
2525
)
26+
from flashinfer.jit.gemm.core import gen_trtllm_low_latency_gemm_module
2627
import torch
2728

28-
from flashinfer.artifacts import ArtifactPath, MetaInfoHash
2929
from flashinfer.autotuner import (
3030
AutoTuner,
3131
TuningConfig,
@@ -38,40 +38,10 @@
3838
get_last_power_of_2_num_tokens_buckets,
3939
last_positive_power_of_2,
4040
)
41-
from flashinfer.jit import setup_cubin_loader, JitSpec, gen_jit_spec, sm100a_nvcc_flags
42-
from flashinfer.jit import env as jit_env
43-
from flashinfer.jit.cubin_loader import get_cubin
41+
from flashinfer.jit import setup_cubin_loader
4442
from flashinfer.utils import _get_cache_buf
4543

4644

47-
def gen_trtllm_low_latency_gemm_module() -> JitSpec:
48-
include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include"
49-
header_name = "flashinferMetaInfo"
50-
51-
# use `get_cubin` to get "flashinferMetaInfo.h"
52-
metainfo = get_cubin(
53-
f"{include_path}/{header_name}.h",
54-
MetaInfoHash.TRTLLM_GEN_GEMM,
55-
)
56-
# make sure "flashinferMetaInfo.h" is downloaded or cached
57-
assert metainfo, f"{header_name}.h not found"
58-
return gen_jit_spec(
59-
"trtllm_gemm",
60-
[
61-
jit_env.FLASHINFER_CSRC_DIR / "trtllm_low_latency_gemm_runner.cu",
62-
],
63-
extra_cuda_cflags=[
64-
"-DTLLM_GEN_EXPORT_INTERFACE",
65-
"-DTLLM_ENABLE_CUDA",
66-
f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"',
67-
]
68-
+ sm100a_nvcc_flags,
69-
# link "include" sub-directory in cache
70-
extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path],
71-
extra_ldflags=["-lcuda"],
72-
)
73-
74-
7545
@functools.cache
7646
def get_trtllm_low_latency_gemm_module():
7747
mod = gen_trtllm_low_latency_gemm_module()

0 commit comments

Comments
 (0)