@@ -76,39 +76,43 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
76
76
// FIXME We create explicit constructor with all options to WAR stubgen issue in TRT-LLM.
77
77
BatchedGemmOptions (
78
78
gemm::AllReduceAlgo allReduceAlgo, gemm::BiasType biasType, int blockK, int clusterDimX,
79
- int clusterDimY, int clusterDimZ, tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB,
80
- tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit,
81
- bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps,
82
- int epilogueLdtmBits, int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA,
83
- bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA,
84
- bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k,
85
- gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB,
86
- int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n,
87
- int numSlicesForSplitK, int numSlicesForSliceK, int numStages, int numStagesMma,
88
- int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId,
89
- bool outputDebugTensors, bool patchF2fp, std::optional<int32_t > sfBlockSizeA,
90
- tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC,
91
- int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, int tileM, int tileN,
92
- gemm::TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule,
93
- bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA,
94
- bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, bool useTwoTmaLoadWarps,
95
- bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize, gemmGatedAct::ActType actType,
96
- bool clampBeforeAct, std::vector<int > batchedM, std::vector<int > batchedN,
97
- BatchMode batchMode, int numBatches, bool isStaticBatch, int numTokens, RouteImpl routeImpl,
98
- bool gridWaitForPrimaryRouting, bool fusedAct, int numRegsPerThreadNonEpilogueWarp,
99
- int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps, bool useTmaOobOpt)
79
+ int clusterDimY, int clusterDimZ, gemm::CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc,
80
+ tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA,
81
+ tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit,
82
+ bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM,
83
+ int epilogueTileN, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB,
84
+ bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB,
85
+ bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits,
86
+ gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind,
87
+ int mmaM, int mmaN, bool mockAllReduce, int n, int numRegsCastAWarps,
88
+ int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp,
89
+ int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK,
90
+ int numStages, int numStagesMma, int numStagesMmaWithinWorkTile,
91
+ int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp,
92
+ std::optional<int32_t > sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB,
93
+ tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK,
94
+ int tileM, int tileN, gemm::TileScheduler tileScheduler, bool transposeMmaOutput,
95
+ bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule,
96
+ bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore,
97
+ bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize,
98
+ gemmGatedAct::ActType actType, bool clampBeforeAct, std::vector<int > batchedM,
99
+ std::vector<int > batchedN, BatchMode batchMode, int numBatches, bool isStaticBatch,
100
+ int numTokens, RouteImpl routeImpl, std::optional<RouteImpl> routeSfsImpl,
101
+ bool gridWaitForPrimaryRouting, bool fusedAct, bool useTmaOobOpt)
100
102
: gemmGatedAct::GemmGatedActOptions(
101
103
gemm::GemmOptions (
102
- allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, dtypeAcc,
103
- dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit,
104
- enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, epilogueLdtmBits,
105
- epilogueTileM, epilogueTileN, gridTriggerSecondaryA, gridTriggerSecondaryB,
106
- gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, gridWaitForPrimaryB,
107
- hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m, mmaK,
108
- mmaKind, mmaM, mmaN, mockAllReduce, n, numSlicesForSplitK, numSlicesForSliceK,
109
- numStages, numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile,
110
- numStagesWorkId, outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB,
111
- sfLayoutC, sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler,
104
+ allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ,
105
+ ctaSwizzleType, dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB,
106
+ enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps,
107
+ epilogueLdtmBits, epilogueTileM, epilogueTileN, gridTriggerSecondaryA,
108
+ gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA,
109
+ gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits,
110
+ layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, numRegsCastAWarps,
111
+ numRegsCopySfLdsSttm, numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp,
112
+ numSlicesForSplitK, numSlicesForSliceK, numStages, numStagesMma,
113
+ numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId,
114
+ outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC,
115
+ sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler,
112
116
transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8,
113
117
useHoistTryWaitForCustomMmaSchedule, usePerTokenSfA, usePerTokenSfB,
114
118
useShuffledMatrixA, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps,
@@ -121,11 +125,9 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
121
125
mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting),
122
126
mIsStaticBatch(isStaticBatch),
123
127
mNumBatches(numBatches),
124
- mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp),
125
- mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp),
126
- mNumRegsCastAWarps(numRegsCastAWarps),
127
128
mNumTokens(numTokens),
128
129
mRouteImpl(routeImpl),
130
+ mRouteSfsImpl(routeSfsImpl),
129
131
mUseTmaOobOpt(useTmaOobOpt) {}
130
132
131
133
// Batched M-dimensions of GEMM.
@@ -143,16 +145,12 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
143
145
bool mIsStaticBatch {true };
144
146
// Number of Gemm batches.
145
147
int mNumBatches ;
146
- // Number of registers per thread for non-epilogue warps
147
- int mNumRegsPerThreadNonEpilogueWarp {0 };
148
- // Number of registers per thread for epilogue warps
149
- int mNumRegsPerThreadEpilogueWarp {0 };
150
- // Number of registers for the cast A warps.
151
- int mNumRegsCastAWarps {0 };
152
148
// Total number of tokens.
153
149
int mNumTokens {32 };
154
150
// Whether load the input tokens and do routing.
155
151
RouteImpl mRouteImpl {RouteImpl::NoRoute};
152
+ // Routing logic for scaling factors. If not specified, mRouteImpl is used.
153
+ std::optional<RouteImpl> mRouteSfsImpl {std::nullopt };
156
154
// Whether to use TMA out-of-bounds optimization to reduce wasted traffic. See details in
157
155
// BatchedGemm/KernelParamsDecl.h.
158
156
bool mUseTmaOobOpt {false };
@@ -235,6 +233,18 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
235
233
" E2m1 is not supported with DeepSeek FP8" );
236
234
}
237
235
236
+ if (options.mRouteSfsImpl .has_value () && options.mRouteSfsImpl .value () != options.mRouteImpl ) {
237
+ TLLM_CHECK_ERROR (
238
+ options.mRouteSfsImpl .value () == RouteImpl::Ldgsts && options.mRouteImpl == RouteImpl::Tma,
239
+ " RouteSfsImpl must be equal to RouteImpl, or Ldgsts, when RouteImpl is Tma" );
240
+ } else if (!options.mRouteSfsImpl .has_value ()) {
241
+ if (updateOptions) {
242
+ options.mRouteSfsImpl = options.mRouteImpl ;
243
+ } else {
244
+ TLLM_LOG_ERROR (" RouteSfsImpl must be specified" );
245
+ return false ;
246
+ }
247
+ }
238
248
if (batchM) {
239
249
if (options.mDtypeA == tg::Dtype::MxE2m1 && options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) {
240
250
TLLM_CHECK_ERROR (doesRouteImplUseNoRoute (options.mRouteImpl ),
@@ -269,18 +279,20 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
269
279
}
270
280
}
271
281
272
- if (doesRouteImplUseTma (options.mRouteImpl )) {
282
+ if (doesRouteImplUseTma (options.mRouteSfsImpl . value () )) {
273
283
TLLM_CHECK_ERROR (!batchM, " UTMALDG.GATHER4 only supported for batch N." );
274
284
275
285
if (tg::mmaKindIsBlockFmt (options.mMmaKind )) {
276
286
auto dtypeRoute = batchM ? options.mDtypeA : options.mDtypeB ;
277
- TLLM_CHECK_ERROR (options.mTileK % tg::dtypeNumEltsPerSf (dtypeRoute) == 0 ,
278
- " tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA)." );
279
287
TLLM_CHECK_ERROR (options.mTileK % (tg::dtypeNumEltsPerSf (dtypeRoute) * 16 ) == 0 ,
280
288
" tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA)." );
281
289
}
282
290
}
283
291
292
+ if (options.mClusterDimX > 1 ) {
293
+ TLLM_CHECK_ERROR (!batchM, " 2CTA Gemm currently only supports batch N." );
294
+ }
295
+
284
296
if (!batchM || doesRouteImplUseNoRoute (options.mRouteImpl )) {
285
297
TLLM_CHECK_ERROR (options.mSfLayoutA == tg::SfLayout::R128c4,
286
298
" options.mSfLayoutA has to be tg::SfLayout::R128c4 when not being routed" );
@@ -301,6 +313,11 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
301
313
TLLM_CHECK_ERROR (options.mK % options.mTileK == 0 , " K must be a multiple of TileK" );
302
314
}
303
315
316
+ if (options.mClusterDimX > 1 && batchM && options.mRouteImpl != RouteImpl::NoRoute) {
317
+ TLLM_CHECK_ERROR (false ,
318
+ " 2CTA BatchedGemm does not support routing along M dimension. To support it, "
319
+ " change the input routing data layout to be padded to clusterDimX size." );
320
+ }
304
321
return isValid;
305
322
}
306
323
@@ -323,6 +340,7 @@ struct BatchedGemmConfig {
323
340
char const * mHash {nullptr };
324
341
#else
325
342
trtllm::gen::CudaRunner* mCudaRunner {nullptr };
343
+ int32_t mInstanceIdx {0 };
326
344
#endif
327
345
328
346
BatchedGemmOptions mOptions ;
@@ -343,13 +361,10 @@ inline std::string dumpOptions(BatchedGemmOptions const& options) {
343
361
ss << " mNumTokens=" << options.mNumTokens << " ," << std::endl;
344
362
ss << " mRouteImpl=batchedGemm::RouteImpl(" << static_cast <int32_t >(options.mRouteImpl ) << " ),"
345
363
<< std::endl;
364
+ ss << " mRouteSfsImpl={batchedGemm::RouteImpl("
365
+ << static_cast <int32_t >(options.mRouteSfsImpl .value ()) << " )}," << std::endl;
346
366
ss << " mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << " ," << std::endl;
347
367
ss << " mFusedAct=" << options.mFusedAct << " ," << std::endl;
348
- ss << " mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << " ,"
349
- << std::endl;
350
- ss << " mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << " ,"
351
- << std::endl;
352
- ss << " mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << " ," << std::endl;
353
368
ss << " mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl;
354
369
return ss.str ();
355
370
}
0 commit comments