24
24
#include " trtllm/gen/CudaKernelLauncher.h"
25
25
26
26
#ifdef TLLM_GEN_EXPORT_INTERFACE
27
- #include " flashinferMetaInfo .h"
27
+ #include " KernelMetaInfo .h"
28
28
#endif // TLLM_GEN_EXPORT_INTERFACE
29
29
30
- #ifdef TLLM_GEN_GEMM_CUBIN_PATH
31
- static const std::string tllm_gen_gemm_cubin_path = std::string(TLLM_GEN_GEMM_CUBIN_PATH);
32
- #else
33
- static_assert (false , " TLLM_GEN_GEMM_CUBIN_PATH macro is not defined when compiling" );
34
- #endif
35
-
36
- namespace flashinfer ::trtllm_cubin_loader {
37
- std::string getCubin (const std::string& kernelName, const std::string& sha256);
38
- } // namespace flashinfer::trtllm_cubin_loader
39
-
40
30
namespace gemm {
41
31
42
32
namespace gemm {
@@ -285,6 +275,12 @@ class GemmInterface {
285
275
template <typename Dtype>
286
276
inline Dtype* alignPtr (Dtype* ptr, int64_t alignment) const ;
287
277
278
+ // Returns the number of tiles and number of CTAs for Z dimension.
279
+ std::tuple<int32_t , int32_t , int32_t > getGridSize (int32_t M, int32_t N, int32_t tileM,
280
+ int32_t tileN, int32_t clusterDimX,
281
+ int32_t clusterDimY,
282
+ int32_t numSlicesForSplitK) const ;
283
+
288
284
// Creates GemmOptions from kernel and data.
289
285
GemmOptions getOptionsFromConfigAndData (GemmConfig const & config, GemmData const & data) const ;
290
286
@@ -319,15 +315,28 @@ GemmConfig const* GemmInterface::getGemmConfigs() const {
319
315
320
316
size_t GemmInterface::getNumGemmConfigs () const {
321
317
#ifdef TLLM_GEN_EXPORT_INTERFACE
322
- return sizeof (tensorrt_llm::kernels::tllmGenGemmList) /
323
- sizeof (tensorrt_llm::kernels::tllmGenGemmList[0 ]);
318
+ return tensorrt_llm::kernels::tllmGenGemmListLen;
324
319
#else
325
320
return 0 ;
326
321
#endif
327
322
}
328
323
329
324
// //////////////////////////////////////////////////////////////////////////////////////////////////
330
325
326
+ std::tuple<int32_t , int32_t , int32_t > GemmInterface::getGridSize (int32_t M, int32_t N,
327
+ int32_t tileM, int32_t tileN,
328
+ int32_t clusterDimX,
329
+ int32_t clusterDimY,
330
+ int32_t numSlicesForSplitK) const {
331
+ // The number of tiles in the M dimension.
332
+ auto numTilesM = gemm::divUpMul (gemm::divUp (M, tileM), clusterDimX);
333
+ // The number of tiles in the N dimension.
334
+ auto numTilesN = gemm::divUpMul (gemm::divUp (N, tileN), clusterDimY);
335
+ return std::make_tuple (numTilesM, numTilesN, numSlicesForSplitK);
336
+ }
337
+
338
+ // //////////////////////////////////////////////////////////////////////////////////////////////////
339
+
331
340
GemmOptions GemmInterface::getOptionsFromConfigAndData (GemmConfig const & config,
332
341
GemmData const & data) const {
333
342
// Create options from config and data.
@@ -363,10 +372,10 @@ std::vector<size_t> GemmInterface::getWorkspaceSizesInBytes(GemmConfig const& co
363
372
// Get options from config.
364
373
auto & options = config.mOptions ;
365
374
366
- // The number of tiles in the M dimension.
367
- int32_t numTilesM = gemm::divUp (data. mProblemDimensions . mM , options. mTileM );
368
- // The number of tiles in the N dimension.
369
- int32_t numTilesN = gemm::divUp (data. mProblemDimensions . mN , options.mTileN );
375
+ // Get the number of tiles and cluster dimension Z .
376
+ auto [ numTilesM, numTilesN, gridDimZ] = getGridSize (
377
+ data. mProblemDimensions . mM , data. mProblemDimensions . mN , options. mTileM , options. mTileN ,
378
+ options. mClusterDimX , options. mClusterDimY , options.mNumSlicesForSplitK );
370
379
371
380
std::vector<size_t > workspaceSizes;
372
381
@@ -439,10 +448,10 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
439
448
}
440
449
}
441
450
442
- // The number of tiles in the M dimension.
443
- int numTilesM = gemm::divUp (options. mM , options. mTileM );
444
- // The number of tiles in the N dimension.
445
- int numTilesN = gemm::divUp ( options.mN , options.mTileN );
451
+ // Get the number of tiles and number of CTAs for Z dimension.
452
+ auto [ numTilesM, numTilesN, gridDimZ] =
453
+ getGridSize (options. mM , options. mN , options. mTileM , options. mTileN , options. mClusterDimX ,
454
+ options.mClusterDimY , options.mNumSlicesForSplitK );
446
455
447
456
// Create kernel params.
448
457
auto kernelParams = gemm::KernelParamsSetup::setKernelParams (
@@ -455,9 +464,8 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
455
464
data.mAllReduceBuffers .mPtrMultiMemCompletionBars , dPtrSplitKCompletionBars,
456
465
/* dPtrNumNonExitingCtas */ nullptr , data.mProblemDimensions .mRank ,
457
466
data.mProblemDimensions .mWorldSize );
458
-
459
467
// The size of the grid.
460
- std::vector<int32_t > grid{numTilesM, numTilesN, options. mNumSlicesForSplitK };
468
+ std::vector<int32_t > grid{numTilesM, numTilesN, gridDimZ };
461
469
462
470
// When split-k is enabled and to guarantee the forward progress, we must ensure that the number
463
471
// of tiles is less than number of SMs. This way, at least one CTA in the grid can make forward.
@@ -472,16 +480,6 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
472
480
CUmodule cuModule;
473
481
CUfunction cuFunction;
474
482
475
- auto fiModuleLoadData = [&](CUmodule* module ) {
476
- const std::string sha256 = config.mHash ? config.mHash : " " ;
477
- std::string fname_cubin = config.mFunctionName ;
478
- if (!fname_cubin.empty ()) {
479
- fname_cubin[0 ] = static_cast <char >(std::toupper (static_cast <unsigned char >(fname_cubin[0 ])));
480
- }
481
- fname_cubin = tllm_gen_gemm_cubin_path + " /" + fname_cubin + " .cubin" ;
482
- std::string cubin = flashinfer::trtllm_cubin_loader::getCubin (fname_cubin, sha256);
483
- cuModuleLoadData (&cuModule, cubin.c_str ());
484
- };
485
483
if (moduleCache.has_value ()) {
486
484
ModuleCache& moduleCacheRef = moduleCache.value ().get ();
487
485
@@ -503,12 +501,12 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
503
501
if (module != moduleCacheRef.end ()) {
504
502
cuFunction = std::get<1 >(module ->second );
505
503
} else {
506
- fiModuleLoadData (&cuModule);
504
+ cuModuleLoadData (&cuModule, config. mData );
507
505
cuModuleGetFunction (&cuFunction, cuModule, config.mFunctionName );
508
506
moduleCacheRef.insert (std::make_pair (moduleKey, std::make_tuple (cuModule, cuFunction)));
509
507
}
510
508
} else {
511
- fiModuleLoadData (&cuModule);
509
+ cuModuleLoadData (&cuModule, config. mData );
512
510
cuModuleGetFunction (&cuFunction, cuModule, config.mFunctionName );
513
511
}
514
512
@@ -536,7 +534,9 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
536
534
return -1 ;
537
535
}
538
536
#else
539
- config.mCudaRunner ->run ((void *)&kernelParams, (void *)cudaStream, grid);
537
+ config.mCudaRunner ->run ((void *)&kernelParams, (void *)cudaStream, grid,
538
+ /* cluster*/ {},
539
+ /* instanceId*/ config.mInstanceIdx );
540
540
#endif
541
541
542
542
return 0 ;
@@ -564,10 +564,11 @@ int32_t GemmInterface::runInitBeforeWorldSync(GemmConfig const& config, GemmData
564
564
return 1 ;
565
565
}
566
566
}
567
- // The number of tiles in the M dimension.
568
- int numTilesM = gemm::divUp (options.mM , options.mTileM );
569
- // The number of tiles in the N dimension.
570
- int numTilesN = gemm::divUp (options.mN , options.mTileN );
567
+
568
+ // Get the number of tiles and number of CTAs for Z dimension.
569
+ auto [numTilesM, numTilesN, gridDimZ] =
570
+ getGridSize (options.mM , options.mN , options.mTileM , options.mTileN , options.mClusterDimX ,
571
+ options.mClusterDimY , options.mNumSlicesForSplitK );
571
572
// The number of bytes for the tile barriers.
572
573
int32_t numBytesTileBars = numTilesM * numTilesN * sizeof (uint32_t );
573
574
// Sanitize system barriers.
0 commit comments