27
27
#include " flashinferMetaInfo.h"
28
28
#endif // TLLM_GEN_EXPORT_INTERFACE
29
29
30
+ #ifdef TLLM_GEN_GEMM_CUBIN_PATH
31
+ static const std::string tllm_gen_gemm_cubin_path = std::string(TLLM_GEN_GEMM_CUBIN_PATH);
32
+ #else
33
+ static_assert (false , " TLLM_GEN_GEMM_CUBIN_PATH macro is not defined when compiling" );
34
+ #endif
35
+
36
+ namespace flashinfer ::trtllm_cubin_loader {
37
+ std::string getCubin (const std::string& kernelName, const std::string& sha256);
38
+ } // namespace flashinfer::trtllm_cubin_loader
39
+
30
40
namespace gemm {
31
41
32
42
namespace gemm {
@@ -315,7 +325,8 @@ GemmConfig const* GemmInterface::getGemmConfigs() const {
315
325
316
326
size_t GemmInterface::getNumGemmConfigs () const {
317
327
#ifdef TLLM_GEN_EXPORT_INTERFACE
318
- return tensorrt_llm::kernels::tllmGenGemmListLen;
328
+ return sizeof (tensorrt_llm::kernels::tllmGenGemmList) /
329
+ sizeof (tensorrt_llm::kernels::tllmGenGemmList[0 ]);
319
330
#else
320
331
return 0 ;
321
332
#endif
@@ -480,6 +491,17 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
480
491
CUmodule cuModule;
481
492
CUfunction cuFunction;
482
493
494
+ auto fiModuleLoadData = [&](CUmodule* module ) {
495
+ const std::string sha256 = config.mHash ? config.mHash : " " ;
496
+ std::string fname_cubin = config.mFunctionName ;
497
+ if (!fname_cubin.empty ()) {
498
+ fname_cubin[0 ] = static_cast <char >(std::toupper (static_cast <unsigned char >(fname_cubin[0 ])));
499
+ }
500
+ fname_cubin = tllm_gen_gemm_cubin_path + " /" + fname_cubin + " .cubin" ;
501
+ std::string cubin = flashinfer::trtllm_cubin_loader::getCubin (fname_cubin, sha256);
502
+ cuModuleLoadData (&cuModule, cubin.c_str ());
503
+ };
504
+
483
505
if (moduleCache.has_value ()) {
484
506
ModuleCache& moduleCacheRef = moduleCache.value ().get ();
485
507
@@ -501,12 +523,12 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
501
523
if (module != moduleCacheRef.end ()) {
502
524
cuFunction = std::get<1 >(module ->second );
503
525
} else {
504
- cuModuleLoadData (&cuModule, config. mData );
526
+ fiModuleLoadData (&cuModule);
505
527
cuModuleGetFunction (&cuFunction, cuModule, config.mFunctionName );
506
528
moduleCacheRef.insert (std::make_pair (moduleKey, std::make_tuple (cuModule, cuFunction)));
507
529
}
508
530
} else {
509
- cuModuleLoadData (&cuModule, config. mData );
531
+ fiModuleLoadData (&cuModule);
510
532
cuModuleGetFunction (&cuFunction, cuModule, config.mFunctionName );
511
533
}
512
534
@@ -534,9 +556,7 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
534
556
return -1 ;
535
557
}
536
558
#else
537
- config.mCudaRunner ->run ((void *)&kernelParams, (void *)cudaStream, grid,
538
- /* cluster*/ {},
539
- /* instanceId*/ config.mInstanceIdx );
559
+ config.mCudaRunner ->run ((void *)&kernelParams, (void *)cudaStream, grid);
540
560
#endif
541
561
542
562
return 0 ;
0 commit comments