Skip to content

Commit e0eca6a

Browse files
committed
upd
1 parent 1aedbf2 commit e0eca6a

File tree

5 files changed

+23
-31
lines changed

5 files changed

+23
-31
lines changed

flashinfer/fused_moe.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -774,10 +774,15 @@ def cutlass_fused_moe(
774774

775775

776776
def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
777-
hash = "f5deee96023f1d74b1ff71ac69f782a96741a053"
777+
hash = "6b93c394210c89dccef13833c89797f1b8f8aefb"
778+
tllm_gen_commit = "ce8ce46"
779+
tllm_gen_config_hash = "2dc78d9"
780+
include_path = (
781+
f"{hash}/batched_gemm-{tllm_gen_commit}-{tllm_gen_config_hash}/include"
782+
)
778783
metainfo = get_cubin(
779-
f"{hash}/batched_gemm-c603ed2-3fa89e1/include/KernelMetaInfo",
780-
"d789c63aaeee1aa0a68ebf22fa693b6b82a7c2319bd933a00a10306ca08d9e0e",
784+
f"{include_path}/flashinferMetaInfo",
785+
"b24fd5e7ae6b20e903c866ecb1d4a68f238301ba9b76df6a536056f2059a0d56",
781786
".h",
782787
)
783788
assert metainfo, "KernelMetaInfo.h not found"
@@ -796,15 +801,12 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
796801
"-DENABLE_BF16",
797802
"-DENABLE_FP8",
798803
"-DENABLE_FP4",
804+
f'-DPIPELINE_HASH=\\"{hash}\\"',
805+
f'-DTLLM_GEN_COMMIT=\\"{tllm_gen_commit}\\"',
806+
f'-DTLLM_GEN_BATCHED_GEMM_CONFIG_HASH=\\"{tllm_gen_config_hash}\\"',
799807
]
800808
+ sm100a_nvcc_flags,
801-
extra_include_paths=[
802-
jit_env.FLASHINFER_CACHE_DIR
803-
/ "cubins"
804-
/ hash
805-
/ "batched_gemm-c603ed2-3fa89e1"
806-
/ "include"
807-
],
809+
extra_include_paths=[jit_env.FLASHINFER_CACHE_DIR / "cubins" / include_path],
808810
extra_ldflags=["-lcuda"],
809811
)
810812

flashinfer/jit/attention/pytorch.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,10 +1483,11 @@ def gen_fmha_cutlass_sm100a_module(
14831483

14841484

14851485
def trtllm_gen_fmha_module():
1486-
hash = "5f2779e6df822bc0b26940b6d3b0059c86f0a6a1"
1486+
hash = "6b93c394210c89dccef13833c89797f1b8f8aefb"
1487+
include_path = f"{hash}/fmha/trtllm-gen/include"
14871488
metainfo = get_cubin(
1488-
f"{hash}/fmha/trtllm-gen/include/flashInferMetaInfo",
1489-
"11f31dc81f996e39c3f1d85d773864c9113c5837619e21418a846befa4f8dddd",
1489+
f"{include_path}/flashInferMetaInfo",
1490+
"ba35dc13249cd09bf39eed43e785b088d329acaf81a3f940a615904b81bfa02f",
14901491
".h",
14911492
)
14921493
assert metainfo, "flashInferMetaInfo.h not found"
@@ -1496,15 +1497,9 @@ def trtllm_gen_fmha_module():
14961497
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fmha_runner.cu",
14971498
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fmha_kernel_launcher.cu",
14981499
],
1499-
extra_include_paths=[
1500-
jit_env.FLASHINFER_CACHE_DIR
1501-
/ "cubins"
1502-
/ hash
1503-
/ "fmha"
1504-
/ "trtllm-gen"
1505-
/ "include"
1506-
],
1500+
extra_include_paths=[jit_env.FLASHINFER_CACHE_DIR / "cubins" / include_path],
15071501
extra_ldflags=["-lcuda"],
1502+
extra_cuda_cflags=[f'-DPIPELINE_HASH=\\"{hash}\\"'],
15081503
)
15091504

15101505

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#include "trtllm/gen/CudaKernelLauncher.h"
2525

2626
#ifdef TLLM_GEN_EXPORT_INTERFACE
27-
#include "KernelMetaInfo.h"
27+
#include "flashinferMetaInfo.h"
2828
#endif // TLLM_GEN_EXPORT_INTERFACE
2929

3030
namespace flashinfer::trtllm_cubin_loader {
@@ -466,7 +466,8 @@ BatchedGemmConfig const* BatchedGemmInterface::getBatchedGemmConfigs() const {
466466

467467
size_t BatchedGemmInterface::getNumBatchedGemmConfigs() const {
468468
#ifdef TLLM_GEN_EXPORT_INTERFACE
469-
return tensorrt_llm::kernels::tllmGenBatchedGemmListLen;
469+
return sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList) /
470+
sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList[0]);
470471
#else
471472
return 0;
472473
#endif
@@ -645,8 +646,7 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
645646

646647
auto fiModuleLoadData = [&](CUmodule* module) {
647648
const std::string sha256 = config.mHash ? config.mHash : "";
648-
const std::string pipeline_hash = "f5deee96023f1d74b1ff71ac69f782a96741a053";
649-
const std::string cubin_path = pipeline_hash + "/" + std::string("batched_gemm-") +
649+
const std::string cubin_path = std::string(PIPELINE_HASH) + "/" + std::string("batched_gemm-") +
650650
TLLM_GEN_COMMIT + "-" + TLLM_GEN_BATCHED_GEMM_CONFIG_HASH + "/";
651651
std::string fname_cubin = config.mFunctionName;
652652
if (!fname_cubin.empty()) {

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,6 @@ struct BatchedGemmConfig {
302302
// defined. In this case, the cubins will be loaded from the provided data and function name.
303303
// Otherwise, the kernel will be loaded from the CudaRunner.
304304
#ifdef TLLM_GEN_EXPORT_INTERFACE
305-
uint8_t const* mData{nullptr};
306-
uint32_t const mSize{0};
307305
uint32_t const mSharedMemSize{0};
308306
char const* mFunctionName{nullptr};
309307
uint32_t const mNumThreadsPerCTA{0};

include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,7 @@ class TllmGenFmhaKernel {
231231
}
232232

233233
static std::string getCubinPath() {
234-
const char* env_hash = std::getenv("FLASHINFER_CUBIN_ARTIFACTORY_HASH");
235-
std::string hash =
236-
env_hash ? std::string(env_hash) : "4c7bdebb4eba13311fc652a069e64782d5c0723d";
237-
std::string cubin_path = hash + "/fmha/trtllm-gen/";
234+
std::string cubin_path = std::string(PIPELINE_HASH) + "/fmha/trtllm-gen/";
238235
return cubin_path;
239236
}
240237

0 commit comments

Comments
 (0)