Skip to content

Commit 3e8f69a

Browse files
committed
revert
init upd
1 parent 2fe5331 commit 3e8f69a

File tree

12 files changed

+44
-19904
lines changed

12 files changed

+44
-19904
lines changed

flashinfer/decode.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
get_batch_prefill_uri,
3434
get_single_decode_uri,
3535
setup_cubin_loader,
36-
setup_metainfo_loader,
3736
trtllm_gen_fmha_module,
3837
)
3938
from .page import get_seq_lens
@@ -304,7 +303,6 @@ def get_trtllm_gen_fmha_module():
304303
mod = trtllm_gen_fmha_module()
305304
op = mod.build_and_load()
306305
setup_cubin_loader(mod.get_library_path())
307-
setup_metainfo_loader(mod.get_library_path())
308306
return op
309307

310308

@@ -1833,13 +1831,9 @@ def __init__(self):
18331831
self._sm_count: Optional[int] = None
18341832
self._mod = trtllm_gen_fmha_module()
18351833
self._op = self._mod.build_and_load()
1836-
from flashinfer.jit.cubin_loader import (
1837-
setup_cubin_loader,
1838-
setup_metainfo_loader,
1839-
)
1834+
from flashinfer.jit.cubin_loader import setup_cubin_loader
18401835

18411836
setup_cubin_loader(self._mod.get_library_path())
1842-
setup_metainfo_loader(self._mod.get_library_path())
18431837

18441838

18451839
@functools.cache

flashinfer/fused_moe.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .jit import JitSpec
4040
from .jit import env as jit_env
4141
from .jit import gen_jit_spec, setup_cubin_loader, sm100a_nvcc_flags
42+
from .jit.cubin_loader import get_cubin
4243
from .utils import _check_shape_dtype_device, register_custom_op, register_fake_op
4344

4445

@@ -773,6 +774,18 @@ def cutlass_fused_moe(
773774

774775

775776
def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
777+
hash = "6b93c394210c89dccef13833c89797f1b8f8aefb"
778+
tllm_gen_commit = "ce8ce46"
779+
tllm_gen_config_hash = "2dc78d9"
780+
include_path = (
781+
f"{hash}/batched_gemm-{tllm_gen_commit}-{tllm_gen_config_hash}/include"
782+
)
783+
metainfo = get_cubin(
784+
f"{include_path}/flashinferMetaInfo",
785+
"b24fd5e7ae6b20e903c866ecb1d4a68f238301ba9b76df6a536056f2059a0d56",
786+
".h",
787+
)
788+
assert metainfo, "KernelMetaInfo.h not found"
776789
return gen_jit_spec(
777790
"fused_moe_sm100",
778791
[
@@ -788,8 +801,12 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
788801
"-DENABLE_BF16",
789802
"-DENABLE_FP8",
790803
"-DENABLE_FP4",
804+
f'-DPIPELINE_HASH=\\"{hash}\\"',
805+
f'-DTLLM_GEN_COMMIT=\\"{tllm_gen_commit}\\"',
806+
f'-DTLLM_GEN_BATCHED_GEMM_CONFIG_HASH=\\"{tllm_gen_config_hash}\\"',
791807
]
792808
+ sm100a_nvcc_flags,
809+
extra_include_paths=[jit_env.FLASHINFER_CACHE_DIR / "cubins" / include_path],
793810
extra_ldflags=["-lcuda"],
794811
)
795812

flashinfer/jit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
from .core import gen_jit_spec as gen_jit_spec
6969
from .core import sm90a_nvcc_flags as sm90a_nvcc_flags
7070
from .core import sm100a_nvcc_flags as sm100a_nvcc_flags
71-
from .cubin_loader import setup_cubin_loader, setup_metainfo_loader
71+
from .cubin_loader import setup_cubin_loader
7272

7373

7474
@functools.cache

flashinfer/jit/attention/pytorch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .. import env as jit_env
2424
from ..core import JitSpec, gen_jit_spec, logger, sm90a_nvcc_flags, sm100a_nvcc_flags
25+
from ..cubin_loader import get_cubin
2526
from ..utils import (
2627
dtype_map,
2728
filename_safe_dtype_map,
@@ -1487,13 +1488,23 @@ def gen_fmha_cutlass_sm100a_module(
14871488

14881489

14891490
def trtllm_gen_fmha_module():
1491+
hash = "6b93c394210c89dccef13833c89797f1b8f8aefb"
1492+
include_path = f"{hash}/fmha/trtllm-gen/include"
1493+
metainfo = get_cubin(
1494+
f"{include_path}/flashInferMetaInfo",
1495+
"ba35dc13249cd09bf39eed43e785b088d329acaf81a3f940a615904b81bfa02f",
1496+
".h",
1497+
)
1498+
assert metainfo, "flashInferMetaInfo.h not found"
14901499
return gen_jit_spec(
14911500
"fmha_gen",
14921501
[
14931502
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fmha_runner.cu",
14941503
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fmha_kernel_launcher.cu",
14951504
],
1505+
extra_include_paths=[jit_env.FLASHINFER_CACHE_DIR / "cubins" / include_path],
14961506
extra_ldflags=["-lcuda"],
1507+
extra_cuda_cflags=[f'-DPIPELINE_HASH=\\"{hash}\\"'],
14971508
)
14981509

14991510

flashinfer/jit/cubin_loader.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -188,32 +188,3 @@ def get_cubin_callback(name, sha256):
188188
dll_cubin_handlers[dll_path] = cb
189189

190190
_LIB.FlashInferSetCubinCallback(cb)
191-
192-
193-
dll_metainfo_handlers = {}
194-
195-
196-
def setup_metainfo_loader(dll_path: str):
197-
if dll_path in dll_metainfo_handlers:
198-
return
199-
200-
_LIB = ctypes.CDLL(dll_path)
201-
202-
# Define the correct callback type
203-
CALLBACK_TYPE = ctypes.CFUNCTYPE(
204-
None, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p
205-
)
206-
207-
def get_metainfo_callback(name, sha256, extension):
208-
metainfo = get_cubin(
209-
name.decode("utf-8"), sha256.decode("utf-8"), extension.decode("utf-8")
210-
)
211-
_LIB.FlashInferSetCurrentMetaInfo(
212-
convert_to_ctypes_char_p(metainfo), ctypes.c_int(len(metainfo))
213-
)
214-
215-
# Create the callback and keep a reference to prevent GC
216-
cb = CALLBACK_TYPE(get_metainfo_callback)
217-
dll_metainfo_handlers[dll_path] = cb
218-
219-
_LIB.FlashInferSetMetaInfoCallback(cb)

flashinfer/prefill.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
get_batch_prefill_uri,
3232
get_single_prefill_uri,
3333
setup_cubin_loader,
34-
setup_metainfo_loader,
3534
trtllm_gen_fmha_module,
3635
)
3736
from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens
@@ -92,7 +91,6 @@ def get_trtllm_gen_prefill_module():
9291
mod = trtllm_gen_fmha_module()
9392
op = mod.build_and_load()
9493
setup_cubin_loader(mod.get_library_path())
95-
setup_metainfo_loader(mod.get_library_path())
9694

9795
def _paged_run(
9896
query: torch.Tensor,
@@ -2946,7 +2944,6 @@ def get_trtllm_gen_fmha_module():
29462944
mod = trtllm_gen_fmha_module()
29472945
op = mod.build_and_load()
29482946
setup_cubin_loader(mod.get_library_path())
2949-
setup_metainfo_loader(mod.get_library_path())
29502947
return op
29512948

29522949

include/flashinfer/cubin_loader.h

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -56,31 +56,3 @@ std::string getCubin(const std::string& name, const std::string& sha256) {
5656
callbackGetCubin(name.c_str(), sha256.c_str());
5757
return current_cubin;
5858
}
59-
60-
void (*callbackGetMetaInfo)(const char* path, const char* sha256, const char* extension) = nullptr;
61-
62-
// Set the python callback, called by the python code using ctypes.
63-
extern "C" void FlashInferSetMetaInfoCallback(void (*callback)(const char* path, const char* sha256,
64-
const char* extension)) {
65-
callbackGetMetaInfo = callback;
66-
}
67-
68-
// Thread-local variable that stores the current metainfo.
69-
// It is reset on every call to `getMetaInfo()`.
70-
thread_local std::string raw_metainfo;
71-
72-
// Called by the callback to set the current metainfo.
73-
extern "C" void FlashInferSetCurrentMetaInfo(const char* binary, int size) {
74-
raw_metainfo = std::string(binary, size);
75-
}
76-
77-
// Get the metainfo from the python callback.
78-
// This is the API for the native library to use.
79-
std::string getMetaInfo(const std::string& name, const std::string& sha256,
80-
const std::string& extension) {
81-
if (!callbackGetMetaInfo) {
82-
throw std::runtime_error("FlashInferSetMetaInfoCallback not set");
83-
}
84-
callbackGetMetaInfo(name.c_str(), sha256.c_str(), extension.c_str());
85-
return raw_metainfo;
86-
}

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#include "trtllm/gen/CudaKernelLauncher.h"
2525

2626
#ifdef TLLM_GEN_EXPORT_INTERFACE
27-
#include "KernelMetaInfo.h"
27+
#include "flashinferMetaInfo.h"
2828
#endif // TLLM_GEN_EXPORT_INTERFACE
2929

3030
namespace flashinfer::trtllm_cubin_loader {
@@ -466,7 +466,8 @@ BatchedGemmConfig const* BatchedGemmInterface::getBatchedGemmConfigs() const {
466466

467467
size_t BatchedGemmInterface::getNumBatchedGemmConfigs() const {
468468
#ifdef TLLM_GEN_EXPORT_INTERFACE
469-
return tensorrt_llm::kernels::tllmGenBatchedGemmListLen;
469+
return sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList) /
470+
sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList[0]);
470471
#else
471472
return 0;
472473
#endif
@@ -645,8 +646,7 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
645646

646647
auto fiModuleLoadData = [&](CUmodule* module) {
647648
const std::string sha256 = config.mHash ? config.mHash : "";
648-
const std::string pipeline_hash = "991e7438224199de85ef08a2730ce18c12b4e0aa";
649-
const std::string cubin_path = pipeline_hash + "/" + std::string("batched_gemm-") +
649+
const std::string cubin_path = std::string(PIPELINE_HASH) + "/" + std::string("batched_gemm-") +
650650
TLLM_GEN_COMMIT + "-" + TLLM_GEN_BATCHED_GEMM_CONFIG_HASH + "/";
651651
std::string fname_cubin = config.mFunctionName;
652652
if (!fname_cubin.empty()) {

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,6 @@ struct BatchedGemmConfig {
302302
// defined. In this case, the cubins will be loaded from the provided data and function name.
303303
// Otherwise, the kernel will be loaded from the CudaRunner.
304304
#ifdef TLLM_GEN_EXPORT_INTERFACE
305-
uint8_t const* mData{nullptr};
306-
uint32_t const mSize{0};
307305
uint32_t const mSharedMemSize{0};
308306
char const* mFunctionName{nullptr};
309307
uint32_t const mNumThreadsPerCTA{0};

0 commit comments

Comments
 (0)