Skip to content

Commit 2ce62b9

Browse files
committed
undo unwanted changes
1 parent 40bb078 commit 2ce62b9

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@
2727
#include "flashinferMetaInfo.h"
2828
#endif // TLLM_GEN_EXPORT_INTERFACE
2929

30+
#ifdef TLLM_GEN_GEMM_CUBIN_PATH
31+
static const std::string tllm_gen_gemm_cubin_path = std::string(TLLM_GEN_GEMM_CUBIN_PATH);
32+
#else
33+
static_assert(false, "TLLM_GEN_GEMM_CUBIN_PATH macro is not defined when compiling");
34+
#endif
35+
36+
namespace flashinfer::trtllm_cubin_loader {
37+
std::string getCubin(const std::string& kernelName, const std::string& sha256);
38+
} // namespace flashinfer::trtllm_cubin_loader
39+
3040
namespace gemm {
3141

3242
namespace gemm {
@@ -315,7 +325,8 @@ GemmConfig const* GemmInterface::getGemmConfigs() const {
315325

316326
size_t GemmInterface::getNumGemmConfigs() const {
317327
#ifdef TLLM_GEN_EXPORT_INTERFACE
318-
return tensorrt_llm::kernels::tllmGenGemmListLen;
328+
return sizeof(tensorrt_llm::kernels::tllmGenGemmList) /
329+
sizeof(tensorrt_llm::kernels::tllmGenGemmList[0]);
319330
#else
320331
return 0;
321332
#endif
@@ -480,6 +491,17 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
480491
CUmodule cuModule;
481492
CUfunction cuFunction;
482493

494+
auto fiModuleLoadData = [&](CUmodule* module) {
495+
const std::string sha256 = config.mHash ? config.mHash : "";
496+
std::string fname_cubin = config.mFunctionName;
497+
if (!fname_cubin.empty()) {
498+
fname_cubin[0] = static_cast<char>(std::toupper(static_cast<unsigned char>(fname_cubin[0])));
499+
}
500+
fname_cubin = tllm_gen_gemm_cubin_path + "/" + fname_cubin + ".cubin";
501+
std::string cubin = flashinfer::trtllm_cubin_loader::getCubin(fname_cubin, sha256);
502+
cuModuleLoadData(&cuModule, cubin.c_str());
503+
};
504+
483505
if (moduleCache.has_value()) {
484506
ModuleCache& moduleCacheRef = moduleCache.value().get();
485507

@@ -501,12 +523,12 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
501523
if (module != moduleCacheRef.end()) {
502524
cuFunction = std::get<1>(module->second);
503525
} else {
504-
cuModuleLoadData(&cuModule, config.mData);
526+
fiModuleLoadData(&cuModule);
505527
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
506528
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
507529
}
508530
} else {
509-
cuModuleLoadData(&cuModule, config.mData);
531+
fiModuleLoadData(&cuModule);
510532
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
511533
}
512534

@@ -534,9 +556,7 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
534556
return -1;
535557
}
536558
#else
537-
config.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid,
538-
/*cluster*/ {},
539-
/*instanceId*/ config.mInstanceIdx);
559+
config.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid);
540560
#endif
541561

542562
return 0;

0 commit comments

Comments
 (0)