|
23 | 23 | convert_to_block_layout,
|
24 | 24 | get_w2_permute_indices_with_cache,
|
25 | 25 | )
|
| 26 | +from flashinfer.jit.gemm.core import gen_trtllm_low_latency_gemm_module |
26 | 27 | import torch
|
27 | 28 |
|
28 |
| -from flashinfer.artifacts import ArtifactPath, MetaInfoHash |
29 | 29 | from flashinfer.autotuner import (
|
30 | 30 | AutoTuner,
|
31 | 31 | TuningConfig,
|
|
38 | 38 | get_last_power_of_2_num_tokens_buckets,
|
39 | 39 | last_positive_power_of_2,
|
40 | 40 | )
|
41 |
| -from flashinfer.jit import setup_cubin_loader, JitSpec, gen_jit_spec, sm100a_nvcc_flags |
42 |
| -from flashinfer.jit import env as jit_env |
43 |
| -from flashinfer.jit.cubin_loader import get_cubin |
| 41 | +from flashinfer.jit import setup_cubin_loader |
44 | 42 | from flashinfer.utils import _get_cache_buf
|
45 | 43 |
|
46 | 44 |
|
47 |
| -def gen_trtllm_low_latency_gemm_module() -> JitSpec: |
48 |
| - include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include" |
49 |
| - header_name = "flashinferMetaInfo" |
50 |
| - |
51 |
| - # use `get_cubin` to get "flashinferMetaInfo.h" |
52 |
| - metainfo = get_cubin( |
53 |
| - f"{include_path}/{header_name}.h", |
54 |
| - MetaInfoHash.TRTLLM_GEN_GEMM, |
55 |
| - ) |
56 |
| - # make sure "flashinferMetaInfo.h" is downloaded or cached |
57 |
| - assert metainfo, f"{header_name}.h not found" |
58 |
| - return gen_jit_spec( |
59 |
| - "trtllm_gemm", |
60 |
| - [ |
61 |
| - jit_env.FLASHINFER_CSRC_DIR / "trtllm_low_latency_gemm_runner.cu", |
62 |
| - ], |
63 |
| - extra_cuda_cflags=[ |
64 |
| - "-DTLLM_GEN_EXPORT_INTERFACE", |
65 |
| - "-DTLLM_ENABLE_CUDA", |
66 |
| - f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"', |
67 |
| - ] |
68 |
| - + sm100a_nvcc_flags, |
69 |
| - # link "include" sub-directory in cache |
70 |
| - extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path], |
71 |
| - extra_ldflags=["-lcuda"], |
72 |
| - ) |
73 |
| - |
74 |
| - |
75 | 45 | @functools.cache
|
76 | 46 | def get_trtllm_low_latency_gemm_module():
|
77 | 47 | mod = gen_trtllm_low_latency_gemm_module()
|
|
0 commit comments