@@ -90,10 +90,10 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
90
90
bool usePerTokenSfB, bool useTmaStore, bool useTwoTmaLoadWarps, bool useTwoMmaWarps,
91
91
tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC,
92
92
int32_t sfReshapeFactor, gemm::TileScheduler tileScheduler, gemmGatedAct::ActType actType,
93
- std::vector<int > batchedM, std::vector<int > batchedN, BatchMode batchMode, int numBatches ,
94
- bool isStaticBatch, int numTokens, RouteImpl routeImpl, bool gridWaitForPrimaryRouting ,
95
- bool fusedAct, int numRegsPerThreadNonEpilogueWarp , int numRegsPerThreadEpilogueWarp ,
96
- int numRegsCastAWarps)
93
+ bool clampBeforeAct, std::vector<int > batchedM, std::vector<int > batchedN,
94
+ BatchMode batchMode, int numBatches, bool isStaticBatch, int numTokens, RouteImpl routeImpl,
95
+ bool gridWaitForPrimaryRouting, bool fusedAct , int numRegsPerThreadNonEpilogueWarp ,
96
+ int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps, bool useTmaOobOpt )
97
97
: gemmGatedAct::GemmGatedActOptions(
98
98
gemm::GemmOptions (
99
99
allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, dtypeAcc,
@@ -109,48 +109,49 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
109
109
useCustomMmaSchedule, useHoistTryWaitForCustomMmaSchedule, useDeepSeekFp8,
110
110
usePerTokenSfA, usePerTokenSfB, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps,
111
111
sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, tileScheduler),
112
- actType),
112
+ actType, clampBeforeAct ),
113
113
mBatchedM(batchedM),
114
114
mBatchedN(batchedN),
115
115
mBatchMode(BatchMode(batchMode)),
116
- mNumBatches(numBatches),
117
- mIsStaticBatch(isStaticBatch),
118
- mNumTokens(numTokens),
119
- mRouteImpl(routeImpl),
120
- mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting),
121
116
mFusedAct(fusedAct),
117
+ mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting),
118
+ mIsStaticBatch(isStaticBatch),
119
+ mNumBatches(numBatches),
122
120
mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp),
123
121
mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp),
124
- mNumRegsCastAWarps(numRegsCastAWarps) {}
122
+ mNumRegsCastAWarps(numRegsCastAWarps),
123
+ mNumTokens(numTokens),
124
+ mRouteImpl(routeImpl),
125
+ mUseTmaOobOpt(useTmaOobOpt) {}
125
126
126
127
// Batched M-dimensions of GEMM.
127
128
std::vector<int > mBatchedM ;
128
129
// Batched N-dimensions of GEMM.
129
130
std::vector<int > mBatchedN ;
130
131
// Whether batching M or N.
131
132
BatchMode mBatchMode {BatchMode::BatchM};
132
- // Number of Gemm batches.
133
- int mNumBatches ;
134
-
135
- // Whether the batch size is static (i.e. known at kernel launch time).
136
- bool mIsStaticBatch {true };
137
- // Total number of tokens.
138
- int mNumTokens {32 };
139
- // Whether load the input tokens and do routing.
140
- RouteImpl mRouteImpl {RouteImpl::NoRoute};
133
+ // Whether to perform a fused gated activation.
134
+ bool mFusedAct {false };
141
135
// Whether the loads that load from ptrRouteMap, ptrTotalNumPaddedTokens,
142
136
// ptrCtaIdxXyToBatchIdx, etc.. should wait on a grid dependency.
143
137
bool mGridWaitForPrimaryRouting {true };
144
-
145
- // Whether to perform a fused gated activation.
146
- bool mFusedAct { false };
147
-
138
+ // Whether the batch size is static (i.e. known at kernel launch time).
139
+ bool mIsStaticBatch { true };
140
+ // Number of Gemm batches.
141
+ int mNumBatches ;
148
142
// Number of registers per thread for non-epilogue warps
149
143
int mNumRegsPerThreadNonEpilogueWarp {0 };
150
144
// Number of registers per thread for epilogue warps
151
145
int mNumRegsPerThreadEpilogueWarp {0 };
152
146
// Number of registers for the cast A warps.
153
147
int mNumRegsCastAWarps {0 };
148
+ // Total number of tokens.
149
+ int mNumTokens {32 };
150
+ // Whether load the input tokens and do routing.
151
+ RouteImpl mRouteImpl {RouteImpl::NoRoute};
152
+ // Whether to use TMA out-of-bounds optimization to reduce wasted traffic. See details in
153
+ // BatchedGemm/KernelParamsDecl.h.
154
+ bool mUseTmaOobOpt {false };
154
155
};
155
156
156
157
// //////////////////////////////////////////////////////////////////////////////////////////////////
@@ -159,6 +160,16 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
159
160
bool checkAndUpdateBatchedGemmOptions (BatchedGemmOptions& options, bool isBlackwell,
160
161
bool updateOptions = true ) {
161
162
bool isValid = true ;
163
+ if (options.mUseTmaOobOpt && !options.mUseTwoTmaLoadWarps ) {
164
+ if (updateOptions) {
165
+ // Since any routing (mRouteAct != NoRoute) requires mUseTwoTmaLoadWarps == true.
166
+ // Single TMA load warp is not the target use case for OOB optimization.
167
+ options.mUseTmaOobOpt = false ;
168
+ } else {
169
+ TLLM_CHECK_ERROR (false , " TMA OOB optimization requires two TMA load warps." );
170
+ return false ;
171
+ }
172
+ }
162
173
if (options.mFusedAct ) {
163
174
// ensure that we check the fused options as well
164
175
isValid = gemmGatedAct::checkAndUpdateGemmGatedActOptions (options, isBlackwell, updateOptions);
@@ -302,6 +313,8 @@ struct BatchedGemmConfig {
302
313
// defined. In this case, the cubins will be loaded from the provided data and function name.
303
314
// Otherwise, the kernel will be loaded from the CudaRunner.
304
315
#ifdef TLLM_GEN_EXPORT_INTERFACE
316
+ uint8_t const * mData {nullptr };
317
+ uint32_t const mSize {0 };
305
318
uint32_t const mSharedMemSize {0 };
306
319
char const * mFunctionName {nullptr };
307
320
uint32_t const mNumThreadsPerCTA {0 };
@@ -334,7 +347,8 @@ inline std::string dumpOptions(BatchedGemmOptions const& options) {
334
347
<< std::endl;
335
348
ss << " mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << " ,"
336
349
<< std::endl;
337
- ss << " mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << std::endl;
350
+ ss << " mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << " ," << std::endl;
351
+ ss << " mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl;
338
352
return ss.str ();
339
353
}
340
354
0 commit comments