Skip to content

Commit c2ec25d

Browse files
committed
update trtllmGen_gemm_export
1 parent acc4fa7 commit c2ec25d

File tree

6 files changed

+169
-80
lines changed

6 files changed

+169
-80
lines changed

include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,9 @@
2424
#include "trtllm/gen/CudaKernelLauncher.h"
2525

2626
#ifdef TLLM_GEN_EXPORT_INTERFACE
27-
#include "flashinferMetaInfo.h"
27+
#include "KernelMetaInfo.h"
2828
#endif // TLLM_GEN_EXPORT_INTERFACE
2929

30-
#ifdef TLLM_GEN_GEMM_CUBIN_PATH
31-
static const std::string tllm_gen_gemm_cubin_path = std::string(TLLM_GEN_GEMM_CUBIN_PATH);
32-
#else
33-
static_assert(false, "TLLM_GEN_GEMM_CUBIN_PATH macro is not defined when compiling");
34-
#endif
35-
36-
namespace flashinfer::trtllm_cubin_loader {
37-
std::string getCubin(const std::string& kernelName, const std::string& sha256);
38-
} // namespace flashinfer::trtllm_cubin_loader
39-
4030
namespace gemm {
4131

4232
namespace gemm {
@@ -285,6 +275,12 @@ class GemmInterface {
285275
template <typename Dtype>
286276
inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const;
287277

278+
// Returns the number of tiles and number of CTAs for Z dimension.
279+
std::tuple<int32_t, int32_t, int32_t> getGridSize(int32_t M, int32_t N, int32_t tileM,
280+
int32_t tileN, int32_t clusterDimX,
281+
int32_t clusterDimY,
282+
int32_t numSlicesForSplitK) const;
283+
288284
// Creates GemmOptions from kernel and data.
289285
GemmOptions getOptionsFromConfigAndData(GemmConfig const& config, GemmData const& data) const;
290286

@@ -319,15 +315,28 @@ GemmConfig const* GemmInterface::getGemmConfigs() const {
319315

320316
size_t GemmInterface::getNumGemmConfigs() const {
321317
#ifdef TLLM_GEN_EXPORT_INTERFACE
322-
return sizeof(tensorrt_llm::kernels::tllmGenGemmList) /
323-
sizeof(tensorrt_llm::kernels::tllmGenGemmList[0]);
318+
return tensorrt_llm::kernels::tllmGenGemmListLen;
324319
#else
325320
return 0;
326321
#endif
327322
}
328323

329324
////////////////////////////////////////////////////////////////////////////////////////////////////
330325

326+
std::tuple<int32_t, int32_t, int32_t> GemmInterface::getGridSize(int32_t M, int32_t N,
327+
int32_t tileM, int32_t tileN,
328+
int32_t clusterDimX,
329+
int32_t clusterDimY,
330+
int32_t numSlicesForSplitK) const {
331+
// The number of tiles in the M dimension.
332+
auto numTilesM = gemm::divUpMul(gemm::divUp(M, tileM), clusterDimX);
333+
// The number of tiles in the N dimension.
334+
auto numTilesN = gemm::divUpMul(gemm::divUp(N, tileN), clusterDimY);
335+
return std::make_tuple(numTilesM, numTilesN, numSlicesForSplitK);
336+
}
337+
338+
////////////////////////////////////////////////////////////////////////////////////////////////////
339+
331340
GemmOptions GemmInterface::getOptionsFromConfigAndData(GemmConfig const& config,
332341
GemmData const& data) const {
333342
// Create options from config and data.
@@ -363,10 +372,10 @@ std::vector<size_t> GemmInterface::getWorkspaceSizesInBytes(GemmConfig const& co
363372
// Get options from config.
364373
auto& options = config.mOptions;
365374

366-
// The number of tiles in the M dimension.
367-
int32_t numTilesM = gemm::divUp(data.mProblemDimensions.mM, options.mTileM);
368-
// The number of tiles in the N dimension.
369-
int32_t numTilesN = gemm::divUp(data.mProblemDimensions.mN, options.mTileN);
375+
// Get the number of tiles and cluster dimension Z.
376+
auto [numTilesM, numTilesN, gridDimZ] = getGridSize(
377+
data.mProblemDimensions.mM, data.mProblemDimensions.mN, options.mTileM, options.mTileN,
378+
options.mClusterDimX, options.mClusterDimY, options.mNumSlicesForSplitK);
370379

371380
std::vector<size_t> workspaceSizes;
372381

@@ -439,10 +448,10 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
439448
}
440449
}
441450

442-
// The number of tiles in the M dimension.
443-
int numTilesM = gemm::divUp(options.mM, options.mTileM);
444-
// The number of tiles in the N dimension.
445-
int numTilesN = gemm::divUp(options.mN, options.mTileN);
451+
// Get the number of tiles and number of CTAs for Z dimension.
452+
auto [numTilesM, numTilesN, gridDimZ] =
453+
getGridSize(options.mM, options.mN, options.mTileM, options.mTileN, options.mClusterDimX,
454+
options.mClusterDimY, options.mNumSlicesForSplitK);
446455

447456
// Create kernel params.
448457
auto kernelParams = gemm::KernelParamsSetup::setKernelParams(
@@ -455,9 +464,8 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
455464
data.mAllReduceBuffers.mPtrMultiMemCompletionBars, dPtrSplitKCompletionBars,
456465
/* dPtrNumNonExitingCtas */ nullptr, data.mProblemDimensions.mRank,
457466
data.mProblemDimensions.mWorldSize);
458-
459467
// The size of the grid.
460-
std::vector<int32_t> grid{numTilesM, numTilesN, options.mNumSlicesForSplitK};
468+
std::vector<int32_t> grid{numTilesM, numTilesN, gridDimZ};
461469

462470
// When split-k is enabled and to guarantee the forward progress, we must ensure that the number
463471
// of tiles is less than number of SMs. This way, at least one CTA in the grid can make forward.
@@ -472,16 +480,6 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
472480
CUmodule cuModule;
473481
CUfunction cuFunction;
474482

475-
auto fiModuleLoadData = [&](CUmodule* module) {
476-
const std::string sha256 = config.mHash ? config.mHash : "";
477-
std::string fname_cubin = config.mFunctionName;
478-
if (!fname_cubin.empty()) {
479-
fname_cubin[0] = static_cast<char>(std::toupper(static_cast<unsigned char>(fname_cubin[0])));
480-
}
481-
fname_cubin = tllm_gen_gemm_cubin_path + "/" + fname_cubin + ".cubin";
482-
std::string cubin = flashinfer::trtllm_cubin_loader::getCubin(fname_cubin, sha256);
483-
cuModuleLoadData(&cuModule, cubin.c_str());
484-
};
485483
if (moduleCache.has_value()) {
486484
ModuleCache& moduleCacheRef = moduleCache.value().get();
487485

@@ -503,12 +501,12 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
503501
if (module != moduleCacheRef.end()) {
504502
cuFunction = std::get<1>(module->second);
505503
} else {
506-
fiModuleLoadData(&cuModule);
504+
cuModuleLoadData(&cuModule, config.mData);
507505
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
508506
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
509507
}
510508
} else {
511-
fiModuleLoadData(&cuModule);
509+
cuModuleLoadData(&cuModule, config.mData);
512510
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
513511
}
514512

@@ -536,7 +534,9 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
536534
return -1;
537535
}
538536
#else
539-
config.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid);
537+
config.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid,
538+
/*cluster*/ {},
539+
/*instanceId*/ config.mInstanceIdx);
540540
#endif
541541

542542
return 0;
@@ -564,10 +564,11 @@ int32_t GemmInterface::runInitBeforeWorldSync(GemmConfig const& config, GemmData
564564
return 1;
565565
}
566566
}
567-
// The number of tiles in the M dimension.
568-
int numTilesM = gemm::divUp(options.mM, options.mTileM);
569-
// The number of tiles in the N dimension.
570-
int numTilesN = gemm::divUp(options.mN, options.mTileN);
567+
568+
// Get the number of tiles and number of CTAs for Z dimension.
569+
auto [numTilesM, numTilesN, gridDimZ] =
570+
getGridSize(options.mM, options.mN, options.mTileM, options.mTileN, options.mClusterDimX,
571+
options.mClusterDimY, options.mNumSlicesForSplitK);
571572
// The number of bytes for the tile barriers.
572573
int32_t numBytesTileBars = numTilesM * numTilesN * sizeof(uint32_t);
573574
// Sanitize system barriers.

0 commit comments

Comments
 (0)